0%

线性回归解析法推导以及代码实现

本文介绍线性回归定义以及推导过程,最后使用python实现了这个算法。

1. 线性回归定义

  如果输入$X$是列向量,目标$y$也是连续值,预测函数输出也是连续值。那就是一个回归问题。

  这就是一个线性回归问题。简单起见可以把公式改为:

  其中$\hat W$和$\hat X$称之为增广权重向量增广特征向量

2. 解析方法解

  线性回归的损失函数通常定义为:

  模型的经验风险为:

  如果最小化$R(Y,f(X,W))$,要计算$R(Y,f(X,W))$对$W$的导数:(可以把这个平方了之后再进行求导,然后使用平方转化为$A^T$矩阵)

  让这个$\frac{\partial R(Y, f(X,W))}{\partial W}$值为0,那么就能得到解析解:

3. 代码实现

  整个代码放到我的GitHub上了。这个例子是house price和Year相关性的例子。

3.1 输入数据

2000 2.000
2001 2.500
2002 2.900
2003 3.147
2004 4.515
2005 4.903
2006 5.365
2007 5.704
2008 6.853
2009 7.971
2010 8.561
2011 10.000
2012 11.280
2013 12.900

3.2 代码实现

  依赖numpy和matplotlib。如果没有安装,可以参考我之前写的Theano安装配置的第一部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
# @Author: Lich_Amnesia
# @Email: [email protected]
# @Date: 2016-04-08 18:35:50
# @Last Modified time: 2016-04-08 19:38:38
# @FileName: LinearRegression.py

import numpy as np
import matplotlib.pyplot as plt
import sys,os

# use numpy load from txt function to load data
def loadData():
file_path = 'data/LR_in.txt'
file = open(file_path)
data_set = np.loadtxt(file)
X0 = np.array([1.0 for i in range(data_set.shape[0])])
return np.c_[X0,data_set[:,:-1]],data_set[:,-1]

# use X^T * X to calculate the answer
def calculateMethod(X_parameters, Y_parameters):
X = np.mat(X_parameters)
# import! this y should be Y.T, you can print it to find the reason
y = np.mat(Y_parameters).T
tmp1 = np.dot(X.T,X).I
tmp2 = np.dot(X.T,y)
theta = np.dot(tmp1,tmp2)
theta = np.array(theta)
print(theta)
return theta

# use calculated theta, it will returrn predict Y
def predictOut(X_parameters, theta):
X = np.mat(X_parameters)
theta = np.mat(theta)
out = np.dot(X,theta)
return out

# use matplotlib to draw X-Y axis points
def draw(X_parameters, Y_parameters,theta):
plt.scatter(X_parameters[:,-1],Y_parameters,color='blue')
Y_predict_out = predictOut(X_parameters,theta)
plt.plot(X_parameters[:,-1],Y_predict_out,color='r',linewidth=4)
plt.xlabel('Year')
plt.ylabel('House Price')
plt.show()
return

def main():
X_parameters, Y_parameters = loadData()
theta = calculateMethod(X_parameters,Y_parameters)
draw(X_parameters,Y_parameters,theta)
# print(X_parameters)
return

if __name__ == '__main__':
main()

参考文献

[1] 神经网络与深度学习讲义20151211


因为我们是朋友,所以你可以使用我的文字,但请注明出处:http://alwa.info