PyTorch:数据集/数据处理
torchvision.transforms.Compose
PyTorch 的 torchvision 库中的一个类,用于将一系列图像变换操作组合在一起。通过将多个图像变换按顺序应用,可以方便地对数据集进行预处理和增强。
原型
1 | class torchvision.transforms.Compose(transforms) |
属性说明
transforms
:包含多个变换操作的列表。这些变换操作会按照提供的顺序依次应用到图像上。
主要方法
实例
1 | from torchvision import transforms |
- Resize((128, 128)):将图像调整为 128x128 像素
- RandomHorizontalFlip():以 0.5 的概率随机水平翻转图像,有助于数据增强
- ToTensor():将
PIL.Image
或numpy.ndarray
转换为张量,并将像素值从 [0, 255] 范围转换到 [0, 1] 范围 - Normalize(mean, std):使用给定的均值和标准差对图像进行归一化,常用于将数据标准化为标准正态分布,以加速神经网络的训练收敛
torchvision.datasets.***
所有的数据集都是torch.utils.data.Dataset
的子类, 它们实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data.DataLoader
.
收录数据集
1 | __all__ = ('LSUN', 'LSUNClass', |
属性说明
- root:数据集所在目录的根目录
- train:如果为True,则从训练集创建数据集,否则从测试集创建
- transform:一个接受PIL图像的函数/变换,返回转换后的版本
- download:默认False,如果为true,则从internet下载数据集 ,将其放在根目录中
- target_transform:对数据集中的目标标签进行变换。通常在加载图像数据集时,不仅需要对图像本身进行预处理或数据增强(例如使用
transform
参数),还可能需要对与这些图像关联的目标标签进行相应的变换。
实例
1 | from torchvision import datasets, transforms |
torch.utils.data.DataLoader
PyTorch 中的数据加载器,用于批量加载数据集,方便在训练和测试过程中进行数据迭代。它具有多种功能,包括批量处理、数据打乱、多进程数据加载等,使得数据处理更加高效和灵活。
原型
1 | DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, |
属性说明
- dataset:要加载的数据集实例。通常是
torchvision.datasets
或自定义数据集。 - batch_size:每个批次加载的样本数量,默认为1。
- shuffle:是否在每个 epoch 开始时打乱数据,默认False。
- sampler:自定义采样器,用于定义从数据集中提取样本的策略。
- batch_sampler:自定义批次采样器,一次返回一个批次的索引。
- num_workers:用于数据加载的子进程数量,默认0,
0
表示数据将在主进程中加载。 - ……
实例
1 | data_loader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2) |
评论