Dataset && DataLoader

  • Dataset定义数据集的内容,是一个类似列表的数据结构,具有确定长度,使用索引能够获取数据集的元素
  • Dataloader 定义了按batch加载数据集的方法,他是一个实现了__iter__ 方法的可迭代对象,每一次迭代输出一个batch的数据; Dataloader能够控制batch大小,batch中元素的采样方法,以及整理为模型所需输入形式,并且能够使用多进程读取数据
  • 多数情况下,实现 __len__ 方法和 __getitem__ 方法,就能够构建数据集并使用默认数据管道进行加载

获取一个batch的步骤

  1. 确定数据集长度
    2。 从范围中抽样出 m 个数
  2. 从数据集中取m个数对应小标的数
  3. 将结果作为两个张量输出,最终拿到的结果是两个张量

Dataset和DataLoader的分工

确定数据集的长度是通过Dataset的 __len__ 方法实现的
从0 到 n - 1 中抽样m个数是通过DataLoader的sampler和batch_sampler参数指定的
* Sampler菜单树指定单个元素抽样方法,默认在参数shuffle=True 采用随机抽样,shuffle=False 顺序抽样
* batch_sampler 将多个抽样元素整理为列表,默认方法在drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,false时会保留
根据下标取元素是由Dataset的 __getitem__ 方法实现
DataLoader的参数 collate_fn指定,一般情况下无需设置
一般使用方法如下:

1
2
3
4
5
6
7
8
9
10
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoader
from torch.utils.data import RandomSampler,BatchSampler

ds = TensorDataset(torch.randn(100,3),
torch.randint(low=0,high=2,size=(1000,)).float())
dl = DataLoader(ds,batch_size=4, drop_last = False)
features, labels = next(iter(dl))
print("features = ",features)
print("labels = ",labels)

使用Dataset创建数据集

Dataset 创建数据集常用的方法有:
* torch.utils.data.TensorDatset 根据tensor创建数据集
* 使用torchvision.datasets.ImageFolder 根据图片目录创建图片数据集
* 继承torch.utils.data.Dataset 创建自定义数据集
此外还可以通过:
* torch.utils.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集
* 调用Dataset加法运算符将多个数据集合并成一个数据集

使用Tensor创建数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoder,random_split
from sklearn import datasets
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))

n_train = int(len(ds_iris)*0.8)
n_val = len(ds_iris) - n_train
ds_train,ds_val = random_split(ds_iris,[n_train,n_val])

print(type(ds_iris))
print(type(ds_train))

dl_train,dl_val = DataLoader(ds_train,batch_size=8),Dataloader(ds_val,batch_size=8)
for features,labels in dl_train:
print(features,labels)
break

ds_data = ds_train + ds_val
print('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_val))
print('len(ds_train+ds_valid) = ',len(ds_data))

print(type(ds_data))

根据图片创建图片数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np 
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from PIL import Image
img = Image.open()
img

transformes.RandomVerticalFlip()(img)
transforms.RandomRotation(45)(img)

transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.RandomVerticalFlip(), #随机垂直翻转
transforms.RandomRotation(45), #随机在45度角度内旋转
transforms.ToTensor() #转换成张量
]
)

transform_valid = transforms.Compose([
transforms.ToTensor()
]
)

def transform_label(x):
return torch.tesnor([x]).float()
ds_train = datasets.ImageFolder("/path/to/folder/train",transform=transform_train,target_transform=transform_label)
ds_val = datsets.ImageFolder("/path/to/folder",transform=transform_valid,target_transform=transform_label)
print(ds_train.class_to_idx)

dl_train = DataLoader(ds_train,batch_size=50,shuffle=True)
dl_val = DataLoader(ds_val,batch_size=50,shuffle=True)
for features,labels in dl_train:
print(features,shape)
print(labels.shape)
break

创建自定义数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pathlib import Path
from PIL import Image
class Cifar2Dataset(Dataset):
def __init__(self,imgs_dir,img_transform):
self.files = list(Path(imgs_dir).rglob("*.jpg"))
self.transform = img_transform
def __len__(self,):
return len(self.files)
def __getitem__(self,i):
file_i = str(self.files[i])
img = Image.open(file_i)
tensor = self.transform(img)
label = torch.tensor([1.0]) if "1_automobile" in file_i else torch.tensor([0.0])
return tensor,label
1
2
3
4
5
6
7
8
9
10
11
12
13
# 定义图片增强
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.RandomVerticalFlip(), #随机垂直翻转
transforms.RandomRotation(45), #随机在45度角度内旋转
transforms.ToTensor() #转换成张量
]
)

transform_val = transforms.Compose([
transforms.ToTensor()
]
)
1
2
3
4
5
6
7
8
9
ds_train = Cifar2Dataset(train_dir,transfor_train)
ds_val = Cifar2Dataset(test_dir,transform_val)

dl_train = DataLoader(ds_train,batch_size = 50,shuffle=True)
dl_val = DataLoader(ds_val,batch_size=50,shuffle=True)
for features,labels in dl_train:
print(features.shape)
print(labels.shape)
break

使用DataLoader加载数据集

DataLoader能够控制batch代销,batch中的采样方法,以及将batch整理成模型输入形式的方法,并且使用多进程读取数据
DataLoader的函数签名如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
)

在一般情况下,我们仅仅配置dataset,batch_szie,shuffle,num_workers,pin_memory,drop_last这六个参数
有时候构建复杂数据集,还需要定义collate_fn 函数,其他一般默认值即可。
DataLoader除了可以加载torch.utils.data.Dataset外,还能够加载林一种数据集 torch.utils.data.IterableDatasete
和一般的dataset数据集不同,IterableDataset相当于一种迭代器结构,更加复杂,一般比较少使用

  • dataset : 数据集
  • batch_size: 批次大小
  • shuffle: 是否乱序
  • sampler: 样本采样函数,一般无需设置。
  • batch_sampler: 批次采样函数,一般无需设置。
  • num_workers: 使用多进程读取数据,设置的进程数。
  • collate_fn: 整理一个批次数据的函数。
  • pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
  • drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
  • timeout: 加载一个数据批次的最长等待时间,一般无需设置。
  • worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。
    1
    2
    3
    4
    5
    6
    7
    8
    ds = TesnorDataset(torch.arange(1,50))
    dl = DataLoader(ds,
    batch_size = 10,
    shuffle = True,
    num_workers = 2,
    drop_last = True)
    for batch, in dl:
    print(batch)

哦我草真的有人会面试的时候手搓数据集和训练流程的,我真的是服了