forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Enabled Conv-Bias-ReLU fusion The following modules are enabled using cuDNN runtime fusion: 1) Conv-Bias-ReLU (+backward) 2) Conv-Bias (+backward) 3) Conv-Bias-Mask-ReLU (+backward) * Casts cleanup and autocast in unittest - Remove redundant dtype casts - Simulate the usage in the unittest by using torch.cuda.amp.autocast Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com> * Fixed save_for_backward Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com> Co-authored-by: root <root@luna-0277.selene.nvidia.com>
- Loading branch information
1 parent
3c88451
commit 23cfb57
Showing
5 changed files
with
1,827 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
import pdb | ||
from torch.autograd import gradcheck | ||
import fused_conv_bias_relu | ||
|
||
|
||
class ConvBiasReLU_(torch.autograd.Function): | ||
@staticmethod | ||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half) | ||
def forward(ctx, x, weight, bias, padding, stride): | ||
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride) | ||
ctx.save_for_backward(x, weight, outputs[0]) | ||
ctx.padding = padding | ||
ctx.stride = stride | ||
|
||
return outputs[0] | ||
|
||
@staticmethod | ||
@torch.cuda.amp.custom_bwd | ||
def backward(ctx, grad_output): | ||
bwd_args = [*ctx.saved_tensors, grad_output] | ||
padding = ctx.padding | ||
stride = ctx.stride | ||
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) | ||
|
||
return grads[0], grads[1], grads[2], None, None | ||
|
||
|
||
class ConvBiasMaskReLU_(torch.autograd.Function): | ||
@staticmethod | ||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half) | ||
def forward(ctx, x, weight, bias, mask, padding, stride): | ||
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride) | ||
ctx.save_for_backward(x, weight, outputs[0]) | ||
ctx.padding = padding | ||
ctx.stride = stride | ||
|
||
return outputs[0] | ||
|
||
@staticmethod | ||
@torch.cuda.amp.custom_bwd | ||
def backward(ctx, grad_output): | ||
bwd_args = [*ctx.saved_tensors, grad_output] | ||
padding = ctx.padding | ||
stride = ctx.stride | ||
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) | ||
|
||
return grads[0], grads[1], grads[2], None, None, None | ||
|
||
|
||
class ConvBias_(torch.autograd.Function): | ||
@staticmethod | ||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half) | ||
def forward(ctx, x, weight, bias, padding, stride): | ||
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride) | ||
ctx.save_for_backward(x, weight) | ||
ctx.padding = padding | ||
ctx.stride = stride | ||
|
||
return outputs[0] | ||
|
||
@staticmethod | ||
@torch.cuda.amp.custom_bwd | ||
def backward(ctx, grad_output): | ||
bwd_args = [*ctx.saved_tensors, grad_output] | ||
padding = ctx.padding | ||
stride = ctx.stride | ||
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride) | ||
|
||
return grads[0], grads[1], grads[2], None, None | ||
|
||
|
||
ConvBiasReLU = ConvBiasReLU_.apply | ||
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply | ||
ConvBias = ConvBias_.apply | ||
|
Oops, something went wrong.