Skip to content

Commit

Permalink
Conv-Bias-ReLU fusion (NVIDIA#1332)
Browse files Browse the repository at this point in the history
* Enabled Conv-Bias-ReLU fusion

The following modules are enabled using cuDNN runtime fusion:
1) Conv-Bias-ReLU (+backward)
2) Conv-Bias (+backward)
3) Conv-Bias-Mask-ReLU (+backward)

* Casts cleanup and autocast in unittest

- Remove redundant dtype casts
- Simulate the usage in the unittest by using torch.cuda.amp.autocast

Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>

* Fixed save_for_backward

Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: root <root@luna-0277.selene.nvidia.com>
  • Loading branch information
3 people authored Mar 30, 2022
1 parent 3c88451 commit 23cfb57
Show file tree
Hide file tree
Showing 5 changed files with 1,827 additions and 0 deletions.
2 changes: 2 additions & 0 deletions apex/contrib/conv_bias_relu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU

76 changes: 76 additions & 0 deletions apex/contrib/conv_bias_relu/conv_bias_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import pdb
from torch.autograd import gradcheck
import fused_conv_bias_relu


class ConvBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride

return outputs[0]

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)

return grads[0], grads[1], grads[2], None, None


class ConvBiasMaskReLU_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, mask, padding, stride):
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride

return outputs[0]

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)

return grads[0], grads[1], grads[2], None, None, None


class ConvBias_(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight)
ctx.padding = padding
ctx.stride = stride

return outputs[0]

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)

return grads[0], grads[1], grads[2], None, None


ConvBiasReLU = ConvBiasReLU_.apply
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
ConvBias = ConvBias_.apply

Loading

0 comments on commit 23cfb57

Please sign in to comment.