diff --git a/pc_distance/__init__.py b/pc_distance/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/pc_distance/makefile b/pc_distance/makefile new file mode 100644 index 0000000..50608bd --- /dev/null +++ b/pc_distance/makefile @@ -0,0 +1,26 @@ +cuda_inc = /opt/cuda/include/ +cuda_lib = /opt/cuda/lib64/ +nvcc = /opt/cuda/bin/nvcc +tf_lib = /home/mmorariu/.local/lib/python3.8/site-packages/tensorflow +tf_inc = $(tf_lib)/include + +all: tf_nndistance_so.so tf_approxmatch_so.so + +tf_nndistance.cu.o: tf_nndistance.cu + $(nvcc) tf_nndistance.cu -o tf_nndistance.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC + +tf_nndistance_so.so: tf_nndistance.cpp tf_nndistance.cu.o + g++ tf_nndistance.cpp tf_nndistance.cu.o -o tf_nndistance_so.so \ + -I $(cuda_inc) -I $(tf_inc) -L $(cuda_lib) -lcudart -L $(tf_lib)/libtensorflow_framework.so.2 \ + -shared -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 -fPIC -O2 + +tf_approxmatch.cu.o: tf_approxmatch.cu + $(nvcc) tf_approxmatch.cu -o tf_approxmatch.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC + +tf_approxmatch_so.so: tf_approxmatch.cpp tf_approxmatch.cu.o + g++ -shared $(CPPFLAGS) tf_approxmatch.cpp tf_approxmatch.cu.o -o tf_approxmatch_so.so \ + -I $(cuda_inc) -I $(tf_inc) -L $(cuda_lib) -lcudart -L $(tf_lib)/libtensorflow_framework.so.2 \ + -shared -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 -fPIC -O2 + +clean: + rm -rf *.o *.so diff --git a/pc_distance/tf_approxmatch.cpp b/pc_distance/tf_approxmatch.cpp new file mode 100755 index 0000000..e12ffa9 --- /dev/null +++ b/pc_distance/tf_approxmatch.cpp @@ -0,0 +1,329 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include +#include +#include +using namespace tensorflow; +REGISTER_OP("ApproxMatch") + .Input("xyz1: float32") + .Input("xyz2: float32") + .Output("match: float32"); +REGISTER_OP("MatchCost") + .Input("xyz1: float32") + .Input("xyz2: float32") + .Input("match: float32") + .Output("cost: float32"); +REGISTER_OP("MatchCostGrad") + .Input("xyz1: float32") + .Input("xyz2: float32") + .Input("match: float32") + .Output("grad1: float32") + .Output("grad2: float32"); + +void approxmatch_cpu(int b,int n,int m,const float * xyz1,const float * xyz2,float * match){ + for (int i=0;i saturatedl(n,double(factorl)),saturatedr(m,double(factorr)); + std::vector weight(n*m); + for (int j=0;j=-2;j--){ + //printf("i=%d j=%d\n",i,j); + double level=-powf(4.0,j); + if (j==-2) + level=0; + for (int k=0;k ss(m,1e-9); + for (int k=0;k ss2(m,0); + for (int k=0;kinput(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + //OP_REQUIRES(context,n<=4096,errors::InvalidArgument("ApproxMatch handles at most 4096 dataset points")); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + //OP_REQUIRES(context,m<=1024,errors::InvalidArgument("ApproxMatch handles at most 1024 query points")); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + Tensor * match_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,n},&match_tensor)); + auto match_flat=match_tensor->flat(); + float * match=&(match_flat(0)); + Tensor temp_tensor; + OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{b,(n+m)*2},&temp_tensor)); + auto temp_flat=temp_tensor.flat(); + float * temp=&(temp_flat(0)); + approxmatchLauncher(b,n,m,xyz1,xyz2,match,temp); + } +}; +REGISTER_KERNEL_BUILDER(Name("ApproxMatch").Device(DEVICE_GPU), ApproxMatchGpuOp); +class ApproxMatchOp: public OpKernel{ + public: + explicit ApproxMatchOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + //OP_REQUIRES(context,n<=4096,errors::InvalidArgument("ApproxMatch handles at most 4096 dataset points")); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ApproxMatch expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + //OP_REQUIRES(context,m<=1024,errors::InvalidArgument("ApproxMatch handles at most 1024 query points")); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + Tensor * match_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,n},&match_tensor)); + auto match_flat=match_tensor->flat(); + float * match=&(match_flat(0)); + approxmatch_cpu(b,n,m,xyz1,xyz2,match); + } +}; +REGISTER_KERNEL_BUILDER(Name("ApproxMatch").Device(DEVICE_CPU), ApproxMatchOp); +class MatchCostGpuOp: public OpKernel{ + public: + explicit MatchCostGpuOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + + const Tensor& match_tensor=context->input(2); + OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); + auto match_flat=match_tensor.flat(); + const float * match=&(match_flat(0)); + + Tensor * cost_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b},&cost_tensor)); + auto cost_flat=cost_tensor->flat(); + float * cost=&(cost_flat(0)); + matchcostLauncher(b,n,m,xyz1,xyz2,match,cost); + } +}; +REGISTER_KERNEL_BUILDER(Name("MatchCost").Device(DEVICE_GPU), MatchCostGpuOp); +class MatchCostOp: public OpKernel{ + public: + explicit MatchCostOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + + const Tensor& match_tensor=context->input(2); + OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); + auto match_flat=match_tensor.flat(); + const float * match=&(match_flat(0)); + + Tensor * cost_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b},&cost_tensor)); + auto cost_flat=cost_tensor->flat(); + float * cost=&(cost_flat(0)); + matchcost_cpu(b,n,m,xyz1,xyz2,match,cost); + } +}; +REGISTER_KERNEL_BUILDER(Name("MatchCost").Device(DEVICE_CPU), MatchCostOp); + +class MatchCostGradGpuOp: public OpKernel{ + public: + explicit MatchCostGradGpuOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCostGrad expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCostGrad expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + + const Tensor& match_tensor=context->input(2); + OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); + auto match_flat=match_tensor.flat(); + const float * match=&(match_flat(0)); + + Tensor * grad1_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad1_tensor)); + auto grad1_flat=grad1_tensor->flat(); + float * grad1=&(grad1_flat(0)); + Tensor * grad2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad2_tensor)); + auto grad2_flat=grad2_tensor->flat(); + float * grad2=&(grad2_flat(0)); + matchcostgradLauncher(b,n,m,xyz1,xyz2,match,grad1,grad2); + } +}; +REGISTER_KERNEL_BUILDER(Name("MatchCostGrad").Device(DEVICE_GPU), MatchCostGradGpuOp); +class MatchCostGradOp: public OpKernel{ + public: + explicit MatchCostGradOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + OP_REQUIRES(context,xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz1 shape")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&(xyz1_flat(0)); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3 && xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("MatchCost expects (batch_size,num_points,3) xyz2 shape, and batch_size must match")); + int m=xyz2_tensor.shape().dim_size(1); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&(xyz2_flat(0)); + + const Tensor& match_tensor=context->input(2); + OP_REQUIRES(context,match_tensor.dims()==3 && match_tensor.shape().dim_size(0)==b && match_tensor.shape().dim_size(1)==m && match_tensor.shape().dim_size(2)==n,errors::InvalidArgument("MatchCost expects (batch_size,#query,#dataset) match shape")); + auto match_flat=match_tensor.flat(); + const float * match=&(match_flat(0)); + + Tensor * grad1_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad1_tensor)); + auto grad1_flat=grad1_tensor->flat(); + float * grad1=&(grad1_flat(0)); + Tensor * grad2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad2_tensor)); + auto grad2_flat=grad2_tensor->flat(); + float * grad2=&(grad2_flat(0)); + matchcostgrad_cpu(b,n,m,xyz1,xyz2,match,grad1,grad2); + } +}; +REGISTER_KERNEL_BUILDER(Name("MatchCostGrad").Device(DEVICE_CPU), MatchCostGradOp); diff --git a/pc_distance/tf_approxmatch.cu b/pc_distance/tf_approxmatch.cu new file mode 100755 index 0000000..33c8e26 --- /dev/null +++ b/pc_distance/tf_approxmatch.cu @@ -0,0 +1,296 @@ +__global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ + float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + float multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ float buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + float level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); +} +__global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ + __shared__ float allsum[512]; + const int Block=1024; + __shared__ float buf[Block*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); +} +__global__ void matchcostgrad2(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ + __shared__ float sum_grad[256*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); + matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); +} + diff --git a/pc_distance/tf_approxmatch.py b/pc_distance/tf_approxmatch.py new file mode 100755 index 0000000..2bc8853 --- /dev/null +++ b/pc_distance/tf_approxmatch.py @@ -0,0 +1,108 @@ +import tensorflow as tf +from tensorflow.python.framework import ops +import os.path as osp + +base_dir = osp.dirname(osp.abspath(__file__)) + +approxmatch_module = tf.load_op_library(osp.join(base_dir, 'tf_approxmatch_so.so')) + + +def approx_match(xyz1,xyz2): + ''' +input: + xyz1 : batch_size * #dataset_points * 3 + xyz2 : batch_size * #query_points * 3 +returns: + match : batch_size * #query_points * #dataset_points + ''' + return approxmatch_module.approx_match(xyz1,xyz2) +ops.NoGradient('ApproxMatch') +#@tf.RegisterShape('ApproxMatch') +@ops.RegisterShape('ApproxMatch') +def _approx_match_shape(op): + shape1=op.inputs[0].get_shape().with_rank(3) + shape2=op.inputs[1].get_shape().with_rank(3) + return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[1]])] + +def match_cost(xyz1,xyz2,match): + ''' +input: + xyz1 : batch_size * #dataset_points * 3 + xyz2 : batch_size * #query_points * 3 + match : batch_size * #query_points * #dataset_points +returns: + cost : batch_size + ''' + return approxmatch_module.match_cost(xyz1,xyz2,match) +#@tf.RegisterShape('MatchCost') +@ops.RegisterShape('MatchCost') +def _match_cost_shape(op): + shape1=op.inputs[0].get_shape().with_rank(3) + shape2=op.inputs[1].get_shape().with_rank(3) + shape3=op.inputs[2].get_shape().with_rank(3) + return [tf.TensorShape([shape1.dims[0]])] +@tf.RegisterGradient('MatchCost') +def _match_cost_grad(op,grad_cost): + xyz1=op.inputs[0] + xyz2=op.inputs[1] + match=op.inputs[2] + grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) + return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] + +if __name__=='__main__': + alpha=0.5 + beta=2.0 + import bestmatch + import numpy as np + import math + import random + import cv2 + + import tf_nndistance + + npoint=100 + + with tf.device('/gpu:2'): + pt_in=tf.placeholder(tf.float32,shape=(1,npoint*4,3)) + mypoints=tf.Variable(np.random.randn(1,npoint,3).astype('float32')) + match=approx_match(pt_in,mypoints) + loss=tf.reduce_sum(match_cost(pt_in,mypoints,match)) + #match=approx_match(mypoints,pt_in) + #loss=tf.reduce_sum(match_cost(mypoints,pt_in,match)) + #distf,_,distb,_=tf_nndistance.nn_distance(pt_in,mypoints) + #loss=tf.reduce_sum((distf+1e-9)**0.5)*0.5+tf.reduce_sum((distb+1e-9)**0.5)*0.5 + #loss=tf.reduce_max((distf+1e-9)**0.5)*0.5*npoint+tf.reduce_max((distb+1e-9)**0.5)*0.5*npoint + + optimizer=tf.train.GradientDescentOptimizer(1e-4).minimize(loss) + with tf.Session('') as sess: + sess.run(tf.initialize_all_variables()) + while True: + meanloss=0 + meantrueloss=0 + for i in xrange(1001): + #phi=np.random.rand(4*npoint)*math.pi*2 + #tpoints=(np.hstack([np.cos(phi)[:,None],np.sin(phi)[:,None],(phi*0)[:,None]])*random.random())[None,:,:] + #tpoints=((np.random.rand(400)-0.5)[:,None]*[0,2,0]+[(random.random()-0.5)*2,0,0]).astype('float32')[None,:,:] + tpoints=np.hstack([np.linspace(-1,1,400)[:,None],(random.random()*2*np.linspace(1,0,400)**2)[:,None],np.zeros((400,1))])[None,:,:] + trainloss,_=sess.run([loss,optimizer],feed_dict={pt_in:tpoints.astype('float32')}) + trainloss,trainmatch=sess.run([loss,match],feed_dict={pt_in:tpoints.astype('float32')}) + #trainmatch=trainmatch.transpose((0,2,1)) + show=np.zeros((400,400,3),dtype='uint8')^255 + trainmypoints=sess.run(mypoints) + for i in xrange(len(tpoints[0])): + u=np.random.choice(range(len(trainmypoints[0])),p=trainmatch[0].T[i]) + cv2.line(show, + (int(tpoints[0][i,1]*100+200),int(tpoints[0][i,0]*100+200)), + (int(trainmypoints[0][u,1]*100+200),int(trainmypoints[0][u,0]*100+200)), + cv2.cv.CV_RGB(0,255,0)) + for x,y,z in tpoints[0]: + cv2.circle(show,(int(y*100+200),int(x*100+200)),2,cv2.cv.CV_RGB(255,0,0)) + for x,y,z in trainmypoints[0]: + cv2.circle(show,(int(y*100+200),int(x*100+200)),3,cv2.cv.CV_RGB(0,0,255)) + cost=((tpoints[0][:,None,:]-np.repeat(trainmypoints[0][None,:,:],4,axis=1))**2).sum(axis=2)**0.5 + #trueloss=bestmatch.bestmatch(cost)[0] + print(trainloss) #,trueloss + cv2.imshow('show',show) + cmd=cv2.waitKey(10)%256 + if cmd==ord('q'): + break diff --git a/pc_distance/tf_nndistance.cpp b/pc_distance/tf_nndistance.cpp new file mode 100755 index 0000000..46b0c60 --- /dev/null +++ b/pc_distance/tf_nndistance.cpp @@ -0,0 +1,254 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +REGISTER_OP("NnDistance") + .Input("xyz1: float32") + .Input("xyz2: float32") + .Output("dist1: float32") + .Output("idx1: int32") + .Output("dist2: float32") + .Output("idx2: int32"); +REGISTER_OP("NnDistanceGrad") + .Input("xyz1: float32") + .Input("xyz2: float32") + .Input("grad_dist1: float32") + .Input("idx1: int32") + .Input("grad_dist2: float32") + .Input("idx2: int32") + .Output("grad_xyz1: float32") + .Output("grad_xyz2: float32"); +using namespace tensorflow; + +static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ + for (int i=0;iinput(0); + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); + int m=xyz2_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&xyz1_flat(0); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&xyz2_flat(0); + Tensor * dist1_tensor=NULL; + Tensor * idx1_tensor=NULL; + Tensor * dist2_tensor=NULL; + Tensor * idx2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); + auto dist1_flat=dist1_tensor->flat(); + auto idx1_flat=idx1_tensor->flat(); + OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); + OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); + auto dist2_flat=dist2_tensor->flat(); + auto idx2_flat=idx2_tensor->flat(); + float * dist1=&(dist1_flat(0)); + int * idx1=&(idx1_flat(0)); + float * dist2=&(dist2_flat(0)); + int * idx2=&(idx2_flat(0)); + nnsearch(b,n,m,xyz1,xyz2,dist1,idx1); + nnsearch(b,m,n,xyz2,xyz1,dist2,idx2); + } +}; +REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp); +class NnDistanceGradOp : public OpKernel{ + public: + explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + const Tensor& xyz2_tensor=context->input(1); + const Tensor& grad_dist1_tensor=context->input(2); + const Tensor& idx1_tensor=context->input(3); + const Tensor& grad_dist2_tensor=context->input(4); + const Tensor& idx2_tensor=context->input(5); + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); + int m=xyz2_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); + OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); + OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); + OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); + OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&xyz1_flat(0); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&xyz2_flat(0); + auto idx1_flat=idx1_tensor.flat(); + const int * idx1=&idx1_flat(0); + auto idx2_flat=idx2_tensor.flat(); + const int * idx2=&idx2_flat(0); + auto grad_dist1_flat=grad_dist1_tensor.flat(); + const float * grad_dist1=&grad_dist1_flat(0); + auto grad_dist2_flat=grad_dist2_tensor.flat(); + const float * grad_dist2=&grad_dist2_flat(0); + Tensor * grad_xyz1_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); + Tensor * grad_xyz2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); + auto grad_xyz1_flat=grad_xyz1_tensor->flat(); + float * grad_xyz1=&grad_xyz1_flat(0); + auto grad_xyz2_flat=grad_xyz2_tensor->flat(); + float * grad_xyz2=&grad_xyz2_flat(0); + for (int i=0;iinput(0); + const Tensor& xyz2_tensor=context->input(1); + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); + int m=xyz2_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&xyz1_flat(0); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&xyz2_flat(0); + Tensor * dist1_tensor=NULL; + Tensor * idx1_tensor=NULL; + Tensor * dist2_tensor=NULL; + Tensor * idx2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); + auto dist1_flat=dist1_tensor->flat(); + auto idx1_flat=idx1_tensor->flat(); + OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); + OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); + auto dist2_flat=dist2_tensor->flat(); + auto idx2_flat=idx2_tensor->flat(); + float * dist1=&(dist1_flat(0)); + int * idx1=&(idx1_flat(0)); + float * dist2=&(dist2_flat(0)); + int * idx2=&(idx2_flat(0)); + NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2); + } +}; +REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp); + +void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2); +class NnDistanceGradGpuOp : public OpKernel{ + public: + explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){} + void Compute(OpKernelContext * context)override{ + const Tensor& xyz1_tensor=context->input(0); + const Tensor& xyz2_tensor=context->input(1); + const Tensor& grad_dist1_tensor=context->input(2); + const Tensor& idx1_tensor=context->input(3); + const Tensor& grad_dist2_tensor=context->input(4); + const Tensor& idx2_tensor=context->input(5); + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); + int b=xyz1_tensor.shape().dim_size(0); + int n=xyz1_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); + int m=xyz2_tensor.shape().dim_size(1); + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); + OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); + OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); + OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); + OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); + auto xyz1_flat=xyz1_tensor.flat(); + const float * xyz1=&xyz1_flat(0); + auto xyz2_flat=xyz2_tensor.flat(); + const float * xyz2=&xyz2_flat(0); + auto idx1_flat=idx1_tensor.flat(); + const int * idx1=&idx1_flat(0); + auto idx2_flat=idx2_tensor.flat(); + const int * idx2=&idx2_flat(0); + auto grad_dist1_flat=grad_dist1_tensor.flat(); + const float * grad_dist1=&grad_dist1_flat(0); + auto grad_dist2_flat=grad_dist2_tensor.flat(); + const float * grad_dist2=&grad_dist2_flat(0); + Tensor * grad_xyz1_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); + Tensor * grad_xyz2_tensor=NULL; + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); + auto grad_xyz1_flat=grad_xyz1_tensor->flat(); + float * grad_xyz1=&grad_xyz1_flat(0); + auto grad_xyz2_flat=grad_xyz2_tensor->flat(); + float * grad_xyz2=&grad_xyz2_flat(0); + NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2); + } +}; +REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp); diff --git a/pc_distance/tf_nndistance.cu b/pc_distance/tf_nndistance.cu new file mode 100755 index 0000000..b755122 --- /dev/null +++ b/pc_distance/tf_nndistance.cu @@ -0,0 +1,159 @@ +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +// #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){ + NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); + NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); + NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); +} + +#endif diff --git a/pc_distance/tf_nndistance.py b/pc_distance/tf_nndistance.py new file mode 100755 index 0000000..ea04015 --- /dev/null +++ b/pc_distance/tf_nndistance.py @@ -0,0 +1,79 @@ +import os, sys +import tensorflow as tf +from tensorflow.python.framework import ops +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +nn_distance_module=tf.load_op_library(os.path.join(BASE_DIR, 'tf_nndistance_so.so')) + +def nn_distance(xyz1,xyz2): + ''' +Computes the distance of nearest neighbors for a pair of point clouds +input: xyz1: (batch_size,#points_1,3) the first point cloud +input: xyz2: (batch_size,#points_2,3) the second point cloud +output: dist1: (batch_size,#point_1) distance from first to second +output: idx1: (batch_size,#point_1) nearest neighbor from first to second +output: dist2: (batch_size,#point_2) distance from second to first +output: idx2: (batch_size,#point_2) nearest neighbor from second to first + ''' + return nn_distance_module.nn_distance(xyz1,xyz2) +#@tf.RegisterShape('NnDistance') +#def _nn_distance_shape(op): + #shape1=op.inputs[0].get_shape().with_rank(3) + #shape2=op.inputs[1].get_shape().with_rank(3) + #return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]), + #tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] +@ops.RegisterGradient('NnDistance') +def _nn_distance_grad(op,grad_dist1,grad_idx1,grad_dist2,grad_idx2): + xyz1=op.inputs[0] + xyz2=op.inputs[1] + idx1=op.outputs[1] + idx2=op.outputs[3] + return nn_distance_module.nn_distance_grad(xyz1,xyz2,grad_dist1,idx1,grad_dist2,idx2) + + +if __name__=='__main__': + import numpy as np + import random + import time + from tensorflow.python.ops.gradient_checker import compute_gradient + random.seed(100) + np.random.seed(100) + with tf.Session('') as sess: + xyz1=np.random.randn(32,16384,3).astype('float32') + xyz2=np.random.randn(32,1024,3).astype('float32') + #with tf.device('/gpu:0'): + if True: + inp1=tf.Variable(xyz1) + inp2=tf.constant(xyz2) + reta,retb,retc,retd=nn_distance(inp1,inp2) + loss=tf.reduce_sum(reta)+tf.reduce_sum(retc) + train=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss) + sess.run(tf.initialize_all_variables()) + t0=time.time() + t1=t0 + best=1e100 + for i in xrange(100): + trainloss,_=sess.run([loss,train]) + newt=time.time() + best=min(best,newt-t1) + print(i,trainloss,(newt-t0)/(i+1),best) + t1=newt + #print sess.run([inp1,retb,inp2,retd]) + #grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2]) + #for i,j in grads: + #print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j)) + #for i in xrange(10): + #t0=time.time() + #a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2}) + #print 'time',time.time()-t0 + #print a.shape,b.shape,c.shape,d.shape + #print a.dtype,b.dtype,c.dtype,d.dtype + #samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32') + #dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1) + #idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) + #print np.abs(dist1-a[:,samples]).max() + #print np.abs(idx1-b[:,samples]).max() + #dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1) + #idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) + #print np.abs(dist2-c[:,samples]).max() + #print np.abs(idx2-d[:,samples]).max() +