Skip to content

Commit

Permalink
Port models/pcn_cd.py to PyTorch.
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaimorariu committed Jul 2, 2020
1 parent 68e9d98 commit 6610a8e
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 0 deletions.
15 changes: 15 additions & 0 deletions io_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

import numpy as np
from open3d import *


def read_pcd(filename):
pcd = read_point_cloud(filename)
return np.array(pcd.points)


def save_pcd(filename, points):
pcd = PointCloud()
pcd.points = Vector3dVector(points)
write_point_cloud(filename, pcd)
31 changes: 31 additions & 0 deletions pytorch_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import argparse
import numpy as np
import open3d as o3d
import torch

from common.visual import plot_side_by_side
from pytorch_models.pcn_cd import Model


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', default='demo_data/car.pcd')
parser.add_argument('--model_type', default='pcn_cd')
parser.add_argument('--checkpoint', default='data/trained_models/pcn_cd')
parser.add_argument('--num_gt_points', type=int, default=16384)
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_arguments()

# device = torch.cuda.current_device()
model = Model()
# model.to(device)

partial = o3d.io.read_point_cloud(args.input_path)
partial = torch.Tensor(partial.points).transpose(1, 0)[None, :]
complete = model(input=partial, npts=[partial.shape[0]])

plot_side_by_side(partial, complete)
Empty file added pytorch_models/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions pytorch_models/fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

import tensorflow as tf
from tf_util import mlp, mlp_conv, chamfer, add_train_summary, add_valid_summary


class Model:
def __init__(self, inputs, gt, alpha):
self.num_output_points = 16384
self.features = self.create_encoder(inputs)
self.outputs = self.create_decoder(self.features)
self.loss, self.update = self.create_loss(self.outputs, gt)
self.visualize_ops = [inputs[0], self.outputs[0], gt[0]]
self.visualize_titles = ['input', 'output', 'ground truth']

def create_encoder(self, inputs):
with tf.variable_scope('encoder_0', reuse=tf.AUTO_REUSE):
features = mlp_conv(inputs, [128, 256])
features_global = tf.reduce_max(features, axis=1, keep_dims=True, name='maxpool_0')
features = tf.concat([features, tf.tile(features_global, [1, tf.shape(inputs)[1], 1])], axis=2)
with tf.variable_scope('encoder_1', reuse=tf.AUTO_REUSE):
features = mlp_conv(features, [512, 1024])
features = tf.reduce_max(features, axis=1, name='maxpool_1')
return features

def create_decoder(self, features):
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
outputs = mlp(features, [1024, 1024, self.num_output_points * 3])
outputs = tf.reshape(outputs, [-1, self.num_output_points, 3])
return outputs

def create_loss(self, outputs, gt):
loss = chamfer(outputs, gt)
add_train_summary('train/loss', loss)
update_loss = add_valid_summary('valid/loss', loss)
return loss, update_loss
47 changes: 47 additions & 0 deletions pytorch_models/folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

import tensorflow as tf
from tf_util import mlp_conv, chamfer, add_train_summary, add_valid_summary


class Model:
def __init__(self, inputs, gt, alpha):
self.grid_size = 128
self.grid_scale = 0.5
self.num_output_points = 16384
self.features = self.create_encoder(inputs)
fold1, fold2 = self.create_decoder(self.features)
self.outputs = fold2
self.loss, self.update = self.create_loss(self.outputs, gt)
self.visualize_ops = [inputs[0], fold1[0], fold2[0], gt[0]]
self.visualize_titles = ['input', '1st folding', '2nd folding', 'ground truth']

def create_encoder(self, inputs):
with tf.variable_scope('encoder_0', reuse=tf.AUTO_REUSE):
features = mlp_conv(inputs, [128, 256])
features_global = tf.reduce_max(features, axis=1, keep_dims=True, name='maxpool_0')
features = tf.concat([features, tf.tile(features_global, [1, tf.shape(inputs)[1], 1])], axis=2)
with tf.variable_scope('encoder_1', reuse=tf.AUTO_REUSE):
features = mlp_conv(features, [512, 1024])
features = tf.reduce_max(features, axis=1, name='maxpool_1')
return features

def create_decoder(self, features):
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
x = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
y = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
grid = tf.meshgrid(x, y)
grid = tf.reshape(tf.stack(grid, axis=2), [-1, 2])
grid = tf.tile(tf.expand_dims(grid, 0), [features.shape[0], 1, 1])
features = tf.tile(tf.expand_dims(features, 1), [1, self.num_output_points, 1])
with tf.variable_scope('folding_1'):
fold1 = mlp_conv(tf.concat([features, grid], axis=2), [512, 512, 3])
with tf.variable_scope('folding_2'):
fold2 = mlp_conv(tf.concat([features, fold1], axis=2), [512, 512, 3])
return fold1, fold2

