-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathSupervisedDataset.py
49 lines (39 loc) · 1.59 KB
/
SupervisedDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import sys
import cv2
import numpy as np
from imageio import imread
from tqdm import tqdm
import io
import zipfile
import torch
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader as TorchDataLoader
from .BaseDataset import BaseDataset
def readImage(path, shape):
img = np.asarray(imread(path, pilmode='RGB'), np.float32) / 255.0
if img.shape[0] != shape[0] or img.shape[1] != shape[1]: img = cv2.resize(img, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_AREA)
return img.transpose(2, 0, 1)
def readDepth(path, shape):
img = np.load(path)
if img.shape[0] != shape[0] or img.shape[1] != shape[1]: img = cv2.resize(img, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
return img[None, ...]
class SupervisedDataset(BaseDataset):
def __init__(self, dataset_path, mode, shape, **kwargs):
assert os.path.isdir(dataset_path) and mode in ['train', 'val']
super().__init__(**kwargs)
self.shape = shape
with open('%s/%s.txt'%(dataset_path, mode), 'r') as f: lst = [os.path.join(*x.rstrip().split(' ')) for x in f]
rgb_lst = ['%s/%s/color.jpg'%(dataset_path, x) for x in lst]
depth_lst = ['%s/%s/depth.npy'%(dataset_path, x) for x in lst]
self.data = list(zip(rgb_lst, depth_lst))
def __getitem__(self, idx):
rgb_path, depth_path = self.data[idx]
rgb = readImage(rgb_path, self.shape)
depth = readDepth(depth_path, self.shape)
out = {
'idx': idx,
'rgb': rgb,
'depth': depth
}
return out