教程目录
阅读:
线性回归模型及预测方法(Python实现)
有了前面章节的理论准备,我们利用上述实现好的函数,构建线性回归模型。在训练线性回归模型的过程中,我们使用如图 1 所示的数据集作为训练数据集。

图 1 原始数据
在求解模型的过程中,我们分别利用最小二乘法和全局牛顿法对其回归系数进行求解,求解的过程分为两步:
导入完训练数据后,可以利用最小二乘法对其参数进行训练,最小二乘法的具体实现在《最小二乘法》一节中已做详细说明,也可以利用全局牛顿法对其参数进行训练,全局牛顿法的具体实现也在《牛顿法》一文中介绍。
训练完成后,将最终的线性回归的模型参数保存在文件“weights”中,保存模型的 save_model 函数的具体实现如下所示:

图 2 最终的数据拟合效果
程序中,对新数据的预测主要有如下的步骤:

图 1 原始数据
在求解模型的过程中,我们分别利用最小二乘法和全局牛顿法对其回归系数进行求解,求解的过程分为两步:
- 训练线性回归模型;
- 利用训练好的线性回归模型预测新的数据;
训练线性回归模型
首先,我们利用训练样本训练模型,为了使得 Python 能够支持中文的注释和利用 numpy,我们需要在文件的开始加入:# coding:UTF-8 import numpy as np同时,在计算最小 m 值的过程中,需要使用 pow 函数,因此在文件中加入:
from math import pow线性回归模型的训练的主函数如下所示:
if __name__ == "__main__": # 1、导入数据集 print "----------- 1.load data ----------" feature, label = load_data("data.txt") # 1.1、最小二乘求解 print "----------- 2.training ----------" # 1.2、牛顿法 print " ---------- newton ----------" w_newton = newton(feature, label, 800, 0.1, 0.5) # 2、保存最终的结果 print "----------- 3.save result ----------" save_model("weights", w)该函数是线性回归模型训练的主函数,在线性回归模型的训练过程中,首先是导入训练数据,其中函数 load_data 的具体实现如下所示:
def load_data(file_path): '''导入数据 input: file_path(string):训练数据 output: feature(mat):特征 label(mat):标签 ''' f = open(file_path) feature = [] label = [] for line in f.readlines(): feature_tmp = [] lines = line.strip().split(" ") feature_tmp.append(1) # x0 for i in xrange(len(lines) - 1): feature_tmp.append(float(lines[i])) feature.append(feature_tmp) label.append(float(lines[-1])) f.close() return np.mat(feature), np.mat(label).Tload_data 函数将训练数据集中的特征导入到矩阵 feature 中,将样本标签导入到矩阵 label 中。
导入完训练数据后,可以利用最小二乘法对其参数进行训练,最小二乘法的具体实现在《最小二乘法》一节中已做详细说明,也可以利用全局牛顿法对其参数进行训练,全局牛顿法的具体实现也在《牛顿法》一文中介绍。
训练完成后,将最终的线性回归的模型参数保存在文件“weights”中,保存模型的 save_model 函数的具体实现如下所示:
def save_model(file_name, w): '''保存最终的模型 input: file_name(string):要保存的文件的名称 w(mat):训练好的线性回归模型 ''' f_result = open(file_name, "w") m, n = np.shape(w) for i in xrange(m): w_tmp = [] for j in xrange(n): w_tmp.append(str(w[i, j])) f_result.write(" ".join(w_tmp) + " ") f_result.close()函数 save_model 将训练好的线性回归模型 w 保存到 file_name 指定的文件中。
最终的训练结果
若使用最小二乘法进行训练,则训练过程为:
load data
training
least_square
save result
load data
training
newton
itration:0,error:0.0701706541513
itration:10,error:0.0701706541513
itration:20,error:0.0701706541513
itration:30,error:0.0701706541513
itration:40,error:0.0701706541513
save result
w0=0.00310499443379
w1=0.99450247031

图 2 最终的数据拟合效果
对新数据的预测
对于回归算法而言,训练好的模型需要能够对新的数据集进行预测。利用上述步骤,我们训练好线性回归模型,并将其保存在“weights”文件中,此时,我们需要利用训练好的线性回归模型对新数据进行预测,同样,为了能够使用 numpy 中的函数和对中文注释的支持,在文件的开始,我们加入:# coding:UTF-8 import numpy as np在对新数据的预测中,其主函数如下所示:
if __name__ == "__main__": # 1、导入测试数据 testData = load_data("data_test.txt") # 2、导入线性回归模型 w = load_model("weights") # 3、得到预测结果 predict = get_prediction(testData, w) # 4、保存最终的结果 save_predict("predict_result", predict)
程序中,对新数据的预测主要有如下的步骤:
- 利用函数 load_data导入测试数据集,该函数的具体实现为:
def load_data(file_path): '''导入测试数据 input: file_path(string):训练数据 output: feature(mat):特征 ''' f = open(file_path) feature = [] for line in f.readlines(): feature_tmp = [] lines = line.strip().split(" ") feature_tmp.append(1) # x0 for i in xrange(len(lines)): feature_tmp.append(float(lines[i])) feature.append(feature_tmp) f.close() return np.mat(feature)该函数实现了导入测试数据集的功能,函数的输入为测试数据集的位置,输出为测试数据集。
- 利用函数 load_model 导入训练好的线性回归的模型,函数load_model的具体实现为:
def load_model(model_file): '''导入模型 input: model_file(string):线性回归模型 output: w(mat):权重值 ''' w = [] f = open(model_file) for line in f.readlines(): w.append(float(line.strip())) f.close() return np.mat(w).T该函数将训练好的线性回归模型导入,函数输入为线性回归的参数所在的文件,其输出为权重值。
- 利用函数get_prediction对新数据进行预测,此函数的具体实现为:
def get_prediction(data, w): '''得到预测值 input: data(mat):测试数据 w(mat):权重值 output: 最终的预测 ''' return data * w该函数利用训练好的线性回归模型对新数据进行预测,函数的输入为测试数据 data 和线性回归模型 w,其输出为最终的预测值。
- 最终将预测的结果保存到文件“predict_result”中,save_predict函数的具体实现为:
def save_predict(file_name, predict): '''保存最终的预测值 input: file_name(string):需要保存的文件名 predict(mat):对测试数据的预测值 ''' m = np.shape(predict)[0] result = [] for i in xrange(m): result.append(str(predict[i,0])) f = open(file_name, "w") f.write(" ".join(result)) f.close()此函数将预测的结果 predict 保存到 file_name 指定的文件中。