多维线性回归模型

前言

今儿个是多维线性回归模型~

测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import random
import matplotlib.pyplot as plt

def make_features(x):
x = x.unsqueeze(1)
return torch.cat([x**i for i in range(1, 3 + 1)], 1)

W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
b_target = torch.FloatTensor([0.9])

def f(x):
return x.mm(W_target) + b_target[0]

def get_batch(batch_size=32, random=None):
if random is None:
random = torch.randn(batch_size)
x = make_features(random)
y = f(x)
return Variable(x), Variable(y)

class poly_model(nn.Module):
def __init__(self):
super().__init__()
self.poly = nn.Linear(3, 1)

def forward(self, x):
out = self.poly(x)
return out

model = poly_model()

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

epoch = 0
while True:
batch_x, batch_y = get_batch()

output = model(batch_x)
loss = criterion(output, batch_y)
print_loss = loss.item()

optimizer.zero_grad()

loss.backward()

optimizer.step()
epoch += 1
if print_loss < 1e-3:
Break

print("the number of epoches :", epoch)

def func_format(weight, bias, n):
func = ''
for i in range(n, 0, -1):
func += ' {:.2f} * x^{} +'.format(weight[i - 1], i)
return 'y =' + func + ' {:.2f}'.format(bias[0])

predict_weight = model.poly.weight.data.numpy().flatten()
predict_bias = model.poly.bias.data.numpy().flatten()
print('predicted function :', func_format(predict_weight, predict_bias, 3))
real_W = W_target.numpy().flatten()
real_b = b_target.numpy().flatten()
print('real function :', func_format(real_W, real_b, 3))

x = [random.randint(-200, 200) * 0.01 for i in range(20)]
x = np.array(sorted(x))
feature_x, y = get_batch(random=torch.from_numpy(x).float())
y = y.data.numpy()
plt.plot(x, y, 'ro', label='Original data')

model.eval()
x_sample = np.arange(-2, 2, 0.01)
x, y = get_batch(random=torch.from_numpy(x_sample).float())
y = model(x)
y_sample = y.data.numpy()
plt.plot(x_sample, y_sample, label='Fitting Line')
plt.show()

测试结果

1
2
3
the number of epoches : 1727
predicted function : y = 2.41 * x^3 + 2.99 * x^2 + 0.45 * x^1 + 0.93
real function : y = 2.40 * x^3 + 3.00 * x^2 + 0.50 * x^1 + 0.90

效果图

这个效果图在matplotlib中的实现要比一维线性回归模型要难一些,具体是:
80-84行先随机给一些在拟合方程上的点集,命名为’Original data’;
86-92行描绘出预测曲线,可以比较直观的看重合度。预测曲线在x轴上位于-2到2之间,x轴上的各点之间相差0.01,主要是为了曲线更加平滑。

相关函数

一、unsqueeze()

这里先介绍squeeze()函数,squeeze()中的参数0、1分别代表第零、第一维度,也就是行和列,理所应当的,squeeze(0)表示如果第零维度值为1,则去掉,否则不变。我写个例子:

1
2
3
4
5
6
import torch
import numpy as np
a = torch.randn(3, 1)
print(a)
a = a.squeeze(1)
print(a)

结果是:
1
2
3
4
tensor([[-0.2699],
[ 0.3355],
[-0.3069]])
tensor([-0.2699, 0.3355, -0.3069])

我感觉就是差不多行列变换,在行或者列为1的情况下。
unsqueeze()就是反向操作,咋变过来的就咋变回去:
1
2
3
4
5
6
7
8
import torch
import numpy as np
a = torch.randn(3, 1)
print(a)
a = a.squeeze(1)
print(a)
a = a.unsqueeze(1)
print(a)

结果是:
1
2
3
4
5
6
7
tensor([[-0.4933],
[ 0.0155],
[-0.5852]])
tensor([-0.4933, 0.0155, -0.5852])
tensor([[-0.4933],
[ 0.0155],
[-0.5852]])

二、torch.cat()

torch.cat()就是讲Tensor拼接在一起,有没有想到C里的strcat(),差不多~cat全称是concatnate
torch.cat()放两个参数,第一个是放需要拼接的Tensor,可以这样:

1
C=torch.cat((A,B),0)

比较常见,也可以是测试代码里的这样:
1
torch.cat([x**i for i in range(1, 3 + 1)], 1)

这里是通过循环的形式拼接三个Tensor。
第二个参数是横向或者纵向拼接,事实上还是按照第零维度或者第一维度比较好理解:0代表行与行拼接,1代表列与列拼接。

代码分析

具体思路和一维线性回归差不多,就是前期前向传播构成计算图是函数与函数之间相互嵌套比较复杂。
这个我发现一个以前没有注意到的点:

1
x.mm(W_target) + b_target[0]

其中x为32行3列,W_target为3行1列,b_target只有一个元素,照道理矩阵加法不能成立,但是:
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import numpy
x = torch.ones(3)
x = x.unsqueeze(1)
print(x)
x = torch.cat([x**i for i in range(1, 4)], 1)
W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
b_target = torch.FloatTensor([0.9])
y = x.mm(W_target)
print(y)
y += b_target[0]
print(y)

结果为:
1
2
3
4
5
6
7
8
9
tensor([[1.],
[1.],
[1.]])
tensor([[5.9000],
[5.9000],
[5.9000]])
tensor([[6.8000],
[6.8000],
[6.8000]])

实际上是每行都加了b!

比较神奇~