torchvision.transforms.Compose

PyTorch 的 torchvision 库中的一个类,用于将一系列图像变换操作组合在一起。通过将多个图像变换按顺序应用,可以方便地对数据集进行预处理和增强。

原型
1
class torchvision.transforms.Compose(transforms)
属性说明
  • transforms:包含多个变换操作的列表。这些变换操作会按照提供的顺序依次应用到图像上。
主要方法
实例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torchvision import transforms
from PIL import Image

# 定义一系列变换操作
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])

# 打开图像
image = Image.open('path/to/your/image.jpg')

# 应用变换
transformed_image = transform(image)

# transformed_image 现在是一个经过变换的张量,可以用于训练模型
  • Resize((128, 128)):将图像调整为 128x128 像素
  • RandomHorizontalFlip():以 0.5 的概率随机水平翻转图像,有助于数据增强
  • ToTensor():将 PIL.Imagenumpy.ndarray 转换为张量,并将像素值从 [0, 255] 范围转换到 [0, 1] 范围
  • Normalize(mean, std):使用给定的均值和标准差对图像进行归一化,常用于将数据标准化为标准正态分布,以加速神经网络的训练收敛

torchvision.datasets.***

所有的数据集都是torch.utils.data.Dataset的子类, 它们实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data.DataLoader.

收录数据集
1
2
3
4
5
6
7
8
9
10
11
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
)
属性说明
  • root:数据集所在目录的根目录
  • train:如果为True,则从训练集创建数据集,否则从测试集创建
  • transform:一个接受PIL图像的函数/变换,返回转换后的版本
  • download:默认False,如果为true,则从internet下载数据集 ,将其放在根目录中
  • target_transform:对数据集中的目标标签进行变换。通常在加载图像数据集时,不仅需要对图像本身进行预处理或数据增强(例如使用 transform 参数),还可能需要对与这些图像关联的目标标签进行相应的变换。
实例
1
2
3
4
5
6
7
8
9
from torchvision import datasets, transforms

cifar10_dataset = datasets.CIFAR10(
root='path/to/cifar10',
train=True,
transform=transforms.ToTensor(), # 图像变换
target_transform=square_transform, # 目标变换,例如,如果原始标签是 3,那么变换后的标签将是 9。
download=True
)

torch.utils.data.DataLoader

PyTorch 中的数据加载器,用于批量加载数据集,方便在训练和测试过程中进行数据迭代。它具有多种功能,包括批量处理、数据打乱、多进程数据加载等,使得数据处理更加高效和灵活。

原型
1
2
3
4
5
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
属性说明
  • dataset:要加载的数据集实例。通常是 torchvision.datasets 或自定义数据集。
  • batch_size:每个批次加载的样本数量,默认为1。
  • shuffle:是否在每个 epoch 开始时打乱数据,默认False。
  • sampler:自定义采样器,用于定义从数据集中提取样本的策略。
  • batch_sampler:自定义批次采样器,一次返回一个批次的索引。
  • num_workers:用于数据加载的子进程数量,默认0,0表示数据将在主进程中加载。
  • ……
实例
1
data_loader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)