def create_loss(self, outputs, gt):
loss = chamfer(outputs, gt)
add_train_summary('train/loss', loss)
update_loss = add_valid_summary('valid/loss', loss)
return loss, update_loss
80 changes: 80 additions & 0 deletions pytorch_models/pcn_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pytorch_tf_util import *


class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.num_coarse = 1024
self.grid_size = 4
self.grid_scale = 0.05
self.num_fine = self.grid_size ** 2 * self.num_coarse

self.mlp_conv1 = MLPConv(in_channels=3, layer_dims=[128, 256])
self.mlp_conv2 = MLPConv(in_channels=512, layer_dims=[512, 1024])
self.mlp_conv3 = MLPConv(in_channels=1029, layer_dims=[512, 512, 3])
self.mlp = MLP(in_channels=1024, layer_dims=[1024, 1024, self.num_coarse * 3])

def __call__(self, input, npts, gt=None, alpha=None):
features = self.encode(input, npts)
coarse, fine = self.decode(features)

if gt is not None:
assert alpha is not None
loss = self.loss(coarse, fine, gt, alpha)
return loss

return coarse, fine

def encode(self, inputs, npts):
features = self.mlp_conv1(inputs)
features_global = point_unpool(point_maxpool(features, npts, keepdim=True), npts)
features_global = features_global.repeat(1, 1, features.shape[2])
features = torch.cat([features, features_global], dim=1)
features = self.mlp_conv2(features)
features = point_maxpool(features, npts)
return features

def decode(self, features):
coarse = self.mlp(features)
coarse = torch.reshape(coarse, [-1, self.num_coarse, 3])

x = torch.linspace(-0.05, 0.05, self.grid_size)
y = torch.linspace(-0.05, 0.05, self.grid_size)

grid = torch.meshgrid(x, y)
grid = torch.unsqueeze(torch.reshape(torch.stack(grid, axis=2), [-1, 2]), 0)
grid_feat = grid.repeat(features.shape[0], self.num_coarse, 1)

point_feat = torch.unsqueeze(coarse, 2)
point_feat = point_feat.repeat(1, 1, self.grid_size ** 2, 1)
point_feat = torch.reshape(point_feat, [-1, self.num_fine, 3])

global_feat = torch.unsqueeze(features, 1)
global_feat = global_feat.repeat(1, self.num_fine, 1)

feat = torch.cat([grid_feat, point_feat, global_feat], axis=2)

center = torch.unsqueeze(coarse, 2)
center = center.repeat(1, 1, self.grid_size ** 2, 1)
center = torch.reshape(center, [-1, 3, self.num_fine])

feat = feat.transpose(1, 2)
fine = self.mlp_conv3(feat)
fine = fine + center

return coarse, fine

def loss(self, coarse, fine, gt, alpha):
loss_coarse = chamfer(coarse, gt)
# add_train_summary('train/coarse_loss', loss_coarse)
# update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)
#
loss_fine = chamfer(fine, gt)
# add_train_summary('train/fine_loss', loss_fine)
# update_fine = add_valid_summary('valid/fine_loss', loss_fine)
#
loss = loss_coarse + alpha * loss_fine
# add_train_summary('train/loss', loss)
# update_loss = add_valid_summary('valid/loss', loss)
#
return loss
69 changes: 69 additions & 0 deletions pytorch_models/pcn_emd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

import tensorflow as tf
from tf_util import *


class Model:
def __init__(self, inputs, npts, gt, alpha):
self.num_coarse = 1024
self.grid_size = 4
self.grid_scale = 0.05
self.num_fine = self.grid_size ** 2 * self.num_coarse
self.features = self.create_encoder(inputs, npts)
self.coarse, self.fine = self.create_decoder(self.features)
self.loss, self.update = self.create_loss(self.coarse, self.fine, gt, alpha)
self.outputs = self.fine
self.visualize_ops = [tf.split(inputs[0], npts, axis=0), self.coarse, self.fine, gt]
self.visualize_titles = ['input', 'coarse output', 'fine output', 'ground truth']

def create_encoder(self, inputs, npts):
with tf.variable_scope('encoder_0', reuse=tf.AUTO_REUSE):
features = mlp_conv(inputs, [128, 256])
features_global = point_unpool(point_maxpool(features, npts, keepdims=True), npts)
features = tf.concat([features, features_global], axis=2)
with tf.variable_scope('encoder_1', reuse=tf.AUTO_REUSE):
features = mlp_conv(features, [512, 1024])
features = point_maxpool(features, npts)
return features

def create_decoder(self, features):
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
coarse = mlp(features, [1024, 1024, self.num_coarse * 3])
coarse = tf.reshape(coarse, [-1, self.num_coarse, 3])

