线性回归与逻辑回归

线性回归

线性模型是许多领域中最流行的统计建模方法,最开始主要应用与经济学,遗传学等领域。又名回归模型,或最小二乘模型,这个名称来自于用来寻找这类模型系数的技术。线性模型需要假定观察值服从正态分布模型。

最小二乘法

最小二乘法是线性回归最基本的方法,但单变量的线性回归实现方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package wiley.streaming.ols;

public class SimpleLinearModel {

public double a, b;

public SimpleLinearModel(double a, double b) {
this.a = a;
this.b = b;
}

public double y(double x) {
return a + b * x;
}

public double error(double[] y, double[] x) {
double error = 0.0;
for (int i = 0; i < y.length; i++)
error += (y[i] - a - b * x[i]) * (y[i] - a - b * x[i]);
return error;
}

public void fit(double[] y, double[] x) {
double sumX = 0.0, sumY = 0.0;
double sumXY = 0.0, sumX2 = 0.0;
for (int i = 0; i < y.length; i++) {
sumX += x[i];
sumY += y[i];
sumXY += x[i] * y[i];
sumX2 += x[i] * x[i];
}

double n = (double) y.length;
b = (sumXY - (sumX * sumY) / n) / (sumX2 - (sumX * sumX) / n);
a = sumY / n - (b * sumX) / n;
}

}

QR分解

如果是多变量,如果满足特定条件,解决方案见括号:

  1. 不同的X变量之间没有关联,术语叫相互独立。(可以去掉相互关联的的变量只剩其中一个,或者对其进行正交变换,如主成分分析,这样就可以生成不相关的x值了)

  2. y的标准差不依赖与x的值(只有y的均值随x的值变化)

如果这些条件都得到了满足,并且已经将x的值放入k列n行的矩阵X,其中k是不同变量的数量,n是观察次数。那就可以通过下列表达式找到使平方误差的均值最小的向量B:
$$B=(X^TX)^{-1}X^Ty$$
该公式是线性方程组的解,这种方程组又名为正规方程。可以使用Apache Commons Math库等线性代数库来直接求值。但是,直接求解可能存在数值稳定性方面的问题,因此可以使用其他数据,最常见为QR分解算法。它是指任何矩阵A都可以用两个矩阵Q和R来表示,其中Q是正交矩阵(变量不相关),R是上对角矩阵。正交矩阵有$Q^TQ=I$的特性,其中I是单位矩阵。将X替换为QR分解,计算B的等式就变成:
$$B=(R^{-1}Q^T)y$$
下面是直接使用Commons Math库的普通线性回归和多元线性回归方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package wiley.streaming.ols;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;

public class LinearModel {

double[] B;

public LinearModel(double[] B) {
this.B = B;
}

public double y(double[] x) {
double y = 0;
for(int i=0;i<B.length;i++) y += B[i]*x[i];
return y;
}

public double error(double[] y,double[][] x) {
double error = 0.0;
for(int i=0;i<y.length;i++) {
double diff = y[i] - y(x[i]);
error += diff*diff;
}
return error;
}

public void fit(double[] y,double[][] x) {
RealMatrix X = new Array2DRowRealMatrix(x);
RealVector Y = new ArrayRealVector(y);
B = (new QRDecomposition(X)).getSolver().solve(Y).toArray();
}

public void fitOLS(double[] y,double[][] x) {
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
ols.newSampleData(y, x);
B = ols.estimateRegressionParameters();
}
}

逻辑回归

对于观察值不服从正态分布的模型,也可以使用广义线性模型来拟合(Generalized Linear Model, GLM) 来拟合。这类模型中,都假设观察值y服从指数分布簇中的分布。如泊松分布就用来对计数数据进行建模。其中一个最流行的GLM模型是逻辑回归模型,他用来为伯努利分布的数据建模。具体来说,它被用来对事件发生的概率或观察值属于某个类的概率建模。逻辑回归模型也有针对多个类建模的扩展,在该扩展模型中,模型的输出是多项分布,不是伯努利分布。

与其他的线性模型相似,逻辑回归的迭代结果y’是通过对输入x的值进行加权线性组合得来的。但是,逻辑回归并不是直接使用这个概率,而是将它转换成0和1之间的值。如果给定输入和输出的观察值,要求B时,逻辑回归要比线性回归复杂,这时候常常使用牛顿迭代法的方法通过不断缩短误差来训练回归模型,神经网络中的低度下降法原理也是类似的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package wiley.streaming.ols;

public class LogisticRegression {

double[] B;

public LogisticRegression(double[] B) {
this.B = B;
}

public double y(double[] x) {
double y = 0;
for (int i = 0; i < B.length; i++)
y += B[i] * x[i];
return logit(y);
}

public static double invlogit(double p) {
return Math.log(p / (1 - p));
}

public static double logit(double y) {
return 1.0 / (1.0 + Math.exp(-y));
}

public LogisticRegression initialize(int k, double alpha) {
B = new double[k];
return this;
}

int MAX_ITER = 1000;
double alpha = 0.2;

public LogisticRegression fit(double[][] x, double[] y) {
//Initialize the weights
B = new double[x[0].length];
double lastError = Double.POSITIVE_INFINITY;
for (int iter = 0; iter < MAX_ITER; iter++) {
double err2 = 0;
for (int i = 0; i < x.length; i++) {
double t = y[i] - y(x[i]);
for (int j = 0; j < x[i].length; i++)
B[j] += alpha * t * x[i][j];
err2 += t * t;
}
err2 = Math.sqrt(err2);
if (err2 - lastError < 1e-6)
break;
lastError = err2;
}
return this;
}

}

参考资料

拜伦·埃利斯. 实时分析[M]. 机械工业出版社, 2016.