origin
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
15
tests/test_digit.py
Normal file
15
tests/test_digit.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import unittest
|
||||
|
||||
from digit_interface import Digit
|
||||
|
||||
from digit_depth import DigitSensor
|
||||
|
||||
|
||||
class TestDigit(unittest.TestCase):
|
||||
def test_digit_sensor(self):
|
||||
fps = 30
|
||||
resolution = "QVGA"
|
||||
serial_num = "12345"
|
||||
digit_sensor = DigitSensor(fps, resolution, serial_num)
|
||||
digit = digit_sensor()
|
||||
self.assertIsInstance(digit, Digit)
|
||||
16
tests/test_handlers.py
Normal file
16
tests/test_handlers.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import unittest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from digit_depth.handlers import image
|
||||
|
||||
class Handler(unittest.TestCase):
|
||||
"""Test for various data handlers"""
|
||||
def test_tensor_to_PIL(self):
|
||||
instance = image.ImageHandler(Image.open("/home/shuk/digit-depth/images/0001.png"), "RGB")
|
||||
tensor = torch.randn(1, 3, 224, 224)
|
||||
pil_image = instance.tensor_to_PIL()
|
||||
self.assertEqual(pil_image.size, (224, 224))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
26
tests/test_train.py
Normal file
26
tests/test_train.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
import unittest
|
||||
import torch
|
||||
from digit_depth.train import MLP, Color2NormalDataset
|
||||
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
class Train(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
model = MLP()
|
||||
x = torch.randn(1, 5)
|
||||
y = model(x)
|
||||
self.assertEqual(torch.Size([1, 3]), y.size())
|
||||
|
||||
def test_dataset(self):
|
||||
dataset = Color2NormalDataset(f'{base_path}/datasets/train_test_split/train.csv')
|
||||
x, y = dataset[0]
|
||||
self.assertEqual(torch.Size([5]), x.size())
|
||||
self.assertEqual(torch.Size([3]), y.size())
|
||||
self.assertLessEqual(x.max(), 1)
|
||||
self.assertGreaterEqual(x.min(), 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user