【pytorch-learning】(二) 模型搭建-训练-测试
模型搭建
自定义模型
基本骨架为 torch.nn.Module
实现一个继承自Module的类,该类中主要包含init和forward方法。
通过forward方法完成网络的前向传递。
详细讲解见我记的B站小土堆的笔记:神经网络搭建
一个简单的代码示例:
1 | import torch |
torch.nn.Sequential
使用示例
1 | import torch |
迁移学习
修改已有网络模型用于自己的任务
实例一:将VGG16用于数据集CIFAR10的分类
通过torchvision.models.vgg16(pretained=True)(pretained=True会将其权重也下载下来)将模型下载下来之后,默认保存路径是C:\Users\[usename].cache\torch\hub\checkpoints
VGG16用于1000分类(最后全连接层的输出是1000),使用数据集CIFAR10是需要10分类的。因此难点在于如何使用该VGG模型进行迁移。 下文代码展现了两种主要方式
方式一
1 | import torchvision |
VGG模型添加了一层之后的结构
多了一层添加的:
(add_linear): Linear(in_features=1000, out_features=10, bias=True)
1 | VGG( |
方式二
想要加到其中的classifier里面
1 | vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10)) |
实例二:resnet18
获取网络中的模型,通过Sequential完成新模型的组建
1 | from torchvision.models import resnet18 ##从网络获取模型 |
模型保存和读取
方式一
1 | import torchvision |
方式二(推荐)
1 | import torchvision |
方式一的报错实例
报错过程:
a.py中创建模型,并使用toch.save(model,”model path”)
b.py中加载该模型:model = torch.load(“model path”)
解决方案:要让该文件能访问到该模型的定义。
在直接加载前重新定义一下该模型
或者
from model_save import *
即在头文件import一下该模型定义文件
模型训练
随机梯度下降
详细理论)
损失函数
通过损失函数计算经过训练的模型其预测结果与真实值的误差,loss越小说明二者越接近。
通过梯度下降搜索极值点。计算loss,通过loss.backward()反向传播进行自动求导获得梯度。
对于模型而言,变化的是各个权重参数w。
使用实例
1 | import torch |
优化器
决定梯度下降的方式。
使用主要关联四行代码:
optim = torch.optim.SGD(model.parameters(),lr=0.01)
(SGD为一种优化器)
epoch 内
optim.zero_grad()
##将优化器梯度清零,每一次循环注意清零result_loss.backward()
optim.step()
训练模型
1 | for i in range(epoch): |
模型测试评估
将测试集传入训练好的模型得到预测输出。
将预测输出和测试集的label进行比较以评估该模型。
有标签的评估指标有:
- 准确率
- 召回率
- 精确率
- F1-score
- …
另外,如ROC曲线,AUC,AP等也是常见的评估
自定义数据集实战
数据预处理与加载
1 | # -*- coding:UTF-8 -*- |
{'.ipynb_checkpoints': 0, 'bulbasaur': 1, 'charmander': 2, 'mewtwo': 3, 'pikachu': 4, 'squirtle': 5}
written into csv file: ./pokeman\image.csv
torch.Size([3, 224, 224]) torch.Size([])
数据集加载:DataLoader
1 | import visdom |
网络创建
1 | #Resnet.py |
Train and test
1 | for epoch in range(epoch): |
1 | #train_scratch.py |
best acc: 0.8583690987124464 best epoch 5
test acc 0.8497854077253219
optional: 迁移学习
1 | import torch |
best acc: 0.8412017167381974 best epoch 8
test acc 0.8025751072961373