origin
This commit is contained in:
22
digit_depth/dataio/data_loader.py
Normal file
22
digit_depth/dataio/data_loader.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Data loader for the color-normal datasets
|
||||
"""
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from digit_depth.dataio.digit_dataset import DigitRealImageAnnotDataset
|
||||
|
||||
|
||||
def data_loader(dir_dataset, params):
|
||||
"""A data loader for the color-normal datasets
|
||||
Args:
|
||||
dir_dataset: path to the dataset
|
||||
params: a dict of parameters
|
||||
"""
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
dataset = DigitRealImageAnnotDataset( dir_dataset=dir_dataset, annot_file=params.annot_file,
|
||||
transform=transform, annot_flag=params.annot_flag)
|
||||
dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=params.shuffle,
|
||||
num_workers=params.num_workers)
|
||||
|
||||
return dataloader, dataset
|
||||
Reference in New Issue
Block a user