别再只会调包了!手把手教你用NumPy从零推导线性回归的OLS公式(附Python代码)

张开发
2026/4/17 18:52:19 15 分钟阅读

分享文章

别再只会调包了!手把手教你用NumPy从零推导线性回归的OLS公式(附Python代码)
从零构建线性回归用NumPy揭秘最小二乘法的数学本质在数据科学领域线性回归就像乐高积木中的基础模块——看似简单却能构建复杂模型。许多初学者能够熟练调用sklearn的LinearRegression完成预测但当被问到为什么参数要这样计算时却陷入沉默。这种现象被业界称为调包侠综合征只会使用工具却不理解背后的数学逻辑。本文将带你用NumPy从零推导普通最小二乘法(OLS)就像拆开黑箱亲眼看看机器学习的齿轮如何转动。1. 线性回归的统计基础线性回归的核心思想是用直线描述自变量X和因变量Y之间的关系。当我们说拟合一条直线时实际上是在寻找参数β₀(截距)和β₁(斜率)的最佳估计值。这里的最佳在统计学中有明确定义——使预测值与真实值之间的差距最小。考虑一个简单的消费支出案例假设每月消费(Spending)与可支配收入(Income)存在线性关系Spending β₀ β₁ × Income ε其中ε代表无法用收入解释的随机误差。我们的目标是找到β₀和β₁使得所有数据点到预测直线的垂直距离(残差)最小。这就是最小二乘法的直观解释。关键概念残差平方和(RSS)是衡量模型拟合优度的核心指标计算公式为Σ(yᵢ - ŷᵢ)²其中ŷᵢ表示第i个预测值2. OLS的数学推导过程2.1 构建优化问题最小二乘法的数学本质是一个优化问题找到参数使残差平方和最小。用矩阵表示对于n个观测值和p个特征RSS(β) (Y - Xβ)ᵀ(Y - Xβ)其中Y是n×1的响应向量X是n×(p1)的设计矩阵(含截距项)β是(p1)×1的参数向量展开这个二次型我们得到RSS(β) YᵀY - 2βᵀXᵀY βᵀXᵀXβ2.2 求解极值点为了找到最小值我们对β求导并令导数等于零∂RSS/∂β -2XᵀY 2XᵀXβ 0整理得到正规方程(Normal Equation)XᵀXβ XᵀY当XᵀX可逆时参数的最优解为β̂ (XᵀX)⁻¹XᵀY这就是OLS估计量的矩阵形式揭示了参数估计如何通过数据矩阵运算得到。3. NumPy实现核心算法现在我们将数学公式转化为NumPy代码不使用任何现成的机器学习库。以下是关键步骤的实现import numpy as np def ols_fit(X, y): 手动实现OLS参数估计 # 添加截距列 X np.column_stack([np.ones(X.shape[0]), X]) # 计算XᵀX的逆 XtX np.dot(X.T, X) XtX_inv np.linalg.inv(XtX) # 计算Xᵀy Xty np.dot(X.T, y) # 求解参数 beta np.dot(XtX_inv, Xty) return beta # 示例数据收入与消费 income np.array([800,1100,1400,1700,2000,2300,2600,2900,3200,3500]) spending np.array([638,935,1155,1254,1408,1650,1925,2068,2266,2530]) # 拟合模型 beta ols_fit(income, spending) print(f截距β₀: {beta[0]:.2f}, 斜率β₁: {beta[1]:.2f})执行结果应显示截距β₀: 142.00, 斜率β₁: 0.67这个简单的实现揭示了机器学习库背后的核心计算过程。值得注意的是实际应用中我们会使用更稳定的数值计算方法(如QR分解)但上述代码最直接地反映了数学原理。4. 算法验证与效果评估4.1 与统计包结果对比为了验证我们的实现是否正确可以使用statsmodels进行交叉验证import statsmodels.api as sm X_with_intercept sm.add_constant(income) model sm.OLS(spending, X_with_intercept).fit() print(model.params)两种方法得到的参数估计应该完全一致这说明我们的手动实现是正确的。4.2 模型诊断指标除了参数估计完整的回归分析还需要评估模型质量。以下是几个核心指标的计算方法指标名称计算公式解释R²1 - RSS/TSS解释的方差比例调整R²1 - (RSS/(n-p-1))/(TSS/(n-1))考虑参数数量的修正R²MSERSS/n均方误差参数标准误√(σ²(XᵀX)⁻¹对角元素)估计的精确度其中RSS Σ(yᵢ - ŷᵢ)² (残差平方和)TSS Σ(yᵢ - ȳ)² (总平方和)σ² RSS/(n-p-1) (误差方差估计)在NumPy中实现这些指标def model_metrics(X, y, beta): X_design np.column_stack([np.ones(X.shape[0]), X]) y_pred np.dot(X_design, beta) residuals y - y_pred n len(y) p X.shape[1] RSS np.sum(residuals**2) TSS np.sum((y - np.mean(y))**2) r_squared 1 - RSS/TSS adj_r_squared 1 - (RSS/(n-p-1))/(TSS/(n-1)) mse RSS/n sigma_squared RSS/(n-p-1) var_beta sigma_squared * np.linalg.inv(np.dot(X_design.T, X_design)) std_err np.sqrt(np.diag(var_beta)) return { R²: r_squared, 调整R²: adj_r_squared, MSE: mse, 参数标准误: std_err }5. 工程实践中的注意事项在实际项目中直接使用正规方程可能会遇到数值不稳定的问题。以下是几个常见挑战及解决方案多重共线性问题当特征高度相关时XᵀX接近奇异矩阵解决方法使用正则化(Ridge回归)或主成分分析(PCA)大数据场景当n很大时计算(XᵀX)⁻¹需要O(p³)时间解决方法使用随机梯度下降(SGD)等迭代算法数值稳定性直接求逆可能引入数值误差更好的替代方案# 使用QR分解替代直接求逆 Q, R np.linalg.qr(X) beta np.linalg.solve(R, np.dot(Q.T, y))缺失值处理原始OLS假设数据完整实际应用中需要先进行缺失值填充或删除专业提示在实现生产级线性回归时考虑使用Cholesky分解代替直接矩阵求逆能显著提高数值稳定性并减少计算时间。

更多文章