Dataset与Dataloader

介绍

在pytorch中,Dataset 和 DataLoader 是 PyTorch 中的类,它们通常被用于数据处理和加载,是 torch.utils.data 模块的重要组成部分。

dataset和dataloader已经作为两个类封装在库中,需要的时候直接导入使用即可:

image-20250226202148473

如何导入

1
2
from torch.utils.data import Dataset # 导入dataset
from torch.utils.data import DataLoader # 导入dataloader

Dataset与Dataloader的区别

Dataset:

dataset相当与一个箱子,箱子里装有我们需要在一个项目中使用的数据,包括训练数据、验证数据和测试数据。因此,当我们定义dataset的时候,也就相当于我们定义了一个包含我们自己的数据集的大箱子,并且在这个箱子中,会根据不同的数据集结构拥有相应的存储结构。

Dataloader

有了存储数据的dataset,当我们使用数据集的时候,我们就需要dataloader作为一个搬运工将数据搬运出来,供我们后续的加载与使用(训练或测试)。而在搬运的过程中,我们也可以根据自己的需求自定义操作,例如在图像数据集中,我们可以选择一次取出多少张图片,可以对图像进行打乱顺序的操作,可以对图像进行旋转裁剪等操作。

加载已经封装在Pytorch中的数据集

以加载MINIST数据集为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)

test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)

加载自定义数据集

在加载自定义数据集时,我们不能直接从torch中download,而是需要自己根据数据集的形式自定义dataset封装数据。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class UavDataset(Dataset): # 构建数据集
def __init__(self, data_path, label_path=None,
random_choose=False, random_shift=False, random_move=False,p_interval=1,
window_size=-1, normalization=False, debug=False, use_mmap=True,is_test=False,random_rot=False,d=3):
"""
:param data_path: 数据文件路径
:param label_path: 标签文件路径
:param random_choose: 如果为 True,则随机选择输入序列的一部分
:param random_shift: 如果为 True,则在序列的开头或结尾随机填充零
:param random_move: 如果为 True,则在数据中随机移动
:param window_size: 输出序列的长度
:param normalization: 如果为 True,则对输入序列进行归一化
:param debug: 如果为 True,则只使用前 100 个样本
:param use_mmap: 如果为 True,则使用 mmap 模式加载数据,以节省内存
"""

self.is_test = is_test
self.debug = debug
self.data_path = data_path
self.label_path = label_path
self.random_choose = random_choose
self.random_shift = random_shift
self.random_move = random_move
self.window_size = window_size
self.normalization = normalization
self.use_mmap = use_mmap
self.p_interval = p_interval
self.random_rot = random_rot
self.d = d
self.load_data()

if normalization:
self.get_mean_map()



def __len__(self):
return len(self.data)

def __iter__(self):
return self

def __getitem__(self, index):
data_numpy = self.data[index]
data_numpy = np.array(data_numpy)

if self.d == 2:
data_numpy = data_numpy[[0,2],:,:,:]
elif self.d == 4:
data_numpy = data_numpy[[0,2,4,6],:,:,:]
else:
data_numpy = data_numpy

if self.random_rot:
data_numpy = random_rot(data_numpy, theta=0.3,d=self.d)

if self.normalization:
data_numpy = (data_numpy - self.mean_map) / self.std_map
if self.random_shift:
data_numpy = tools.random_shift(data_numpy)
if self.random_choose:
data_numpy = tools.random_choose(data_numpy, self.window_size)
# elif self.window_size > 0:
# data_numpy = tools.auto_pading(data_numpy, self.window_size)
if self.random_move:
data_numpy = tools.random_move(data_numpy)

使用自定义Dataset的时候需要注意的问题

对于dataset这个类需要重写内置方法时,有三个方法必须包含:

  1. _init_ 不必多说
  2. _len_ 数据集长度
  3. _getitem_ 按索引取出其中一个数据(切片)