【pytorch learning】(一)自定义数据集预处理和加载
通用实现
关键库:Dataset和 DataLoader
from torch.utils.data import Dataset,DataLoader
代码使用数据集:宝可梦
该数据集目录格式为:
root dir:
- dir1(name is the class)
- pic1
- pic2
- …
- dir2(name is the class)
- …
Dataset
继承Dataset类的实现
需要重载方法:
- __len__:返回数据集长度(数据数量)
- __get_item__:获取数据集的item
关键思路:具体见代码及代码注释
自定义方法load_csv:
- 遍历目录,将image_path-label存储到csv文件中:
- 用
glob.glob
获取具体路径的数据路径 csv.writer(f)
将image_path,label写入csv文件
- 读取csv文件,获得image_path-label
- reader =
csv.reader(f)
,for row in reader:…
len方法:
- 返回数据集长度
get_item方法:
传入index,从我们全局的列表中获得对应的数据
通过读取csv文件,传入index, 可以获得对应的image_path - label对。
目标:载入image_path,处理图片,返回符合要求的图片数据和label对
处理从路径上读取的数据,返回指定格式。如使用
torchvision.transform
,最后处理完的格式是tensor返回数据和label对
关于mode,可以通过判断需要训练集或者测试集这样子,返回切分后的数据集
1 | import torch |
{'.ipynb_checkpoints': 0, 'bulbasaur': 1, 'charmander': 2, 'mewtwo': 3, 'pikachu': 4, 'squirtle': 5}
written into csv file: pokeman\image.csv
torch.Size([32, 3, 224, 224]) torch.Size([32])
DataLoader
传入数据集,处理为batch。主要参数:batch_size
1 | dl_train = DataLoader(dataset=db_train,batch_size=32,shuffle=True) |
其他常用参数:
- num_workers ( int, optional): how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (default:0
) - drop_last (bool, optional): set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False
)
官方文档:torch.utils.data — PyTorch 1.13 documentation
dl_train返回的参数对应dataset get_item的return 参数
x,y = next(iter(dataset))
X,Y = next(iter(dataloader))
X.shape == (batchsize,x.shape)
进行训练时的一般模式:
1 | for idx,(x,label) in enumerate(dl_train): |
特殊便捷方式数据集加载
使用函数datasets.ImageFolder
使用场景:适用于将数据集分文件夹存储,文件夹名称为对应的label
1 | tf = transforms.Compose([ |