datasets.ImageFolder是PyTorch供给的一个预界说数据集类,用于处理图画数据。它能够方便地将一组图画加载到内存中,并为每个图画分配标签。

  1. 数据集预备和目录结构

要运用datasets.ImageFolder,咱们需求预备好一个包括图画数据的目录,并依照以下方法进行安排:

root/
    class1/
        img1.jpg
        img2.jpg
        ...
    class2/
        img1.jpg
        img2.jpg
        ...
    ...

其中,root代表数据集根目录,class1、class2等代表不同的分类标签,img1、img2等代表图画文件名。每个类别(也称为标签)应该有一个单独的子目录,子目录中包括这个类别的一切图画文件。一起,每个图画文件在对应的子目录下,以其文件名作为其类别标签。这种目录安排方法能够让咱们轻松获取图画和对应的标签信息。

  1. 加载数据集

完结数据集预备之后,咱们就能够运用datasets.ImageFolder来加载它了。下面是一个示例代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = "/path/to/data"
transforms = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transforms)

在这个比如中,咱们首要导入datasets和transforms模块,然后指定数据集的根目录data_dir。接下来,咱们界说一个 transforms 对象,它将图画转换为PyTorch张量,并调整巨细为(224, 224)。

最终,咱们运用datasets.ImageFolder来加载图画数据集。ImageFolder类需求两个参数:root 和 transform。root是数据集根目录;transform指定对每个图画应该履行的预处理操作,例如调整巨细、裁剪、翻转等。

  1. 数据集区分

对于机器学习任务,咱们一般需求将数据集区分红练习集、验证集和测验集。在PyTorch中,咱们能够运用torch.utils.data.random_split函数来完结数据集的区分。下面是一个示例代码:

from torch.utils.data import DataLoader, random_split
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# Split train dataset into train and validation sets
val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

在这个比如中,咱们先运用random_split函数将原始数据集区分为练习集和测验集,在这里80%的数据用于练习,20%的数据用于测验。然后,咱们再次运用random_split函数将练习集区分为练习集和验证集,其中80%的数据用于练习,20%的数据用于验证。

  1. 数据加载器

最终,咱们能够运用数据加载器(DataLoader)来加载数据集。数据加载器负责将图画数据和标签封装成批量,并供给线程方法加载数据以加速练习过程。下面是一个示例代码:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

在这里,咱们创建了三个数据加载器train_loader、val_loader 和 test_loader,它们别离对应练习集、验证集和测验集。batch_size参数指定了每个批次的巨细,shuffle参数表明是否随机化输入数据(在练习会集设置为True,在验证集和测验会集设置为False)。