From 705cba91fc26413d484f5ce2983427d840199a97 Mon Sep 17 00:00:00 2001 From: Deyu Fu Date: Sat, 17 Apr 2021 08:24:52 +0800 Subject: [PATCH] Adding fast bottleneck implementation into contrib (#1079) * initial commit for adding fast bottleneck * sync cudnn-frontend module Co-authored-by: pbialecki --- .gitmodules | 3 + apex/contrib/bottleneck/__init__.py | 1 + apex/contrib/bottleneck/bottleneck.py | 214 +++ apex/contrib/bottleneck/test.py | 71 + apex/contrib/csrc/bottleneck/bottleneck.cpp | 1634 +++++++++++++++++++ apex/contrib/csrc/cudnn-frontend | 1 + setup.py | 19 +- 7 files changed, 1942 insertions(+), 1 deletion(-) create mode 100644 apex/contrib/bottleneck/__init__.py create mode 100644 apex/contrib/bottleneck/bottleneck.py create mode 100644 apex/contrib/bottleneck/test.py create mode 100644 apex/contrib/csrc/bottleneck/bottleneck.cpp create mode 160000 apex/contrib/csrc/cudnn-frontend diff --git a/.gitmodules b/.gitmodules index f93c04075..6479428db 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = apex/contrib/csrc/multihead_attn/cutlass url = https://github.com/NVIDIA/cutlass.git branch = v1.2.0 +[submodule "apex/contrib/csrc/cudnn-frontend"] + path = apex/contrib/csrc/cudnn-frontend + url = https://github.com/NVIDIA/cudnn-frontend.git diff --git a/apex/contrib/bottleneck/__init__.py b/apex/contrib/bottleneck/__init__.py new file mode 100644 index 000000000..4acbc25d7 --- /dev/null +++ b/apex/contrib/bottleneck/__init__.py @@ -0,0 +1 @@ +from .bottleneck import Bottleneck diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py new file mode 100644 index 000000000..7c4e01a53 --- /dev/null +++ b/apex/contrib/bottleneck/bottleneck.py @@ -0,0 +1,214 @@ +import torch +from torch import nn +import fast_bottleneck + +def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): + weight_tensor_nchw = tensor + nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed + """ + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def get_scale_bias(self, nhwc=False): + scale = self.weight * self.running_var.rsqrt() + bias = self.bias - self.running_mean * scale + if nhwc: + scale = scale.reshape(1, 1, 1, -1) + bias = bias.reshape(1, 1, 1, -1) + else: + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return scale, bias + + def forward(self, x): + scale, bias = self.get_scale_bias() + return x * scale + bias + + +@torch.jit.script +def drelu_dscale1(grad_o, output, scale1): + relu_mask = (output>0).half() + dx_relu = relu_mask * grad_o + g1 = dx_relu * scale1 + return g1, dx_relu + +@torch.jit.script +def drelu_dscale2(grad_o, output, scale1, scale2): + relu_mask = (output>0).half() + dx_relu = relu_mask * grad_o + g1 = dx_relu * scale1 + g2 = dx_relu * scale2 + return g1, g2 + +class BottleneckFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv): + # TODO: clean up order of tensors + args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] + ctx.downsample = len(conv) > 3 + if ctx.downsample: + args.append(conv[3]) + args.append(scale[3]) + args.append(bias[3]) + + # weight buffers are always in nhwc while shape can be nhwc or channels_last + # here we pass in flag and let c++ handle it + # alternatively, we can put all sizes into a fixed format and pass it in + outputs = fast_bottleneck.forward(nhwc, stride_1x1, args) + ctx.save_for_backward(*(args+outputs)) + # save relu outputs for drelu + ctx.nhwc = nhwc + ctx.stride_1x1 = stride_1x1 + return outputs[2] + + # backward relu is not exposed, MUL with mask used now + # only support dgrad + @staticmethod + def backward(ctx, grad_o): + outputs = ctx.saved_tensors[-3:] + + if ctx.downsample: + grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]) + else: + grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) + + # create input vector for backward + t_list = [*ctx.saved_tensors[0:10]] + t_list.append(grad_conv3) + t_list.append(grad_conv4) + + # outputs used for wgrad and generating drelu mask + t_list.append(outputs[0]) + t_list.append(outputs[1]) + + # in case there is downsample + if ctx.downsample: + t_list.append(ctx.saved_tensors[10]) + + grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list) + + return (None, None, None, None, *grads) + +bottleneck_function = BottleneckFunction.apply + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class Bottleneck(torch.nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + # here we put it at 1x1 + + def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, + dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False): + super(Bottleneck, self).__init__() + if groups != 1: + raise RuntimeError('Only support groups == 1') + if dilation != 1: + raise RuntimeError('Only support dilation == 1') + if norm_func == None: + norm_func = FrozenBatchNorm2d + else: + raise RuntimeError('Only support frozen BN now.') + + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + conv1x1(in_channels, out_channels, stride), + norm_func(out_channels), + ) + else: + self.downsample = None + + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) + self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) + self.conv3 = conv1x1(bottleneck_channels, out_channels) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + self.bn1 = norm_func(bottleneck_channels) + self.bn2 = norm_func(bottleneck_channels) + self.bn3 = norm_func(out_channels) + + self.use_cudnn = use_cudnn + + # setup conv weights + self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] + if self.downsample is not None: + self.w_conv.append(self.downsample[0].weight) + + # init weight in nchw format before possible transpose + for w in self.w_conv: + kaiming_uniform_(w, a=1) + + # TODO: prevent unsupported case usage + # support cases + # native cudnn + # normal yes no + # channel_last yes yes + # explicit_nhwc no yes + self.explicit_nhwc = explicit_nhwc + if self.explicit_nhwc: + for p in self.parameters(): + with torch.no_grad(): + p.data = p.data.permute(0,2,3,1).contiguous() + return + + def forward(self, x): + if self.use_cudnn: + # calculate scale/bias from registered buffers + # TODO: make this better + s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) + s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) + s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) + w_scale = [s1, s2, s3] + w_bias = [b1, b2, b3] + if self.downsample is not None: + s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) + w_scale.append(s4) + w_bias.append(b4) + + out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) + return out + + if self.explicit_nhwc: + raise RuntimeError('explicit nhwc with native ops is not supported.') + + # fallback to native ops + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out diff --git a/apex/contrib/bottleneck/test.py b/apex/contrib/bottleneck/test.py new file mode 100644 index 000000000..2c3c62130 --- /dev/null +++ b/apex/contrib/bottleneck/test.py @@ -0,0 +1,71 @@ +import torch +from bottleneck import Bottleneck +torch.manual_seed(23337) + +# use True to print layerwise sum for all outputs in reference code path +DEBUG = False#True + +for stride, o_channel in [(1,32), (1,128), (2,32)]: + print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel) + a_ = torch.randn(17,32,28,28) + + a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_() + model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last) + + # test model + b = model(a) + b.mean().backward() + d_grad = a.grad.float() + a.grad = None + torch.cuda.synchronize() + + if DEBUG: + print("[DEBUG] ref dx :", d_grad.sum().item()) + # print wgrad. we don't need to reset since later cpp print before accumulation + for i, w in enumerate(model.w_conv): + print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item()) + + wgrads = [] + for w in model.w_conv: + wgrads.append(w.grad.float()) + + model.use_cudnn = True + model.zero_grad() + c = model(a) + c.mean().backward() + + torch.cuda.synchronize() + print("comparing native and channels_last:") + print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item()) + print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) + for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)): + print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) + + nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_() + nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half() + for p,q in zip(model.parameters(), nhwc_model.parameters()): + # model's storage is already in nhwc, we clone and assign to explicit nhwc model + q.data.copy_(p.data.permute(0,2,3,1).contiguous()) + for p,q in zip(model.buffers(), nhwc_model.buffers()): + q.data.copy_(p.data) + + d = nhwc_model(nhwc_a) + d.mean().backward() + torch.cuda.synchronize() + + # reset reference to cudnn channels_last permute + #c_s = c.storage().tolist() + #d_s = d.storage().tolist() + #print(max([x-y for x,y in zip(c_s,d_s)])) + c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous() + d_grad = a.grad.float().permute(0,2,3,1).contiguous() + wgrads = [] + for w in model.w_conv: + wgrads.append(w.grad.float().permute(0,2,3,1).contiguous()) + + torch.cuda.synchronize() + print("comparing nhwc and channels_last:") + print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item()) + print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) + for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)): + print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp new file mode 100644 index 000000000..65b9a2bbb --- /dev/null +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -0,0 +1,1634 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include // for getcudnnhandle +#include +#include +#include +#include + +#include + +#ifdef DEBUG +#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) +#else +#define DEBUG_MSG(str) do { } while ( false ) +#endif + +#ifdef DEBUG_CUDNN +#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) +#else +#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) +#endif + +#define checkCudnnErr(...) \ + do { \ + int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + if (err) { \ + return; \ + } \ + } while (0) + + +int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { + if (code) { + printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); + return 1; + } + return 0; +} + +void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); +#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function + +void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) +{ + if (code != cudaSuccess) + { + const char * errorMessage = cudaGetErrorString(code); + fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); + if (abort){ + cudaDeviceReset(); + exit(code); + } + } +} + +void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + if (filterFormat == CUDNN_TENSOR_NCHW) { + strideA[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + } else { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strideA[1] = 1; + strideA[nbDims - 1] = strideA[1] * dimA[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + strideA[0] = strideA[2] * dimA[2]; + } +} + + +int getFwdConvDilatedFilterDim(int filterDim, int dilation) { + return ((filterDim - 1) * dilation) + 1; +} + +int getFwdConvPaddedImageDim(int tensorDim, int pad) { + return tensorDim + (2 * pad); +} + +int getFwdConvOutputDim( + int tensorDim, + int pad, + int filterDim, + int stride, + int dilation) +{ + int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; + return (p); +} + +enum { + X_TENSOR, + Y_TENSOR, + W_TENSOR, + Z_TENSOR, + B_TENSOR, + AFTERADD_TENSOR, + AFTERBIAS_TENSOR, + AFTERCONV_TENSOR, + OPTIONAL, + AFTEROPT_TENSOR, +}; + +using common_conv_descriptors = + std::tuple; + + +common_conv_descriptors +create_common_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + cudnnConvolutionMode_t mode) { + const int convDim = 2; + + int64_t strideA_padded[4]; + int64_t outstrideA_padded[4]; + int64_t filterstrideA_padded[4]; + + generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); + + return common_conv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, strideA_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, outstrideA_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, filterstrideA_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(mode) + .setNDims(convDim) + .setStrides(convDim, convstrideA) + .setPrePadding(convDim, padA) + .setPostPadding(convDim, padA) + .setDilation(convDim, dilationA) + .build()); +} + +using common_convbias_descriptors = std::tuple; + +common_convbias_descriptors +create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = y_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return common_convbias_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('z') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('A') // after add + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('B') // after bias + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('C') // after conv + .setAlignment(16) + .setVirtual() + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(dataType) + .build()); +} + +// tensor descriptors used for dgrad +enum { + X_OR_DX_TENSOR, + DY_TENSOR, + W_OR_DW_TENSOR, + SCALE_TENSOR, + RELU_TENSOR, + AFTER_DCONV_TENSOR, + AFTER_DRELU_TENSOR, +}; + +using dconv_descriptors = std::tuple; + +dconv_descriptors +create_dconv_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return dconv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(dataType) + .build()); +} + +// create a cache for plan +std::unordered_map plan_cache; + +// TODO: better name +std::string getConvFusionString(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + cudnnDataType_t dataType, + std::string fusion_string) { + + for(int i=0;i<4;i++) { + fusion_string += 'X'; + fusion_string += std::to_string(x_dim_padded[i]); + } + for(int i=0;i<4;i++) { + fusion_string += 'W'; + fusion_string += std::to_string(w_dim_padded[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'P'; + fusion_string += std::to_string(padA[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'S'; + fusion_string += std::to_string(convstrideA[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'D'; + fusion_string += std::to_string(dilationA[i]); + } + fusion_string += 'T'; + fusion_string += std::to_string(dataType); + return fusion_string; +} + +cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, + std::stringstream& log_buf, + cudnn_frontend::OperationGraph& opGraph, + std::string cache_string, + bool use_heuristic = true){ + auto it = plan_cache.find(cache_string); + if (it != plan_cache.end()) { + DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); + return it->second; + } else { + if (use_heuristic){ + // TODO: confirm which mode to use + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(opGraph) + .setHeurMode(CUDNN_HEUR_MODE_INSTANT) + .build(); + // try 3 times for now as WAR for no heuristic training + int max_tries = 3, count = 0; + auto& engine_configs = heuristics.getEngineConfig(max_tries); + while(true) { + try { + plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(engine_configs[count], opGraph.getTag()) + .build())); + break; + } catch (cudnn_frontend::cudnnException e) { + if (++count == max_tries) throw e; + } + } + }else{ + DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); + // How many engines support this operation graph ? + auto total_engines = opGraph.getEngineCount(); + DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); + // We have to randomly pick one engine from [0, total_engines) + // Selecting "0" by default + auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); + DEBUG_CUDNN_MSG(log_buf, engine.describe()); + auto& knobs = engine.getSupportedKnobs(); + for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { + DEBUG_CUDNN_MSG(log_buf, it->describe()); + } + if (knobs.begin() != knobs.end()) { + DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); + knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); + DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); + } + + // Createmplacee the requisite engine config + auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); + DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); + plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); + } + + return plan_cache.find(cache_string)->second; + } +} + +void +run_conv_scale_bias_add_activation(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB, + at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create a optional add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(devPtrI ? 6 : 5, data_ptrs) + .setUids(devPtrI ? 6 : 5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_conv_scale_bias(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) // TODO: change enum to aftermul + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &add_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + + +void +run_dconv_drelu_dscale(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &act_op, &scale_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + cudnnBackendDescriptorType_t mode) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + // mode should be one of following + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR + auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); + if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { + conv_op_builder.setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta); + } + else { + conv_op_builder.setxDesc(std::get(tensors)) + .setdwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta); + } + auto conv_op = conv_op_builder.build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; + int64_t uids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv_add(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the add backward operation + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error"); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + + +// inputs contains x,w,z,b,(i) +std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + + std::cout << std::fixed; + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // setup dimensions + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; + + // All dim calculation after this order of n,c,h,w + int axis[] {0,1,2,3}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[10].size(axis[dim]); + } + } + + // output dim in n,c,h,w used by backend + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + + // use these fixed value for test run + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; + int64_t dilationA[] = {1, 1}; + int64_t convstrideA[] = {1, 1}; + int64_t convstride1X1[] = {stride_1X1, stride_1X1}; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + // Create output tensor in the correct shape in pytorch's view + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim1[dim] = outdimA1[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + } + + // run + at::Half* x = inputs[0].data_ptr(); + at::Half* w = inputs[1].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + at::Half* b = inputs[7].data_ptr(); + auto out1 = at::empty(outdim1, inputs[0].type(), output_format); + at::Half* y1 = out1.data_ptr(); + + run_conv_scale_bias_add_activation(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + x, + w, + y1, + z, + b, + nullptr); + + DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); + + w = inputs[2].data_ptr(); + z = inputs[5].data_ptr(); + b = inputs[8].data_ptr(); + auto out2 = at::empty(outdim2, inputs[0].type(), output_format); + at::Half* y2 = out2.data_ptr(); + + run_conv_scale_bias_add_activation(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr); + DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); + + // create output of conv3 + auto out3 = at::empty(outdim3, inputs[0].type(), output_format); + at::Half* y3 = out3.data_ptr(); + + // create output of conv4 that may exist + auto identity = at::empty_like(out3); + at::Half* yi = identity.data_ptr(); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ + + w = inputs[10].data_ptr(); + z = inputs[11].data_ptr(); + b = inputs[12].data_ptr(); + run_conv_scale_bias(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + x, + w, + yi, + z, + b); + DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); + } + else { + yi = x; + } + + w = inputs[3].data_ptr(); + z = inputs[6].data_ptr(); + b = inputs[9].data_ptr(); + + run_conv_scale_bias_add_activation(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + y2, + w, + y3, + z, + b, + yi); + DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); + + outputs.push_back(out1); + outputs.push_back(out2); + outputs.push_back(out3); + + return outputs; +} + +std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // setup dimensions + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; + + // All dim calculation after this order of n,c,h,w + int axis[] {0,1,2,3}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[14].size(axis[dim]); + } + } + + // output dim in n,c,h,w used by backend + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + + // use these fixed value for test run + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; + int64_t dilationA[] = {1, 1}; + int64_t convstrideA[] = {1, 1}; + int64_t convstride1X1[] = {stride_1X1, stride_1X1}; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + // Create output tensor in the correct shape in pytorch's view + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim1[dim] = outdimA1[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + } + + // dconv3+drelu2+dscale2 + at::Half* conv_in = inputs[13].data_ptr(); + at::Half* dy3 = inputs[10].data_ptr(); + + DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); + + // wgrad + auto wgrad3 = at::empty_like(inputs[3]); + at::Half* dw3 = wgrad3.data_ptr(); + run_dconv(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + conv_in, + dw3, + dy3, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format); + at::Half* dy2 = grad_out2.data_ptr(); + at::Half* w = inputs[3].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + + at::Half* relu2 = inputs[13].data_ptr(); + + run_dconv_drelu_dscale(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + dy2, + w, + dy3, + z, + relu2); + + DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); + + // dconv2+drelu1+dscale1 + conv_in = inputs[12].data_ptr(); + + // wgrad + auto wgrad2 = at::empty_like(inputs[2]); + at::Half* dw2 = wgrad2.data_ptr(); + run_dconv(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + conv_in, + dw2, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format); + at::Half* dy1 = grad_out1.data_ptr(); + w = inputs[2].data_ptr(); + z = inputs[4].data_ptr(); + + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + +/* + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (stride_1X1 != 1){ + // dgrad + run_dconv(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + // mul fused mask + grad_out1.mul_(inputs[15]); + } + else { + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + } +*/ + DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); + + // create grads of conv4 that may exist + auto grad_x_conv4 = at::empty_like(inputs[0]); + at::Half* dx_conv4 = grad_x_conv4.data_ptr(); + at::Tensor wgrad4; + + // x used for dconv1 and dconv4 wgrad + at::Half* x = inputs[0].data_ptr(); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ + w = inputs[14].data_ptr(); + at::Half* dy_conv4 = inputs[11].data_ptr(); + if (requires_grad) { + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + dx_conv4, + w, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx + // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); + } + // wgrad + wgrad4 = at::empty_like(inputs[14]); + at::Half* dw4 = wgrad4.data_ptr(); + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + x, + dw4, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + } + else { + // if there is no downsample, dx_conv4 is fork of drelu3 + dx_conv4 = inputs[11].data_ptr(); + } + + // dconv1+add + // wgrad + auto wgrad1 = at::empty_like(inputs[1]); + at::Half* dw1 = wgrad1.data_ptr(); + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + x, + dw1, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + w = inputs[1].data_ptr(); + auto grad_x = at::empty_like(inputs[0]); + at::Half* dx = grad_x.data_ptr(); + + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (requires_grad){ + if (stride_1X1 != 1){ + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // add 2 together + grad_x.add_(grad_x_conv4); + } + else { + run_dconv_add(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + dx_conv4); + } + } + + DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); + outputs.push_back(grad_x); + outputs.push_back(wgrad1); + outputs.push_back(wgrad2); + outputs.push_back(wgrad3); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); + outputs.push_back(wgrad4); + } + + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &bottleneck_forward, "Bottleneck block forward"); + m.def("backward", &bottleneck_backward, "Bottleneck block backward"); +} diff --git a/apex/contrib/csrc/cudnn-frontend b/apex/contrib/csrc/cudnn-frontend new file mode 160000 index 000000000..b4e1ad961 --- /dev/null +++ b/apex/contrib/csrc/cudnn-frontend @@ -0,0 +1 @@ +Subproject commit b4e1ad9613b89199982c9baf6ee91f6f98f5606d diff --git a/setup.py b/setup.py index e2576b0d5..e9cf39b37 100644 --- a/setup.py +++ b/setup.py @@ -291,7 +291,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): 'nvcc':['-O3', '--use_fast_math'] + version_dependent_macros})) -# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 +# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 generator_flag = [] torch_dir = torch.__path__[0] if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): @@ -520,6 +520,23 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'nvcc':['-O3'] + version_dependent_macros})) +if "--fast_bottleneck" in sys.argv: + from torch.utils.cpp_extension import CUDAExtension + sys.argv.remove("--fast_bottleneck") + + from torch.utils.cpp_extension import BuildExtension + cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) + + if torch.utils.cpp_extension.CUDA_HOME is None: + raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") + else: + subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) + ext_modules.append( + CUDAExtension(name='fast_bottleneck', + sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'], + include_dirs=['apex/contrib/csrc/cudnn-frontend/include'], + extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag})) + setup( name='apex', version='0.1',