44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import glob
|
|
|
|
import pandas as pd
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class DigitRealImageAnnotDataset(Dataset):
|
|
def __init__( self, dir_dataset, annot_file, transform=None, annot_flag=True, img_type="png" ):
|
|
self.dir_dataset = dir_dataset
|
|
print(f"Loading dataset from {dir_dataset}")
|
|
self.transform = transform
|
|
self.annot_flag = annot_flag
|
|
|
|
# a list of image paths sorted. dir_dataset is the root dir of the datasets (color)
|
|
self.img_files = sorted(glob.glob(f"{self.dir_dataset}/*.{img_type}"))
|
|
print(f"Found {len(self.img_files)} images")
|
|
if self.annot_flag:
|
|
self.annot_dataframe = pd.read_csv(annot_file, sep=",")
|
|
|
|
def __getitem__(self, idx):
|
|
"""Returns a tuple of (img, annot) where annot is a tensor of shape (3,1)"""
|
|
|
|
# read in image
|
|
img = Image.open(self.img_files[idx])
|
|
img = self.transform(img)
|
|
img = img.permute(0, 2, 1) # (3,240,320) -> (3,320,240)
|
|
# read in region annotations
|
|
if self.annot_flag:
|
|
img_name = self.img_files[idx]
|
|
row_filter = self.annot_dataframe["img_names"] == img_name
|
|
region_attr = self.annot_dataframe.loc[
|
|
row_filter, ["center_x", "center_y", "radius"]
|
|
]
|
|
annot = (torch.tensor(region_attr.values, dtype=torch.int32) if (len(region_attr) > 0) else torch.tensor([]))
|
|
data = img
|
|
if self.annot_flag:
|
|
data = (img, annot)
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.img_files)
|