【pytorch learning】(一)自定义数据集预处理和加载

通用实现

关键库:Dataset和 DataLoader

from torch.utils.data import Dataset,DataLoader

代码使用数据集:宝可梦

该数据集目录格式为:

root dir:

  • dir1(name is the class)
    • pic1
    • pic2
  • dir2(name is the class)

Dataset

继承Dataset类的实现

需要重载方法:

  • __len__:返回数据集长度(数据数量)
  • __get_item__:获取数据集的item

关键思路:具体见代码及代码注释

自定义方法load_csv:

  1. 遍历目录,将image_path-label存储到csv文件中:
  • glob.glob获取具体路径的数据路径
  • csv.writer(f)将image_path,label写入csv文件
  1. 读取csv文件,获得image_path-label
  • reader = csv.reader(f),for row in reader:…

len方法:

  • 返回数据集长度

get_item方法:

  • 传入index,从我们全局的列表中获得对应的数据

  • 通过读取csv文件,传入index, 可以获得对应的image_path - label对。

    目标:载入image_path,处理图片,返回符合要求的图片数据和label对

  • 处理从路径上读取的数据,返回指定格式。如使用torchvision.transform,最后处理完的格式是tensor

  • 返回数据和label对

关于mode,可以通过判断需要训练集或者测试集这样子,返回切分后的数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import os,glob
import random,csv
import time
from torch.utils.data import Dataset,DataLoader
from PIL import Image
from torchvision import transforms

class Pokemen(Dataset):
def __init__(self,root,resize,mode):
'''

:param root: root dir of dataset
root dir:
- dir1(name is the class)
- pic1
- pic2
- ...
- dir2(name is the class)
- ...
:param resize:
:param mode: train\test\val
'''
super(Pokemen, self).__init__()
self.root = root
self.resize = resize

self.name2label = {}# save the dict {name:int}
#list the dir name under the root dir
for name in sorted(os.listdir(os.path.join(root))): # sorted to ensure every time to be the same
if not os.path.isdir(os.path.join(root,name)):
# not a dir
continue
# get the name:int following the order
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)

# read from csv: image_path-label;
# function load_csv:save the image_path-label to csv and then read from csv
self.images,self.labels = self.load_csv("image.csv")

#split dataset
if mode == 'train': #60%
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode == 'val': #20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8*len(self.labels))]
elif mode == 'test':
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]

def load_csv(self,filename):
'''
save image_path-label to csv(if csv exits,just read from it) ,and then read from it
:param filename: filepath of the csv to save and read
:return: images,labels==> images_path,labels
'''
savepath = os.path.join(self.root,filename)
if not os.path.exists(savepath):
images = []
for name in self.name2label.keys():
#grop the pic, save to list
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.gif'))
#print(images)
#'./pokeman/squirtle\\00000073.png'
random.shuffle(images) #shuffle
#write into csv
with open(savepath,mode='w',newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('written into csv file:',savepath)

#read from csv
images,labels = [],[]
with open(savepath) as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)

return images,labels

def __len__(self):
'''
:return: len of dataset
'''
return len(self.images)

def __getitem__(self, item):
'''

:param item: range in [0,len(images)]
:return: self.images,self.labels
'''

img,label = self.images[item],self.labels[item]

tf = transforms.Compose(
[lambda x:Image.open(x).convert('RGB'), # open image and convert to RGB
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))), # resize
transforms.RandomRotation(15), # rotate
transforms.CenterCrop(self.resize), #center crop
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]),#values computed from ImageNet,we could use it in other dataset

]
)

img = tf(img)
label = torch.tensor(label)
return img,label

def denormalize(self,x_hat):
'''
x_hat = (x - mean)/std
x = x_hat * std + mean
:param x_hat: shape[3,self.resize,self.resize]
:return:
'''
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

x = x_hat*std+mean
return x

if __name__ == '__main__':
# dataset
db_train = Pokemen(root='pokeman',resize=224,mode='train')
# dataloader
dl_train = DataLoader(dataset=db_train,batch_size=32,shuffle=True)
x,y = next(iter(dl_train))
print(x.shape,y.shape)
{'.ipynb_checkpoints': 0, 'bulbasaur': 1, 'charmander': 2, 'mewtwo': 3, 'pikachu': 4, 'squirtle': 5}
written into csv file: pokeman\image.csv
torch.Size([32, 3, 224, 224]) torch.Size([32])

DataLoader

传入数据集,处理为batch。主要参数:batch_size

1
dl_train = DataLoader(dataset=db_train,batch_size=32,shuffle=True)

其他常用参数:

  • num_workers ( int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • drop_last (bool, optional): set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

官方文档:torch.utils.data — PyTorch 1.13 documentation

dl_train返回的参数对应dataset get_item的return 参数

  • x,y = next(iter(dataset))

  • X,Y = next(iter(dataloader))

X.shape == (batchsize,x.shape)

进行训练时的一般模式:

1
2
3
4
5
6
7
8
9
for idx,(x,label) in enumerate(dl_train):
x,label = x.to(device),y.to(device)
pred = model(x)
loss = lossFunction(pred,label)
optimizer.zero_grad()
loss.backward()
loss.step()


特殊便捷方式数据集加载

使用函数datasets.ImageFolder

使用场景:适用于将数据集分文件夹存储,文件夹名称为对应的label

1
2
3
4
5
6
7
8
tf  = transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor()
])
##适用于将数据集分文件夹存储,文件夹名称为对应的label
db = torchvision.datasets.ImageFolder(root='./pokeman/',transform=tf)
loader = DataLoader(db,batch_size=32,shuffle=True)
print(db.class_to_idx)