Dataset && DataLoader
- Dataset定义数据集的内容,是一个类似列表的数据结构,具有确定长度,使用索引能够获取数据集的元素
- Dataloader 定义了按batch加载数据集的方法,他是一个实现了__iter__ 方法的可迭代对象,每一次迭代输出一个batch的数据; Dataloader能够控制batch大小,batch中元素的采样方法,以及整理为模型所需输入形式,并且能够使用多进程读取数据
- 多数情况下,实现 __len__ 方法和 __getitem__ 方法,就能够构建数据集并使用默认数据管道进行加载
获取一个batch的步骤
- 确定数据集长度
2。 从范围中抽样出 m 个数
- 从数据集中取m个数对应小标的数
- 将结果作为两个张量输出,最终拿到的结果是两个张量
Dataset和DataLoader的分工
确定数据集的长度是通过Dataset的 __len__ 方法实现的
从0 到 n - 1 中抽样m个数是通过DataLoader的sampler和batch_sampler参数指定的
* Sampler菜单树指定单个元素抽样方法,默认在参数shuffle=True 采用随机抽样,shuffle=False 顺序抽样
* batch_sampler 将多个抽样元素整理为列表,默认方法在drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,false时会保留
根据下标取元素是由Dataset的 __getitem__ 方法实现
DataLoader的参数 collate_fn指定,一般情况下无需设置
一般使用方法如下:
1 2 3 4 5 6 7 8 9 10
| import torch from torch.utils.data import TensorDataset,Dataset,DataLoader from torch.utils.data import RandomSampler,BatchSampler
ds = TensorDataset(torch.randn(100,3), torch.randint(low=0,high=2,size=(1000,)).float()) dl = DataLoader(ds,batch_size=4, drop_last = False) features, labels = next(iter(dl)) print("features = ",features) print("labels = ",labels)
|
使用Dataset创建数据集
Dataset 创建数据集常用的方法有:
* torch.utils.data.TensorDatset 根据tensor创建数据集
* 使用torchvision.datasets.ImageFolder 根据图片目录创建图片数据集
* 继承torch.utils.data.Dataset 创建自定义数据集
此外还可以通过:
* torch.utils.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集
* 调用Dataset加法运算符将多个数据集合并成一个数据集
使用Tensor创建数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| import numpy as np import torch from torch.utils.data import TensorDataset,Dataset,DataLoder,random_split from sklearn import datasets iris = datasets.load_iris() ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
n_train = int(len(ds_iris)*0.8) n_val = len(ds_iris) - n_train ds_train,ds_val = random_split(ds_iris,[n_train,n_val])
print(type(ds_iris)) print(type(ds_train))
dl_train,dl_val = DataLoader(ds_train,batch_size=8),Dataloader(ds_val,batch_size=8) for features,labels in dl_train: print(features,labels) break
ds_data = ds_train + ds_val print('len(ds_train) = ',len(ds_train)) print('len(ds_valid) = ',len(ds_val)) print('len(ds_train+ds_valid) = ',len(ds_data))
print(type(ds_data))
|
根据图片创建图片数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| import numpy as np import torch from torch.utils.data import DataLoader from torchvision import transforms,datasets from PIL import Image img = Image.open() img
transformes.RandomVerticalFlip()(img) transforms.RandomRotation(45)(img)
transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(45), transforms.ToTensor() ] )
transform_valid = transforms.Compose([ transforms.ToTensor() ] )
def transform_label(x): return torch.tesnor([x]).float() ds_train = datasets.ImageFolder("/path/to/folder/train",transform=transform_train,target_transform=transform_label) ds_val = datsets.ImageFolder("/path/to/folder",transform=transform_valid,target_transform=transform_label) print(ds_train.class_to_idx)
dl_train = DataLoader(ds_train,batch_size=50,shuffle=True) dl_val = DataLoader(ds_val,batch_size=50,shuffle=True) for features,labels in dl_train: print(features,shape) print(labels.shape) break
|
创建自定义数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| from pathlib import Path from PIL import Image class Cifar2Dataset(Dataset): def __init__(self,imgs_dir,img_transform): self.files = list(Path(imgs_dir).rglob("*.jpg")) self.transform = img_transform def __len__(self,): return len(self.files) def __getitem__(self,i): file_i = str(self.files[i]) img = Image.open(file_i) tensor = self.transform(img) label = torch.tensor([1.0]) if "1_automobile" in file_i else torch.tensor([0.0]) return tensor,label
|
1 2 3 4 5 6 7 8 9 10 11 12 13
| transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(45), transforms.ToTensor() ] )
transform_val = transforms.Compose([ transforms.ToTensor() ] )
|
1 2 3 4 5 6 7 8 9
| ds_train = Cifar2Dataset(train_dir,transfor_train) ds_val = Cifar2Dataset(test_dir,transform_val)
dl_train = DataLoader(ds_train,batch_size = 50,shuffle=True) dl_val = DataLoader(ds_val,batch_size=50,shuffle=True) for features,labels in dl_train: print(features.shape) print(labels.shape) break
|
使用DataLoader加载数据集
DataLoader能够控制batch代销,batch中的采样方法,以及将batch整理成模型输入形式的方法,并且使用多进程读取数据
DataLoader的函数签名如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, )
|
在一般情况下,我们仅仅配置dataset,batch_szie,shuffle,num_workers,pin_memory,drop_last这六个参数
有时候构建复杂数据集,还需要定义collate_fn 函数,其他一般默认值即可。
DataLoader除了可以加载torch.utils.data.Dataset外,还能够加载林一种数据集 torch.utils.data.IterableDatasete
和一般的dataset数据集不同,IterableDataset相当于一种迭代器结构,更加复杂,一般比较少使用
- dataset : 数据集
- batch_size: 批次大小
- shuffle: 是否乱序
- sampler: 样本采样函数,一般无需设置。
- batch_sampler: 批次采样函数,一般无需设置。
- num_workers: 使用多进程读取数据,设置的进程数。
- collate_fn: 整理一个批次数据的函数。
- pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
- drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
- timeout: 加载一个数据批次的最长等待时间,一般无需设置。
- worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。
1 2 3 4 5 6 7 8
| ds = TesnorDataset(torch.arange(1,50)) dl = DataLoader(ds, batch_size = 10, shuffle = True, num_workers = 2, drop_last = True) for batch, in dl: print(batch)
|
哦我草真的有人会面试的时候手搓数据集和训练流程的,我真的是服了