Files
digit-depth/scripts/train_mlp.py
2022-12-29 23:08:25 +08:00

113 lines
3.8 KiB
Python

import argparse
import os
import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm
from digit_depth.train import MLP, Color2NormalDataset
seed = 42
torch.seed = seed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def train(train_loader, epochs, lr):
model = MLP().to(device)
wandb.init(project="MLP", name="Color 2 Normal model train")
wandb.watch(model, log_freq=100)
model.train()
learning_rate = lr
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = epochs
avg_loss=0.0
loss_record=[]
cnt=0
total_step = len(train_loader)
for epoch in tqdm(range(1, 1 + num_epochs)):
for i, (data, labels) in enumerate(train_loader):
# Move tensors to the configured device
data = data.to(device)
labels = labels.to(device)
outputs = model(data)
loss = criterion(outputs, labels)
avg_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
cnt+=1
if (i + 1) % 1 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
loss_record.append(loss.item())
# wandb.log({"Mini-batch loss": loss})
# wandb.log({'Running test loss': avg_loss / cnt})
os.makedirs(f"{base_path}/models", exist_ok=True)
print(f"Saving model to {base_path}/models/")
torch.save(model,
f"{base_path}/models/mlp.ckpt")
def test(test_loader,criterion):
model = torch.load(
f"{base_path}/models/mlp.ckpt").to(
device)
model.eval()
wandb.init(project="MLP", name="Color 2 Normal model test")
wandb.watch(model, log_freq=100)
model.eval()
avg_loss = 0.0
cnt = 0
with torch.no_grad():
for idx, (data, labels) in enumerate(test_loader):
data = data.to(device)
labels = labels.to(device)
outputs = model(data)
loss = criterion(outputs, labels)
avg_loss += loss.item()
cnt=cnt+1
# wandb.log({"Mini-batch test loss": loss})
avg_loss = avg_loss / cnt
print("Test loss: {:.4f}".format(avg_loss))
# wandb.log({'Average Test loss': avg_loss})
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--mode', type=str, default='train', help='train or test')
argparser.add_argument('--batch_size', type=int, default=10000, help='batch size')
argparser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
argparser.add_argument('--epochs', type=int, default=2, help='epochs')
argparser.add_argument('--train_path', type=str, default=f'{base_path}/datasets/train_test_split/train.csv',
help='data path')
argparser.add_argument('--test_path', type=str, default=f'{base_path}/datasets/train_test_split/test.csv',
help='test data path')
option = argparser.parse_args()
if option.mode == "train":
train_set = Color2NormalDataset(
option.train_path)
train_loader = DataLoader(train_set, batch_size=option.batch_size, shuffle=True)
print("Training set size: ", len(train_set))
train(train_loader, option.epochs,option.learning_rate)
elif option.mode == "test":
test_set = Color2NormalDataset(
option.test_path)
test_loader = DataLoader(test_set, batch_size=option.batch_size, shuffle=True)
criterion = nn.MSELoss()
test(test_loader, criterion)
if __name__ == "__main__":
main()