origin
This commit is contained in:
1
digit_depth/handlers/__init__.py
Normal file
1
digit_depth/handlers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .image import ImageHandler
|
||||
38
digit_depth/handlers/image.py
Normal file
38
digit_depth/handlers/image.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class ImageHandler:
|
||||
def __init__(self, img_path, convert="RGB"):
|
||||
self.img = Image.open(img_path).convert(convert)
|
||||
self.convert = convert
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_PIL(self, img_tensor):
|
||||
img_tensor = img_tensor.squeeze_(0)
|
||||
return transforms.ToPILImage()(img_tensor).convert(self.convert)
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
return transforms.ToTensor()(self.img).unsqueeze_(0)
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
return self.img
|
||||
|
||||
@property
|
||||
def nparray(self):
|
||||
return np.array(self.img)
|
||||
|
||||
@staticmethod
|
||||
def save(file_name, img):
|
||||
if isinstance(img, Image.Image):
|
||||
# this is a PIL image
|
||||
img.save(file_name)
|
||||
else:
|
||||
# cv2 image
|
||||
cv2.imwrite(file_name, img)
|
||||
Reference in New Issue
Block a user