with tf.variable_scope('folding', reuse=tf.AUTO_REUSE):
x = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
y = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
grid = tf.meshgrid(x, y)
grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]), 0)
grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1])

point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1])
point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3])

global_feat = tf.tile(tf.expand_dims(features, 1), [1, self.num_fine, 1])

feat = tf.concat([grid_feat, point_feat, global_feat], axis=2)

center = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1])
center = tf.reshape(center, [-1, self.num_fine, 3])

fine = mlp_conv(feat, [512, 512, 3]) + center
return coarse, fine

def create_loss(self, coarse, fine, gt, alpha):
gt_ds = gt[:, :coarse.shape[1], :]
loss_coarse = earth_mover(coarse, gt_ds)
add_train_summary('train/coarse_loss', loss_coarse)
update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

loss_fine = chamfer(fine, gt)
add_train_summary('train/fine_loss', loss_fine)
update_fine = add_valid_summary('valid/fine_loss', loss_fine)

loss = loss_coarse + alpha * loss_fine
add_train_summary('train/loss', loss)
update_loss = add_valid_summary('valid/loss', loss)

return loss, [update_coarse, update_fine, update_loss]
81 changes: 81 additions & 0 deletions pytorch_tf_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import tensorflow as tf
import torch

from pc_distance import tf_nndistance, tf_approxmatch


class MLP(torch.nn.Module):
def __init__(self, in_channels, layer_dims):
super(MLP, self).__init__()
self.layer_dims = layer_dims
self.seqs = []

for _, out_channels in enumerate(layer_dims):
self.seqs.append(torch.nn.Sequential(
torch.nn.Linear(in_channels, out_channels),
torch.nn.ReLU()
))
in_channels = out_channels

def __call__(self, features):
output = features
for seq in self.seqs:
output = seq(output)
return output


class MLPConv(torch.nn.Module):
def __init__(self, in_channels, layer_dims):
super(MLPConv, self).__init__()
self.layer_dims = layer_dims
self.conv1s = []

for _, out_channels in enumerate(layer_dims):
self.conv1s.append(torch.nn.Conv1d(in_channels, out_channels, kernel_size=1))
in_channels = out_channels

def __call__(self, input):
output = input
for conv1 in self.conv1s:
output = conv1(output)
return output


def point_maxpool(input, npts, keepdim=False):
split_size = [input.shape[2] // npt for npt in npts]
output = [f.max(dim=2, keepdims=keepdim).values
for f in torch.split(input, split_size, dim=2)]
output = torch.cat(output, dim=0)
return output


def point_unpool(input, npts):
input = torch.split(input, input.shape[0], dim=0)
output = [f.repeat(1, npts[i], 1) for i, f in enumerate(input)]
output = torch.cat(output, dim=1)
return output


def chamfer(pcd1, pcd2):
dist1, _, dist2, _ = tf_nndistance.nn_distance(pcd1, pcd2)
dist1 = torch.mean(torch.sqrt(dist1))
dist2 = torch.mean(torch.sqrt(dist2))
return (dist1 + dist2) / 2


def earth_mover(pcd1, pcd2):
assert pcd1.shape[1] == pcd2.shape[1]
num_points = pcd1.shape[1].to(torch.float32)
match = tf_approxmatch.approx_match(pcd1, pcd2)
cost = tf_approxmatch.match_cost(pcd1, pcd2, match)
return toch.mean(cost / num_points)


def add_train_summary(name, value):
tf.summary.scalar(name, value, collections=['train_summary'])


def add_valid_summary(name, value):
avg, update = tf.metrics.mean(value)
tf.summary.scalar(name, avg, collections=['valid_summary'])
return update
28 changes: 28 additions & 0 deletions visu_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def plot_pcd_three_views(filename, pcds, titles, suptitle='', sizes=None, cmap='Reds', zdir='y',
xlim=(-0.3, 0.3), ylim=(-0.3, 0.3), zlim=(-0.3, 0.3)):
if sizes is None:
sizes = [0.5 for i in range(len(pcds))]
fig = plt.figure(figsize=(len(pcds) * 3, 9))
for i in range(3):
elev = 30
azim = -45 + 90 * i
for j, (pcd, size) in enumerate(zip(pcds, sizes)):
color = pcd[:, 0]
ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d')
ax.view_init(elev, azim)
ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], zdir=zdir, c=color, s=size, cmap=cmap, vmin=-1, vmax=0.5)
ax.set_title(titles[j])
ax.set_axis_off()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, wspace=0.1, hspace=0.1)
plt.suptitle(suptitle)
fig.savefig(filename)
plt.close(fig)

0 comments on commit 6610a8e

Please sign in to comment.