PyTorch CNN – 定义损失函数和优化器

接下来,定义损失函数和优化器。

  • 损失函数使用交叉熵损失函数
  • 优化器使用SGD/随机梯度下降优化器
    • 学习率 lr=0.001
    • momentum即动量

Momentum/动量

SGD方法的一个缺点是,其更新方向完全依赖于当前的batch,因而其更新十分不稳定。解决这一问题的一个简单的做法便是引入momentum。
momentum即动量,它模拟的是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


浙ICP备17015664号-1 浙公网安备 33011002012336号 联系我们 网站地图  
@2019 qikegu.com 版权所有,禁止转载