origin
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
103
scripts/create_image_dataset.py
Normal file
103
scripts/create_image_dataset.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Creates color+normal datasets based on annotation file.
|
||||
The datasets can be used to train MLP, CycleGAN, Pix2Pix models.
|
||||
"""
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from digit_depth.dataio.create_csv import (combine_csv, create_pixel_csv,
|
||||
create_train_test_csv)
|
||||
from digit_depth.dataio.data_loader import data_loader
|
||||
from digit_depth.dataio.generate_sphere_gt_normals import generate_sphere_gt_normals
|
||||
from digit_depth.third_party import data_utils
|
||||
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
@hydra.main(config_path=f"{base_path}/config", config_name="rgb_to_normal.yaml", version_base=None)
|
||||
def main(cfg):
|
||||
normal_dataloader, normal_dataset = data_loader(
|
||||
dir_dataset=os.path.join(base_path, "images"), params=cfg.dataloader
|
||||
)
|
||||
dirs = [
|
||||
f"{base_path}/datasets/A/imgs",
|
||||
f"{base_path}/datasets/B/imgs",
|
||||
f"{base_path}/datasets/A/csv",
|
||||
f"{base_path}/datasets/B/csv",
|
||||
f"{base_path}/datasets/train_test_split",
|
||||
]
|
||||
for dir in dirs:
|
||||
print(f"Creating directory: {dir}")
|
||||
os.makedirs(f"{dir}", exist_ok=True)
|
||||
# iterate over images
|
||||
img_idx = 0
|
||||
radius_bearing = np.int32(0.5 * 6.0 * cfg.mm_to_pixel)
|
||||
while img_idx < len(normal_dataset):
|
||||
# read img + annotations
|
||||
data = normal_dataset[img_idx]
|
||||
if cfg.dataloader.annot_flag:
|
||||
img, annot = data
|
||||
if annot.shape[0] == 0:
|
||||
img_idx = img_idx + 1
|
||||
continue
|
||||
else:
|
||||
img = data
|
||||
|
||||
# get annotation circle params
|
||||
if cfg.dataloader.annot_flag:
|
||||
annot_np = annot.cpu().detach().numpy()
|
||||
center_y, center_x, radius_annot = (
|
||||
annot_np[0][1],
|
||||
annot_np[0][0],
|
||||
annot_np[0][2],
|
||||
)
|
||||
else:
|
||||
center_y, center_x, radius_annot = 0, 0, 0
|
||||
|
||||
img_color_np = (img.permute(2, 1, 0).cpu().detach().numpy()) # (3,320,240) -> (240,320,3)
|
||||
|
||||
# apply foreground mask
|
||||
fg_mask = np.zeros(img_color_np.shape[:2], dtype="uint8")
|
||||
fg_mask = cv2.circle(fg_mask, (center_x, center_y), radius_annot, 255, -1)
|
||||
|
||||
# 1. rgb -> normal (generate gt surface normals)
|
||||
img_mask = cv2.bitwise_and(img_color_np, img_color_np, mask=fg_mask)
|
||||
img_normal_np = generate_sphere_gt_normals(
|
||||
img_mask, center_x, center_y, radius=radius_bearing
|
||||
)
|
||||
|
||||
# 2. downsample and convert to NumPy: (320,240,3) -> (160,120,3)
|
||||
img_normal_np = data_utils.interpolate_img(
|
||||
img=torch.tensor(img_normal_np).permute(2, 0, 1), rows=160, cols=120)
|
||||
img_normal_np = img_normal_np.permute(1, 2, 0).cpu().detach().numpy()
|
||||
img_color_ds = data_utils.interpolate_img(
|
||||
img=torch.tensor(img_color_np).permute(2, 0, 1), rows=160, cols=120)
|
||||
img_color_np = img_color_ds.permute(1, 2, 0).cpu().detach().numpy()
|
||||
|
||||
# 3. save csv files for color and normal images
|
||||
|
||||
if cfg.dataset.save_dataset:
|
||||
imageio.imwrite(
|
||||
f"{dirs[0]}/{img_idx:04d}.png", (img_color_np * 255).astype(np.uint8)
|
||||
)
|
||||
imageio.imwrite(f"{dirs[1]}/{img_idx:04d}.png", (img_normal_np*255).astype(np.uint8))
|
||||
print(f"Saved image {img_idx:04d}")
|
||||
img_idx += 1
|
||||
|
||||
# post-process CSV files and create train/test split
|
||||
create_pixel_csv( img_dir=dirs[0], save_dir=dirs[2], img_type="color")
|
||||
create_pixel_csv(img_dir=dirs[1], save_dir=dirs[3], img_type="normal")
|
||||
combine_csv(dirs[2],img_type="color")
|
||||
combine_csv(dirs[3],img_type="normal")
|
||||
create_train_test_csv(color_path=f'{dirs[2]}/combined.csv',
|
||||
normal_path=f'{dirs[3]}/combined.csv',
|
||||
save_dir=dirs[4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
scripts/depth.py
Normal file
84
scripts/depth.py
Normal file
@@ -0,0 +1,84 @@
|
||||
""" Publishes a ROS topic with name /depth/compressed and shows the image on OpenCV window.
|
||||
Issues: rqt_image_view is not showing the image due to some data conversion issues but OpenCV is showing the image."""
|
||||
import os
|
||||
import cv2
|
||||
import hydra
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from cv_bridge import CvBridge
|
||||
from sensor_msgs.msg import CompressedImage
|
||||
from digit_depth.third_party import geom_utils
|
||||
from digit_depth.digit import DigitSensor
|
||||
from digit_depth.train import MLP
|
||||
from digit_depth.train.prepost_mlp import *
|
||||
from PIL import Image
|
||||
seed = 42
|
||||
torch.seed = seed
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
class ImageFeature:
|
||||
def __init__(self):
|
||||
# topic where we publish
|
||||
|
||||
self.image_pub = rospy.Publisher("/depth/compressed",
|
||||
CompressedImage, queue_size=10)
|
||||
self.br = CvBridge()
|
||||
|
||||
|
||||
@hydra.main(config_path="/home/shuk/digit-depth/config", config_name="rgb_to_normal.yaml", version_base=None)
|
||||
def show_depth(cfg):
|
||||
model = torch.load(cfg.model_path).to(device)
|
||||
model.eval()
|
||||
ic = ImageFeature()
|
||||
br = CvBridge()
|
||||
rospy.init_node('depth_node', anonymous=True)
|
||||
# base image depth map
|
||||
base_img = cv2.imread(cfg.base_img_path)
|
||||
base_img = preproc_mlp(base_img)
|
||||
base_img_proc = model(base_img).cpu().detach().numpy()
|
||||
base_img_proc, normal_base = post_proc_mlp(base_img_proc)
|
||||
# get gradx and grady
|
||||
gradx_base, grady_base = geom_utils._normal_to_grad_depth(img_normal=base_img_proc, gel_width=cfg.sensor.gel_width,
|
||||
gel_height=cfg.sensor.gel_height, bg_mask=None)
|
||||
|
||||
# reconstruct depth
|
||||
img_depth_base = geom_utils._integrate_grad_depth(gradx_base, grady_base, boundary=None, bg_mask=None,
|
||||
max_depth=0.0237)
|
||||
img_depth_base = img_depth_base.detach().cpu().numpy() # final depth image for base image
|
||||
# setup digit sensor
|
||||
digit = DigitSensor(30, "QVGA", cfg.sensor.serial_num)
|
||||
digit_call = digit()
|
||||
while not rospy.is_shutdown():
|
||||
frame = digit_call.get_frame()
|
||||
img_np = preproc_mlp(frame)
|
||||
img_np = model(img_np).detach().cpu().numpy()
|
||||
img_np, normal_img = post_proc_mlp(img_np)
|
||||
# get gradx and grady
|
||||
gradx_img, grady_img = geom_utils._normal_to_grad_depth(img_normal=img_np, gel_width=cfg.sensor.gel_width,
|
||||
gel_height=cfg.sensor.gel_height,bg_mask=None)
|
||||
# reconstruct depth
|
||||
img_depth = geom_utils._integrate_grad_depth(gradx_img, grady_img, boundary=None, bg_mask=None,max_depth=0.0237)
|
||||
img_depth = img_depth.detach().cpu().numpy() # final depth image for current image
|
||||
# get depth difference
|
||||
depth_diff = (img_depth - img_depth_base)*500
|
||||
# cv2.imshow("depth", depth_diff)
|
||||
img_depth[img_depth == 0.0237] = 0
|
||||
img_depth[img_depth != 0] = (img_depth[img_depth != 0]-0.0237)*(-1)
|
||||
img_depth = img_depth*1000
|
||||
cv2.imshow("depth", depth_diff)
|
||||
msg = br.cv2_to_compressed_imgmsg(img_depth, "png")
|
||||
ic.image_pub.publish(msg)
|
||||
now = rospy.get_rostime()
|
||||
rospy.loginfo("published depth image at {}".format(now))
|
||||
cv2.waitKey(1)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rospy.loginfo("starting...")
|
||||
show_depth()
|
||||
|
||||
69
scripts/label_data.py
Normal file
69
scripts/label_data.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Labels images for training MLP depth reconstruction model.
|
||||
Specify the image folder containing the circle images and csv folder to store the labels ( img_names, center_x, center_y, radius ).
|
||||
The image datasets should include the rolling of a sphere with a known radius.
|
||||
|
||||
Directions:
|
||||
-- Click left mouse button to select the center of the sphere.
|
||||
-- Click right mouse button to select the circumference of the sphere.
|
||||
-- Double click ESC to move to the next image.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
|
||||
import cv2
|
||||
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def click_and_store(event, x, y, flags, param):
|
||||
global count
|
||||
global center_x, center_y, circumference_x, circumference_y, radii
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
center_x = x
|
||||
center_y = y
|
||||
print("center_x: ", x)
|
||||
print("center_y: ", y)
|
||||
cv2.circle(image, (x, y), 3, (0, 0, 255), -1)
|
||||
cv2.imshow("image", image)
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
circumference_x = x
|
||||
circumference_y = y
|
||||
print("circumference_x: ", x)
|
||||
print("circumference_y: ", y)
|
||||
cv2.circle(image, (x, y), 3, (0, 0, 255), -1)
|
||||
cv2.imshow("image", image)
|
||||
radius = math.sqrt(
|
||||
(center_x - circumference_x) ** 2 + (center_y - circumference_y) ** 2
|
||||
)
|
||||
print("radius: ", int(radius))
|
||||
radii.append(int(radius))
|
||||
with open(filename, "a") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
print("Writing>>")
|
||||
count += 1
|
||||
writer.writerow([img_name, center_x, center_y, int(radius)])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--folder", type=str, default="images", help="folder containing images")
|
||||
argparser.add_argument("--csv", type=str, default=f"{base_path}/csv/annotate.csv", help="csv file to store results")
|
||||
args = argparser.parse_args()
|
||||
filename = args.csv
|
||||
img_folder = os.path.join(base_path, args.folder)
|
||||
img_files = sorted(glob.glob(f"{img_folder}/*.png"))
|
||||
os.makedirs(os.path.join(base_path, "csv"), exist_ok=True)
|
||||
center_x, center_y, circumference_x, circumference_y, radii = [], [], [], [], []
|
||||
count = 0
|
||||
for img in img_files:
|
||||
image = cv2.imread(img)
|
||||
img_name = img
|
||||
cv2.imshow("image", image)
|
||||
cv2.setMouseCallback("image", click_and_store, image)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
65
scripts/point_cloud.py
Normal file
65
scripts/point_cloud.py
Normal file
@@ -0,0 +1,65 @@
|
||||
""" Publishes a ROS topic with name /depth/compressed and shows the image on OpenCV window.
|
||||
Issues: rqt_image_view is not showing the image due to some data conversion issues but OpenCV is showing the image."""
|
||||
import os
|
||||
import hydra
|
||||
import open3d as o3d
|
||||
from digit_depth.third_party import geom_utils
|
||||
from digit_depth.digit import DigitSensor
|
||||
from digit_depth.train import MLP
|
||||
from digit_depth.train.prepost_mlp import *
|
||||
from attrdict import AttrDict
|
||||
from digit_depth.third_party import vis_utils
|
||||
seed = 42
|
||||
torch.seed = seed
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
@hydra.main(config_path="/home/shuk/digit-depth/config", config_name="rgb_to_normal.yaml", version_base=None)
|
||||
def show_point_cloud(cfg):
|
||||
view_params = AttrDict({'fov': 60, 'front': [-0.1, 0.1, 0.1], 'lookat': [
|
||||
-0.001, -0.01, 0.01], 'up': [0.04, -0.05, 0.190], 'zoom': 2.5})
|
||||
vis3d = vis_utils.Visualizer3d(base_path=base_path, view_params=view_params)
|
||||
|
||||
# projection params
|
||||
T_cam_offset = torch.tensor(cfg.sensor.T_cam_offset)
|
||||
proj_mat = torch.tensor(cfg.sensor.P)
|
||||
model = torch.load(cfg.model_path).to(device)
|
||||
model.eval()
|
||||
# base image depth map
|
||||
base_img = cv2.imread(cfg.base_img_path)
|
||||
base_img = preproc_mlp(base_img)
|
||||
base_img_proc = model(base_img).cpu().detach().numpy()
|
||||
base_img_proc, normal_base = post_proc_mlp(base_img_proc)
|
||||
# get gradx and grady
|
||||
gradx_base, grady_base = geom_utils._normal_to_grad_depth(img_normal=base_img_proc, gel_width=cfg.sensor.gel_width,
|
||||
gel_height=cfg.sensor.gel_height, bg_mask=None)
|
||||
|
||||
# reconstruct depth
|
||||
img_depth_base = geom_utils._integrate_grad_depth(gradx_base, grady_base, boundary=None, bg_mask=None,
|
||||
max_depth=0.0237)
|
||||
img_depth_base = img_depth_base.detach().cpu().numpy() # final depth image for base image
|
||||
# setup digit sensor
|
||||
digit = DigitSensor(30, "QVGA", cfg.sensor.serial_num)
|
||||
digit_call = digit()
|
||||
while True:
|
||||
frame = digit_call.get_frame()
|
||||
img_np = preproc_mlp(frame)
|
||||
img_np = model(img_np).detach().cpu().numpy()
|
||||
img_np, normal_img = post_proc_mlp(img_np)
|
||||
# get gradx and grady
|
||||
gradx_img, grady_img = geom_utils._normal_to_grad_depth(img_normal=img_np, gel_width=cfg.sensor.gel_width,
|
||||
gel_height=cfg.sensor.gel_height,bg_mask=None)
|
||||
# reconstruct depth
|
||||
img_depth = geom_utils._integrate_grad_depth(gradx_img, grady_img, boundary=None, bg_mask=None, max_depth=0.0237)
|
||||
view_mat = torch.eye(4) # torch.inverse(T_cam_offset)
|
||||
# Project depth to 3D
|
||||
points3d = geom_utils.depth_to_pts3d(depth=img_depth, P=proj_mat, V=view_mat, params=cfg.sensor)
|
||||
points3d = geom_utils.remove_background_pts(points3d, bg_mask=None)
|
||||
cloud = o3d.geometry.PointCloud()
|
||||
clouds = geom_utils.init_points_to_clouds(clouds=[copy.deepcopy(cloud)], points3d=[points3d])
|
||||
vis_utils.visualize_geometries_o3d(vis3d=vis3d, clouds=clouds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
show_point_cloud()
|
||||
49
scripts/record.py
Normal file
49
scripts/record.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Script for capturing individual frames while the camera output is displayed.
|
||||
-- Press SPACEBAR to capture
|
||||
-- Press ESC to terminate the program.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import os.path
|
||||
|
||||
import cv2
|
||||
|
||||
from digit_depth.digit.digit_sensor import DigitSensor
|
||||
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def record_frame(digit_sensor, dir_name: str):
|
||||
img_counter = len(os.listdir(dir_name))
|
||||
digit_call = digit_sensor()
|
||||
while True:
|
||||
frame = digit_call.get_frame()
|
||||
cv2.imshow("Capture Frame", frame)
|
||||
k = cv2.waitKey(1)
|
||||
if k % 256 == 27:
|
||||
# ESC hit
|
||||
print("Escape hit, closing...")
|
||||
break
|
||||
elif k % 256 == 32:
|
||||
# SPACEBAR hit
|
||||
img_name = "{}/{:0>4}.png".format(dir_name, img_counter)
|
||||
cv2.imwrite(img_name, frame)
|
||||
print("{} written!".format(img_name))
|
||||
img_counter += 1
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--fps", type=int, default=30, help="Frames per second. Max:60 on QVGA")
|
||||
argparser.add_argument("--resolution", type=str, default="QVGA", help="QVGA, VGA")
|
||||
argparser.add_argument("--serial_num", type=str, default="D00001", help="Serial number of DIGIT")
|
||||
args = argparser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.join(base_path, "images")):
|
||||
os.makedirs(os.path.join(base_path, "images"), exist_ok=True)
|
||||
print("Directory {} created for saving images".format(os.path.join(base_path, "images")))
|
||||
digit = DigitSensor(args.fps, args.resolution, args.serial_num)
|
||||
|
||||
record_frame(digit, os.path.join(base_path, "images"))
|
||||
54
scripts/ros/depth_value_pub.py
Normal file
54
scripts/ros/depth_value_pub.py
Normal file
@@ -0,0 +1,54 @@
|
||||
""" Node to publish max depth value when gel is deformed """
|
||||
import os
|
||||
import hydra
|
||||
import rospy
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
from digit_depth.third_party import geom_utils
|
||||
from digit_depth.digit.digit_sensor import DigitSensor
|
||||
from digit_depth.train.mlp_model import MLP
|
||||
from digit_depth.train.prepost_mlp import *
|
||||
seed = 42
|
||||
torch.seed = seed
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
BASE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
@hydra.main(config_path=f"{BASE_PATH}/config", config_name="rgb_to_normal.yaml", version_base=None)
|
||||
def print_depth(cfg):
|
||||
model = torch.load(cfg.model_path).to(device)
|
||||
model.eval()
|
||||
# setup digit sensor
|
||||
digit = DigitSensor(30, "QVGA", cfg.sensor.serial_num)
|
||||
digit_call = digit()
|
||||
pub = rospy.Publisher('chatter', Float32, queue_size=1)
|
||||
rospy.init_node('depth', anonymous=True)
|
||||
try:
|
||||
while not rospy.is_shutdown():
|
||||
frame = digit_call.get_frame()
|
||||
img_np = preproc_mlp(frame)
|
||||
img_np = model(img_np).detach().cpu().numpy()
|
||||
img_np, normal_img = post_proc_mlp(img_np)
|
||||
|
||||
# get gradx and grady
|
||||
gradx_img, grady_img = geom_utils._normal_to_grad_depth(img_normal=img_np, gel_width=cfg.sensor.gel_width,
|
||||
gel_height=cfg.sensor.gel_height,bg_mask=None)
|
||||
# reconstruct depth
|
||||
img_depth = geom_utils._integrate_grad_depth(gradx_img, grady_img, boundary=None, bg_mask=None,max_depth=0.0237)
|
||||
img_depth = img_depth.detach().cpu().numpy().flatten()
|
||||
|
||||
# get max depth value
|
||||
max_depth = np.min(img_depth)
|
||||
rospy.loginfo(f"max:{max_depth}")
|
||||
img_depth_calibrated = np.abs((max_depth - 0.02362))
|
||||
|
||||
# publish max depth value
|
||||
pub.publish(Float32(img_depth_calibrated*10000)) # convert to mm
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down")
|
||||
digit().disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rospy.loginfo("starting...")
|
||||
print_depth()
|
||||
52
scripts/ros/digit_image_pub.py
Normal file
52
scripts/ros/digit_image_pub.py
Normal file
@@ -0,0 +1,52 @@
|
||||
""" ROS image publisher for DIGIT sensor """
|
||||
|
||||
import argparse
|
||||
# OpenCV
|
||||
import cv2
|
||||
from PIL import Image as Im
|
||||
from sensor_msgs.msg import Image
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
# Ros libraries
|
||||
import roslib
|
||||
import rospy
|
||||
|
||||
# Ros Messages
|
||||
from sensor_msgs.msg import CompressedImage
|
||||
from sensor_msgs.msg import std_msgs
|
||||
from digit_depth.digit.digit_sensor import DigitSensor
|
||||
|
||||
|
||||
class ImageFeature:
|
||||
|
||||
def __init__(self):
|
||||
# topic where we publish
|
||||
|
||||
self.image_pub = rospy.Publisher("/output/image_raw/compressed",
|
||||
CompressedImage, queue_size=10)
|
||||
self.br = CvBridge()
|
||||
|
||||
|
||||
def rgb_pub(digit_sensor: DigitSensor):
|
||||
# Initializes and cleanup ros node
|
||||
ic = ImageFeature()
|
||||
rospy.init_node('image_feature', anonymous=True)
|
||||
digit_call = digit_sensor()
|
||||
br = CvBridge()
|
||||
while True:
|
||||
frame = digit_call.get_frame()
|
||||
msg = br.cv2_to_compressed_imgmsg(frame, "png")
|
||||
ic.image_pub.publish(msg)
|
||||
rospy.loginfo("published ...")
|
||||
if cv2.waitKey(1) == 27:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--fps", type=int, default=30)
|
||||
argparser.add_argument("--resolution", type=str, default="QVGA")
|
||||
argparser.add_argument("--serial_num", type=str, default="D00001")
|
||||
args, unknown = argparser.parse_known_args()
|
||||
digit = DigitSensor(args.fps, args.resolution, args.serial_num)
|
||||
rgb_pub(digit)
|
||||
112
scripts/train_mlp.py
Normal file
112
scripts/train_mlp.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from digit_depth.train import MLP, Color2NormalDataset
|
||||
|
||||
seed = 42
|
||||
torch.seed = seed
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
base_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def train(train_loader, epochs, lr):
|
||||
model = MLP().to(device)
|
||||
wandb.init(project="MLP", name="Color 2 Normal model train")
|
||||
wandb.watch(model, log_freq=100)
|
||||
|
||||
model.train()
|
||||
|
||||
learning_rate = lr
|
||||
# Loss and optimizer
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
num_epochs = epochs
|
||||
avg_loss=0.0
|
||||
loss_record=[]
|
||||
cnt=0
|
||||
total_step = len(train_loader)
|
||||
for epoch in tqdm(range(1, 1 + num_epochs)):
|
||||
for i, (data, labels) in enumerate(train_loader):
|
||||
# Move tensors to the configured device
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = model(data)
|
||||
loss = criterion(outputs, labels)
|
||||
avg_loss += loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
cnt+=1
|
||||
|
||||
if (i + 1) % 1 == 0:
|
||||
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
|
||||
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
|
||||
loss_record.append(loss.item())
|
||||
# wandb.log({"Mini-batch loss": loss})
|
||||
# wandb.log({'Running test loss': avg_loss / cnt})
|
||||
os.makedirs(f"{base_path}/models", exist_ok=True)
|
||||
print(f"Saving model to {base_path}/models/")
|
||||
torch.save(model,
|
||||
f"{base_path}/models/mlp.ckpt")
|
||||
|
||||
|
||||
def test(test_loader,criterion):
|
||||
model = torch.load(
|
||||
f"{base_path}/models/mlp.ckpt").to(
|
||||
device)
|
||||
model.eval()
|
||||
wandb.init(project="MLP", name="Color 2 Normal model test")
|
||||
wandb.watch(model, log_freq=100)
|
||||
model.eval()
|
||||
avg_loss = 0.0
|
||||
cnt = 0
|
||||
with torch.no_grad():
|
||||
for idx, (data, labels) in enumerate(test_loader):
|
||||
data = data.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model(data)
|
||||
loss = criterion(outputs, labels)
|
||||
avg_loss += loss.item()
|
||||
cnt=cnt+1
|
||||
# wandb.log({"Mini-batch test loss": loss})
|
||||
avg_loss = avg_loss / cnt
|
||||
print("Test loss: {:.4f}".format(avg_loss))
|
||||
# wandb.log({'Average Test loss': avg_loss})
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--mode', type=str, default='train', help='train or test')
|
||||
argparser.add_argument('--batch_size', type=int, default=10000, help='batch size')
|
||||
argparser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
|
||||
argparser.add_argument('--epochs', type=int, default=2, help='epochs')
|
||||
argparser.add_argument('--train_path', type=str, default=f'{base_path}/datasets/train_test_split/train.csv',
|
||||
help='data path')
|
||||
argparser.add_argument('--test_path', type=str, default=f'{base_path}/datasets/train_test_split/test.csv',
|
||||
help='test data path')
|
||||
option = argparser.parse_args()
|
||||
|
||||
if option.mode == "train":
|
||||
train_set = Color2NormalDataset(
|
||||
option.train_path)
|
||||
train_loader = DataLoader(train_set, batch_size=option.batch_size, shuffle=True)
|
||||
print("Training set size: ", len(train_set))
|
||||
train(train_loader, option.epochs,option.learning_rate)
|
||||
elif option.mode == "test":
|
||||
test_set = Color2NormalDataset(
|
||||
option.test_path)
|
||||
test_loader = DataLoader(test_set, batch_size=option.batch_size, shuffle=True)
|
||||
criterion = nn.MSELoss()
|
||||
test(test_loader, criterion)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user