文章目录
- 1.线性模型
- 2.梯度下降算法
- 3.反向传播
- 3.1原理
- 3.2Tensor in PyTorch
- 4.用PyTorch实现线性模型
1.线性模型
2.梯度下降算法
# 梯度下降
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = 3.0
def forward(x):
return x*w
# 损失函数
def cost(xs,ys):
cost = 0
for x,y in zip(xs,ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost / len(xs)
# 计算梯度(求导)
def gradient(xs,ys):
grad = 0
for x,y in zip(xs,ys):
grad += 2*x*(x*w-y)
return grad / len(xs)
print('Predict(before training)',4,forward(4))
loss_list = []
# 学习过程
for epoch in range(100):
cost_val = cost(x_data, y_data)
loss_list.append(cost_val)
grad_val = gradient(x_data,y_data)
w -= 0.01 * grad_val
print('Epoch:',epoch,'w=',w,'loss=',cost_val)
print('Predict(after training)',4,forward(4))
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Function')
plt.show()
print('w=',w)
随机梯度下降:
随机梯度下降从n个数据中随机选择一个,
更新公式为:单个样本损失函数对权重求导,进行更新。
有可能可以解决鞍点问题
code:
# 随机梯度下降(SGD)
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = 3.0
def forward(x):
return x*w
# 损失函数
def loss(xs,ys):
y_pred = forward(x)
return (y_pred - y) ** 2
def gradient(xs,ys):
return 2*x*(x*w-y)
print('Predict(before training)',4,forward(4))
l_list = []
for epoch in range(100):
for x,y in zip(x_data,y_data):
grad = gradient(x,y)
w = w - 0.01 * grad
print("\tgrad:",x,y,grad)
l = loss(x,y)
l_list.append(l)
print('Predict(after training)',4,forward(4))
plt.plot(range(len(l_list)), l_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Function')
plt.show()
print('w=',w)
但是随机梯度下降性能比较低(无法并行)
采用梯度下降(整个数据集都用上,但是最优值有可能没有随机梯度下降好)
所以我们采用折中的方式,采用mini-batch梯度下降,若干个一组去求梯度。
3.反向传播
3.1原理
当网络比较复杂时,在向上面一样写解析式就会比较复杂,难以实现。
BP:将网络看作一个图,在图上传播梯度,然后根据链式法则求导,求出最终需要的导数。
从上图可以看出经过不断地线性变换,最终不管神经网络有多少层,最终都会统一成右边的形式。所以神经网络层数多少并没有什么本质区别。
所以对每一层的输出加一层非线性变换函数。
前向传播+反向传播:
example:f=x*w,x=2,w=3
绿色框为函数计算模块,可以计算局部梯度
完整的计算图分析如下:
3.2Tensor in PyTorch
code:
import torch
import matplotlib.pyplot as plt
# y = x * w
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = torch.Tensor([1.0]) # 初始化权重
w.requires_grad = True # 需要计算梯度
def forward(x):
return x * w #返回tensor对象
def loss(x,y):
y_pred = forward(x)
return (y_pred - y) ** 2
print('Predict(before training)',4,forward(4).item())
l_list = []
# 需要获取梯度的数值时,可以使用w.grad.item();
# 需要对梯度进行操作、使用其他PyTorch函数或者与其他张量进行运算时,可以使用w.grad.data
for epoch in range(100):
for x,y in zip(x_data,y_data):
l = loss(x,y)
l_list.append(l.item())
l.backward()
print('\tgrad:',x,y,w.grad.item())
w.data = w.data - 0.01 * w.grad.data
w.grad.data.zero_()
print("progress:",epoch,l.item())
print('Predict(after training)',4,forward(4).item())
plt.plot(range(len(l_list)), l_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Function')
plt.show()
Predict(before training) 4 12.0
grad: 1.0 2.0 2.0
grad: 2.0 4.0 7.84
grad: 3.0 6.0 16.228800000000003
grad: 1.0 2.0 1.478624
grad: 2.0 4.0 5.796206080000001
grad: 3.0 6.0 11.998146585600002
grad: 1.0 2.0 1.093164466688
grad: 2.0 4.0 4.285204709416959
grad: 3.0 6.0 8.870373748493105
grad: 1.0 2.0 0.808189608196038
grad: 2.0 4.0 3.1681032641284688
grad: 3.0 6.0 6.557973756745934
grad: 1.0 2.0 0.5975042756146296
grad: 2.0 4.0 2.3422167604093467
grad: 3.0 6.0 4.8483886940473475
grad: 1.0 2.0 0.4417420810132029
grad: 2.0 4.0 1.731628957571754
grad: 3.0 6.0 3.5844719421735274
grad: 1.0 2.0 0.3265852213980329
grad: 2.0 4.0 1.2802140678802907
grad: 3.0 6.0 2.6500431205121995
grad: 1.0 2.0 0.241448373202223
grad: 2.0 4.0 0.9464776229527132
grad: 3.0 6.0 1.9592086795121144
grad: 1.0 2.0 0.17850567968888154
grad: 2.0 4.0 0.699742264380415
grad: 3.0 6.0 1.44846648726746
grad: 1.0 2.0 0.13197139106214628
grad: 2.0 4.0 0.5173278529636143
grad: 3.0 6.0 1.070868655634678
grad: 1.0 2.0 0.09756803306893769
grad: 2.0 4.0 0.38246668963023467
grad: 3.0 6.0 0.7917060475345838
grad: 1.0 2.0 0.07213321766426173
grad: 2.0 4.0 0.28276221324390605
grad: 3.0 6.0 0.5853177814148793
grad: 1.0 2.0 0.05332895341780031
grad: 2.0 4.0 0.20904949739777834
grad: 3.0 6.0 0.4327324596134048
grad: 1.0 2.0 0.03942673520922124
grad: 2.0 4.0 0.15455280202014876
grad: 3.0 6.0 0.3199243001817109
grad: 1.0 2.0 0.029148658460999677
grad: 2.0 4.0 0.1142627411671171
grad: 3.0 6.0 0.23652387421593346
grad: 1.0 2.0 0.02154995298411766
grad: 2.0 4.0 0.08447581569773988
grad: 3.0 6.0 0.17486493849431994
grad: 1.0 2.0 0.015932138840593524
grad: 2.0 4.0 0.06245398425512505
grad: 3.0 6.0 0.12927974740810555
grad: 1.0 2.0 0.011778821430517006
grad: 2.0 4.0 0.0461729800076256
grad: 3.0 6.0 0.09557806861577944
grad: 1.0 2.0 0.00870822402943805
grad: 2.0 4.0 0.034136238195397794
grad: 3.0 6.0 0.0706620130644744
grad: 1.0 2.0 0.0064380945236521825
grad: 2.0 4.0 0.02523733053271826
grad: 3.0 6.0 0.052241274202728505
grad: 1.0 2.0 0.004759760538470381
grad: 2.0 4.0 0.01865826131080439
grad: 3.0 6.0 0.03862260091336189
grad: 1.0 2.0 0.0035189480832178432
grad: 2.0 4.0 0.013794276486212453
grad: 3.0 6.0 0.028554152326460525
grad: 1.0 2.0 0.002601600545299121
grad: 2.0 4.0 0.010198274137572128
grad: 3.0 6.0 0.02111042746477665
grad: 1.0 2.0 0.0019233945023460208
grad: 2.0 4.0 0.007539706449197325
grad: 3.0 6.0 0.01560719234984198
grad: 1.0 2.0 0.0014219886363191492
grad: 2.0 4.0 0.0055741954543719885
grad: 3.0 6.0 0.011538584590550016
grad: 1.0 2.0 0.001051293262694486
grad: 2.0 4.0 0.004121069589761106
grad: 3.0 6.0 0.008530614050808794
grad: 1.0 2.0 0.0007772337246292338
grad: 2.0 4.0 0.0030467562005469517
grad: 3.0 6.0 0.006306785335137732
grad: 1.0 2.0 0.0005746182194235061
grad: 2.0 4.0 0.002252503420141494
grad: 3.0 6.0 0.004662682079690228
grad: 1.0 2.0 0.0004248221450389167
grad: 2.0 4.0 0.0016653028085542587
grad: 3.0 6.0 0.0034471768137098735
grad: 1.0 2.0 0.0003140761096931399
grad: 2.0 4.0 0.0012311783499967532
grad: 3.0 6.0 0.0025485391844988214
grad: 1.0 2.0 0.00023220023680980972
grad: 2.0 4.0 0.0009102249282939567
grad: 3.0 6.0 0.0018841656015720076
grad: 1.0 2.0 0.000171668421476312
grad: 2.0 4.0 0.000672940212187001
grad: 3.0 6.0 0.001392986239231675
grad: 1.0 2.0 0.00012691652401830567
grad: 2.0 4.0 0.0004975127741531082
grad: 3.0 6.0 0.0010298514425031158
grad: 1.0 2.0 9.383090920511705e-05
grad: 2.0 4.0 0.0003678171640828509
grad: 3.0 6.0 0.0007613815296476645
grad: 1.0 2.0 6.937031714571162e-05
grad: 2.0 4.0 0.0002719316432120422
grad: 3.0 6.0 0.0005628985014531906
grad: 1.0 2.0 5.1286307909848006e-05
grad: 2.0 4.0 0.00020104232700646207
grad: 3.0 6.0 0.0004161576169003922
grad: 1.0 2.0 3.7916582873442906e-05
grad: 2.0 4.0 0.00014863300486211983
grad: 3.0 6.0 0.00030767032005840633
grad: 1.0 2.0 2.8032184716586528e-05
grad: 2.0 4.0 0.00010988616408980079
grad: 3.0 6.0 0.0002274643596678061
grad: 1.0 2.0 2.0724530547688857e-05
grad: 2.0 4.0 8.124015974786403e-05
grad: 3.0 6.0 0.00016816713067413502
grad: 1.0 2.0 1.5321894128561553e-05
grad: 2.0 4.0 6.006182498552448e-05
grad: 3.0 6.0 0.0001243279777209949
grad: 1.0 2.0 1.1327660192073097e-05
grad: 2.0 4.0 4.440442795328181e-05
grad: 3.0 6.0 9.191716586265386e-05
grad: 1.0 2.0 8.37467511161094e-06
grad: 2.0 4.0 3.2828726435951694e-05
grad: 3.0 6.0 6.79554637201818e-05
grad: 1.0 2.0 6.191497806007362e-06
grad: 2.0 4.0 2.4270671399762023e-05
grad: 3.0 6.0 5.0240289800385085e-05
grad: 1.0 2.0 4.577448626363889e-06
grad: 2.0 4.0 1.794359861406747e-05
grad: 3.0 6.0 3.714324913239864e-05
grad: 1.0 2.0 3.3841626985164908e-06
grad: 2.0 4.0 1.3265917779392566e-05
grad: 3.0 6.0 2.7460449802063636e-05
grad: 1.0 2.0 2.5019520935032347e-06
grad: 2.0 4.0 9.80765220504054e-06
grad: 3.0 6.0 2.0301840065073407e-05
grad: 1.0 2.0 1.8497232057157476e-06
grad: 2.0 4.0 7.250914965339916e-06
grad: 3.0 6.0 1.5009393973031138e-05
grad: 1.0 2.0 1.3675225618570153e-06
grad: 2.0 4.0 5.360688440703143e-06
grad: 3.0 6.0 1.109662507481346e-05
grad: 1.0 2.0 1.011025839936508e-06
grad: 2.0 4.0 3.963221292480057e-06
grad: 3.0 6.0 8.203868070211229e-06
grad: 1.0 2.0 7.474635355109172e-07
grad: 2.0 4.0 2.9300570574264384e-06
grad: 3.0 6.0 6.06521810908589e-06
grad: 1.0 2.0 5.526087614171615e-07
grad: 2.0 4.0 2.166226344968436e-06
grad: 3.0 6.0 4.484088535150477e-06
grad: 1.0 2.0 4.08550288710785e-07
grad: 2.0 4.0 1.6015171304673004e-06
grad: 3.0 6.0 3.3151404608133817e-06
grad: 1.0 2.0 3.020461303293587e-07
grad: 2.0 4.0 1.1840208316016287e-06
grad: 3.0 6.0 2.450923123120674e-06
grad: 1.0 2.0 2.2330632898359681e-07
grad: 2.0 4.0 8.753608113920563e-07
grad: 3.0 6.0 1.811996877876254e-06
grad: 1.0 2.0 1.6509304945344638e-07
grad: 2.0 4.0 6.471647537864555e-07
grad: 3.0 6.0 1.3396310407642886e-06
grad: 1.0 2.0 1.22055272555599e-07
grad: 2.0 4.0 4.784566698390336e-07
grad: 3.0 6.0 9.904053097642418e-07
grad: 1.0 2.0 9.023692815190998e-08
grad: 2.0 4.0 3.5372875828443284e-07
grad: 3.0 6.0 7.322185258118452e-07
grad: 1.0 2.0 6.671324381812838e-08
grad: 2.0 4.0 2.6151591470124913e-07
grad: 3.0 6.0 5.413379398078177e-07
grad: 1.0 2.0 4.932190122985958e-08
grad: 2.0 4.0 1.933418545263521e-07
grad: 3.0 6.0 4.002176350326181e-07
grad: 1.0 2.0 3.6464273378555845e-08
grad: 2.0 4.0 1.4293995320713293e-07
grad: 3.0 6.0 2.9588569994132286e-07
grad: 1.0 2.0 2.6958474563798518e-08
grad: 2.0 4.0 1.0567721986376455e-07
grad: 3.0 6.0 2.1875184863517916e-07
grad: 1.0 2.0 1.993072373807081e-08
grad: 2.0 4.0 7.812843705323758e-08
grad: 3.0 6.0 1.6172586470020178e-07
grad: 1.0 2.0 1.473502297955065e-08
grad: 2.0 4.0 5.7761290861435555e-08
grad: 3.0 6.0 1.1956587187000878e-07
grad: 1.0 2.0 1.0893779212040045e-08
grad: 2.0 4.0 4.270361486646834e-08
grad: 3.0 6.0 8.839647946956575e-08
grad: 1.0 2.0 8.05390154567931e-09
grad: 2.0 4.0 3.157129313535734e-08
grad: 3.0 6.0 6.535257668360828e-08
grad: 1.0 2.0 5.9543463493128e-09
grad: 2.0 4.0 2.334103754719763e-08
grad: 3.0 6.0 4.8315953904420894e-08
grad: 1.0 2.0 4.402119557767037e-09
grad: 2.0 4.0 1.725630838222969e-08
grad: 3.0 6.0 3.5720557178819945e-08
grad: 1.0 2.0 3.254539748809293e-09
grad: 2.0 4.0 1.2757794820572599e-08
grad: 3.0 6.0 2.640863527858528e-08
grad: 1.0 2.0 2.4061197478886243e-09
grad: 2.0 4.0 9.431989411723407e-09
grad: 3.0 6.0 1.952421158080142e-08
grad: 1.0 2.0 1.778873048863261e-09
grad: 2.0 4.0 6.973181143621332e-09
grad: 3.0 6.0 1.443448560678462e-08
grad: 1.0 2.0 1.31514177326153e-09
grad: 2.0 4.0 5.155357030162122e-09
grad: 3.0 6.0 1.0671591610389441e-08
grad: 1.0 2.0 9.723004623651832e-10
grad: 2.0 4.0 3.811418736177075e-09
grad: 3.0 6.0 7.88963561149103e-09
grad: 1.0 2.0 7.188329931295812e-10
grad: 2.0 4.0 2.817824196199581e-09
grad: 3.0 6.0 5.832891503132487e-09
grad: 1.0 2.0 5.314415574275699e-10
grad: 2.0 4.0 2.083250905116074e-09
grad: 3.0 6.0 4.312326495892194e-09
grad: 1.0 2.0 3.929008229874853e-10
grad: 2.0 4.0 1.5401724340335932e-09
grad: 3.0 6.0 3.188153741007227e-09
grad: 1.0 2.0 2.9047608762766686e-10
grad: 2.0 4.0 1.1386660503376334e-09
grad: 3.0 6.0 2.357042561129674e-09
grad: 1.0 2.0 2.1475266009929328e-10
grad: 2.0 4.0 8.418297170464939e-10
grad: 3.0 6.0 1.7425900722400911e-09
grad: 1.0 2.0 1.5876899794875499e-10
grad: 2.0 4.0 6.223750403933082e-10
grad: 3.0 6.0 1.288313455916068e-09
grad: 1.0 2.0 1.1737988359072915e-10
grad: 2.0 4.0 4.601297121098469e-10
grad: 3.0 6.0 9.524701027885385e-10
grad: 1.0 2.0 8.678036067522044e-11
grad: 2.0 4.0 3.4017944017250556e-10
grad: 3.0 6.0 7.041727201340109e-10
grad: 1.0 2.0 6.415756814703855e-11
grad: 2.0 4.0 2.5149660132228746e-10
grad: 3.0 6.0 5.205968989230314e-10
grad: 1.0 2.0 4.743228032566549e-11
grad: 2.0 4.0 1.8593482309370302e-10
grad: 3.0 6.0 3.8488678910653107e-10
grad: 1.0 2.0 3.5067060366600344e-11
grad: 2.0 4.0 1.3746159766014898e-10
grad: 3.0 6.0 2.845510493898473e-10
grad: 1.0 2.0 2.5925039892626955e-11
grad: 2.0 4.0 1.0162537478208833e-10
grad: 3.0 6.0 2.1037038777649286e-10
grad: 1.0 2.0 1.9166890297128703e-11
grad: 2.0 4.0 7.5132788879273e-11
grad: 3.0 6.0 1.5552359400317073e-10
grad: 1.0 2.0 1.4169998507895798e-11
grad: 2.0 4.0 5.554667836804583e-11
grad: 3.0 6.0 1.1498002550069941e-10
grad: 1.0 2.0 1.0476064460362977e-11
grad: 2.0 4.0 4.106581741325499e-11
grad: 3.0 6.0 8.500400383582019e-11
grad: 1.0 2.0 7.744915819785092e-12
grad: 2.0 4.0 3.036149109902908e-11
grad: 3.0 6.0 6.284572862114146e-11
grad: 1.0 2.0 5.726086271806707e-12
grad: 2.0 4.0 2.2446045022661565e-11
grad: 3.0 6.0 4.646949491871055e-11
grad: 1.0 2.0 4.233946526710497e-12
grad: 2.0 4.0 1.659827830735594e-11
grad: 3.0 6.0 3.4356517630840244e-11
grad: 1.0 2.0 3.1299407510232413e-12
grad: 2.0 4.0 1.227107304657693e-11
grad: 3.0 6.0 2.5403679160262982e-11
grad: 1.0 2.0 2.3145929617385264e-12
grad: 2.0 4.0 9.07363073565648e-12
grad: 3.0 6.0 1.878497357665765e-11
grad: 1.0 2.0 1.7115198147621413e-12
grad: 2.0 4.0 6.707523425575346e-12
grad: 3.0 6.0 1.3887557770431158e-11
grad: 1.0 2.0 1.2647660696529783e-12
grad: 2.0 4.0 4.959588295605499e-12
grad: 3.0 6.0 1.0263789818054647e-11
grad: 1.0 2.0 9.352518759442319e-13
grad: 2.0 4.0 3.666400516522117e-12
grad: 3.0 6.0 7.58859641791787e-12
grad: 1.0 2.0 6.910028105266974e-13
grad: 2.0 4.0 2.7071678232459817e-12
grad: 3.0 6.0 5.6061821851471905e-12
grad: 1.0 2.0 5.10702591327572e-13
grad: 2.0 4.0 2.0037305148434825e-12
grad: 3.0 6.0 4.1460168631601846e-12
grad: 1.0 2.0 3.7836400679225335e-13
grad: 2.0 4.0 1.4814816040598089e-12
grad: 3.0 6.0 3.069544618483633e-12
grad: 1.0 2.0 2.7977620220553945e-13
grad: 2.0 4.0 1.0977885267493548e-12
grad: 3.0 6.0 2.27018404075352e-12
grad: 1.0 2.0 2.0694557179012918e-13
grad: 2.0 4.0 8.100187187665142e-13
grad: 3.0 6.0 1.6786572132332367e-12
Predict(after training) 4 8.000000000000306
w= 2.0000000000000764
训练:y = w1xx + w2 *x + b
import numpy as np
import matplotlib.pyplot as plt
import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w1 = torch.Tensor([1.0])#初始权值
w1.requires_grad = True#计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True
def forward(x):
return w1 * x**2 + w2 * x + b
def loss(x,y):#构建计算图
y_pred = forward(x)
return (y_pred-y) **2
print('Predict (befortraining)',4,forward(4).item())
loss_history = []
for epoch in range(100):
for x,y in zip(x_data,y_data):
l = loss(x, y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data = w1.data - 0.01 * w1.grad.data #注意这里的grad是一个tensor,所以要取他的data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
loss_history.append(l.item())
print('progress:',epoch,l.item())
print('Predict(after training)',4,forward(4).item())
print('w1=',w1.item(),'w2=',w2.item(),'b=',b.item())
# 绘制损失函数图像
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Function')
plt.show()
w1= 0.24038191139698029 w2= 0.9266766309738159 b= 0.99135422706604
4.用PyTorch实现线性模型
pytorch:
1.准备数据
采用mini-batch设计时,需要注意x和y都应该是一个矩阵
2.设计模型
主要是构造计算图
3.构造损失函数和优化器
(1)使用SGD优化器
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
# 准备模型
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel,self).__init__()
self.linear = torch.nn.Linear(1,1)
def forward(self,x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
# 损失函数和优化器
loss_history = [] # 用于存储每次迭代的损失值
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch,loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_history.append(loss.item())
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)
# 绘制损失函数图像
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Function')
plt.show()
(2)Adagrad
(3)Adam
(4)Adamax
(4)ASGD
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.ASGD(model.parameters(), lr=0.01)
4.训练周期(前向传播+反馈+更新)