23 lines
829 B
Python
23 lines
829 B
Python
"""
|
|
Data loader for the color-normal datasets
|
|
"""
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
from digit_depth.dataio.digit_dataset import DigitRealImageAnnotDataset
|
|
|
|
|
|
def data_loader(dir_dataset, params):
|
|
"""A data loader for the color-normal datasets
|
|
Args:
|
|
dir_dataset: path to the dataset
|
|
params: a dict of parameters
|
|
"""
|
|
transform = transforms.Compose([transforms.ToTensor()])
|
|
dataset = DigitRealImageAnnotDataset( dir_dataset=dir_dataset, annot_file=params.annot_file,
|
|
transform=transform, annot_flag=params.annot_flag)
|
|
dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=params.shuffle,
|
|
num_workers=params.num_workers)
|
|
|
|
return dataloader, dataset
|