This commit is contained in:
wxchen
2022-12-29 23:08:25 +08:00
commit 21ad625896
42 changed files with 2336 additions and 0 deletions

View File

@@ -0,0 +1,43 @@
import glob
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
class DigitRealImageAnnotDataset(Dataset):
def __init__( self, dir_dataset, annot_file, transform=None, annot_flag=True, img_type="png" ):
self.dir_dataset = dir_dataset
print(f"Loading dataset from {dir_dataset}")
self.transform = transform
self.annot_flag = annot_flag
# a list of image paths sorted. dir_dataset is the root dir of the datasets (color)
self.img_files = sorted(glob.glob(f"{self.dir_dataset}/*.{img_type}"))
print(f"Found {len(self.img_files)} images")
if self.annot_flag:
self.annot_dataframe = pd.read_csv(annot_file, sep=",")
def __getitem__(self, idx):
"""Returns a tuple of (img, annot) where annot is a tensor of shape (3,1)"""
# read in image
img = Image.open(self.img_files[idx])
img = self.transform(img)
img = img.permute(0, 2, 1) # (3,240,320) -> (3,320,240)
# read in region annotations
if self.annot_flag:
img_name = self.img_files[idx]
row_filter = self.annot_dataframe["img_names"] == img_name
region_attr = self.annot_dataframe.loc[
row_filter, ["center_x", "center_y", "radius"]
]
annot = (torch.tensor(region_attr.values, dtype=torch.int32) if (len(region_attr) > 0) else torch.tensor([]))
data = img
if self.annot_flag:
data = (img, annot)
return data
def __len__(self):
return len(self.img_files)