1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| class UavDataset(Dataset): # 构建数据集 def __init__(self, data_path, label_path=None, random_choose=False, random_shift=False, random_move=False,p_interval=1, window_size=-1, normalization=False, debug=False, use_mmap=True,is_test=False,random_rot=False,d=3): """ :param data_path: 数据文件路径 :param label_path: 标签文件路径 :param random_choose: 如果为 True,则随机选择输入序列的一部分 :param random_shift: 如果为 True,则在序列的开头或结尾随机填充零 :param random_move: 如果为 True,则在数据中随机移动 :param window_size: 输出序列的长度 :param normalization: 如果为 True,则对输入序列进行归一化 :param debug: 如果为 True,则只使用前 100 个样本 :param use_mmap: 如果为 True,则使用 mmap 模式加载数据,以节省内存 """
self.is_test = is_test self.debug = debug self.data_path = data_path self.label_path = label_path self.random_choose = random_choose self.random_shift = random_shift self.random_move = random_move self.window_size = window_size self.normalization = normalization self.use_mmap = use_mmap self.p_interval = p_interval self.random_rot = random_rot self.d = d self.load_data()
if normalization: self.get_mean_map()
def __len__(self): return len(self.data)
def __iter__(self): return self
def __getitem__(self, index): data_numpy = self.data[index] data_numpy = np.array(data_numpy)
if self.d == 2: data_numpy = data_numpy[[0,2],:,:,:] elif self.d == 4: data_numpy = data_numpy[[0,2,4,6],:,:,:] else: data_numpy = data_numpy
if self.random_rot: data_numpy = random_rot(data_numpy, theta=0.3,d=self.d)
if self.normalization: data_numpy = (data_numpy - self.mean_map) / self.std_map if self.random_shift: data_numpy = tools.random_shift(data_numpy) if self.random_choose: data_numpy = tools.random_choose(data_numpy, self.window_size) # elif self.window_size > 0: # data_numpy = tools.auto_pading(data_numpy, self.window_size) if self.random_move: data_numpy = tools.random_move(data_numpy)
|