Files
digit-depth/tests/test_train.py
2022-12-29 23:08:25 +08:00

27 lines
738 B
Python

import os
import unittest
import torch
from digit_depth.train import MLP, Color2NormalDataset
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
class Train(unittest.TestCase):
def test_shape(self):
model = MLP()
x = torch.randn(1, 5)
y = model(x)
self.assertEqual(torch.Size([1, 3]), y.size())
def test_dataset(self):
dataset = Color2NormalDataset(f'{base_path}/datasets/train_test_split/train.csv')
x, y = dataset[0]
self.assertEqual(torch.Size([5]), x.size())
self.assertEqual(torch.Size([3]), y.size())
self.assertLessEqual(x.max(), 1)
self.assertGreaterEqual(x.min(), 0)
if __name__ == '__main__':
unittest.main()