forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding fast bottleneck implementation into contrib (NVIDIA#1079)
* initial commit for adding fast bottleneck * sync cudnn-frontend module Co-authored-by: pbialecki <pbialecki@nvidia.com>
- Loading branch information
Showing
7 changed files
with
1,942 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bottleneck import Bottleneck |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import torch | ||
from torch import nn | ||
import fast_bottleneck | ||
|
||
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): | ||
weight_tensor_nchw = tensor | ||
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) | ||
|
||
class FrozenBatchNorm2d(torch.nn.Module): | ||
""" | ||
BatchNorm2d where the batch statistics and the affine parameters are fixed | ||
""" | ||
def __init__(self, n): | ||
super(FrozenBatchNorm2d, self).__init__() | ||
self.register_buffer("weight", torch.ones(n)) | ||
self.register_buffer("bias", torch.zeros(n)) | ||
self.register_buffer("running_mean", torch.zeros(n)) | ||
self.register_buffer("running_var", torch.ones(n)) | ||
|
||
def get_scale_bias(self, nhwc=False): | ||
scale = self.weight * self.running_var.rsqrt() | ||
bias = self.bias - self.running_mean * scale | ||
if nhwc: | ||
scale = scale.reshape(1, 1, 1, -1) | ||
bias = bias.reshape(1, 1, 1, -1) | ||
else: | ||
scale = scale.reshape(1, -1, 1, 1) | ||
bias = bias.reshape(1, -1, 1, 1) | ||
return scale, bias | ||
|
||
def forward(self, x): | ||
scale, bias = self.get_scale_bias() | ||
return x * scale + bias | ||
|
||
|
||
@torch.jit.script | ||
def drelu_dscale1(grad_o, output, scale1): | ||
relu_mask = (output>0).half() | ||
dx_relu = relu_mask * grad_o | ||
g1 = dx_relu * scale1 | ||
return g1, dx_relu | ||
|
||
@torch.jit.script | ||
def drelu_dscale2(grad_o, output, scale1, scale2): | ||
relu_mask = (output>0).half() | ||
dx_relu = relu_mask * grad_o | ||
g1 = dx_relu * scale1 | ||
g2 = dx_relu * scale2 | ||
return g1, g2 | ||
|
||
class BottleneckFunction(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv): | ||
# TODO: clean up order of tensors | ||
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] | ||
ctx.downsample = len(conv) > 3 | ||
if ctx.downsample: | ||
args.append(conv[3]) | ||
args.append(scale[3]) | ||
args.append(bias[3]) | ||
|
||
# weight buffers are always in nhwc while shape can be nhwc or channels_last | ||
# here we pass in flag and let c++ handle it | ||
# alternatively, we can put all sizes into a fixed format and pass it in | ||
outputs = fast_bottleneck.forward(nhwc, stride_1x1, args) | ||
ctx.save_for_backward(*(args+outputs)) | ||
# save relu outputs for drelu | ||
ctx.nhwc = nhwc | ||
ctx.stride_1x1 = stride_1x1 | ||
return outputs[2] | ||
|
||
# backward relu is not exposed, MUL with mask used now | ||
# only support dgrad | ||
@staticmethod | ||
def backward(ctx, grad_o): | ||
outputs = ctx.saved_tensors[-3:] | ||
|
||
if ctx.downsample: | ||
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]) | ||
else: | ||
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) | ||
|
||
# create input vector for backward | ||
t_list = [*ctx.saved_tensors[0:10]] | ||
t_list.append(grad_conv3) | ||
t_list.append(grad_conv4) | ||
|
||
# outputs used for wgrad and generating drelu mask | ||
t_list.append(outputs[0]) | ||
t_list.append(outputs[1]) | ||
|
||
# in case there is downsample | ||
if ctx.downsample: | ||
t_list.append(ctx.saved_tensors[10]) | ||
|
||
grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list) | ||
|
||
return (None, None, None, None, *grads) | ||
|
||
bottleneck_function = BottleneckFunction.apply | ||
|
||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||
"""3x3 convolution with padding""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=dilation, groups=groups, bias=False, dilation=dilation) | ||
|
||
def conv1x1(in_planes, out_planes, stride=1): | ||
"""1x1 convolution""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
|
||
class Bottleneck(torch.nn.Module): | ||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | ||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | ||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | ||
# This variant is also known as ResNet V1.5 and improves accuracy according to | ||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | ||
# here we put it at 1x1 | ||
|
||
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, | ||
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False): | ||
super(Bottleneck, self).__init__() | ||
if groups != 1: | ||
raise RuntimeError('Only support groups == 1') | ||
if dilation != 1: | ||
raise RuntimeError('Only support dilation == 1') | ||
if norm_func == None: | ||
norm_func = FrozenBatchNorm2d | ||
else: | ||
raise RuntimeError('Only support frozen BN now.') | ||
|
||
if stride != 1 or in_channels != out_channels: | ||
self.downsample = nn.Sequential( | ||
conv1x1(in_channels, out_channels, stride), | ||
norm_func(out_channels), | ||
) | ||
else: | ||
self.downsample = None | ||
|
||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) | ||
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) | ||
self.conv3 = conv1x1(bottleneck_channels, out_channels) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.stride = stride | ||
|
||
self.bn1 = norm_func(bottleneck_channels) | ||
self.bn2 = norm_func(bottleneck_channels) | ||
self.bn3 = norm_func(out_channels) | ||
|
||
self.use_cudnn = use_cudnn | ||
|
||
# setup conv weights | ||
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] | ||
if self.downsample is not None: | ||
self.w_conv.append(self.downsample[0].weight) | ||
|
||
# init weight in nchw format before possible transpose | ||
for w in self.w_conv: | ||
kaiming_uniform_(w, a=1) | ||
|
||
# TODO: prevent unsupported case usage | ||
# support cases | ||
# native cudnn | ||
# normal yes no | ||
# channel_last yes yes | ||
# explicit_nhwc no yes | ||
self.explicit_nhwc = explicit_nhwc | ||
if self.explicit_nhwc: | ||
for p in self.parameters(): | ||
with torch.no_grad(): | ||
p.data = p.data.permute(0,2,3,1).contiguous() | ||
return | ||
|
||
def forward(self, x): | ||
if self.use_cudnn: | ||
# calculate scale/bias from registered buffers | ||
# TODO: make this better | ||
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) | ||
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) | ||
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) | ||
w_scale = [s1, s2, s3] | ||
w_bias = [b1, b2, b3] | ||
if self.downsample is not None: | ||
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) | ||
w_scale.append(s4) | ||
w_bias.append(b4) | ||
|
||
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) | ||
return out | ||
|
||
if self.explicit_nhwc: | ||
raise RuntimeError('explicit nhwc with native ops is not supported.') | ||
|
||
# fallback to native ops | ||
identity = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
identity = self.downsample(x) | ||
|
||
out += identity | ||
out = self.relu(out) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
from bottleneck import Bottleneck | ||
torch.manual_seed(23337) | ||
|
||
# use True to print layerwise sum for all outputs in reference code path | ||
DEBUG = False#True | ||
|
||
for stride, o_channel in [(1,32), (1,128), (2,32)]: | ||
print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel) | ||
a_ = torch.randn(17,32,28,28) | ||
|
||
a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_() | ||
model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last) | ||
|
||
# test model | ||
b = model(a) | ||
b.mean().backward() | ||
d_grad = a.grad.float() | ||
a.grad = None | ||
torch.cuda.synchronize() | ||
|
||
if DEBUG: | ||
print("[DEBUG] ref dx :", d_grad.sum().item()) | ||
# print wgrad. we don't need to reset since later cpp print before accumulation | ||
for i, w in enumerate(model.w_conv): | ||
print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item()) | ||
|
||
wgrads = [] | ||
for w in model.w_conv: | ||
wgrads.append(w.grad.float()) | ||
|
||
model.use_cudnn = True | ||
model.zero_grad() | ||
c = model(a) | ||
c.mean().backward() | ||
|
||
torch.cuda.synchronize() | ||
print("comparing native and channels_last:") | ||
print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item()) | ||
print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) | ||
for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)): | ||
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) | ||
|
||
nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_() | ||
nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half() | ||
for p,q in zip(model.parameters(), nhwc_model.parameters()): | ||
# model's storage is already in nhwc, we clone and assign to explicit nhwc model | ||
q.data.copy_(p.data.permute(0,2,3,1).contiguous()) | ||
for p,q in zip(model.buffers(), nhwc_model.buffers()): | ||
q.data.copy_(p.data) | ||
|
||
d = nhwc_model(nhwc_a) | ||
d.mean().backward() | ||
torch.cuda.synchronize() | ||
|
||
# reset reference to cudnn channels_last permute | ||
#c_s = c.storage().tolist() | ||
#d_s = d.storage().tolist() | ||
#print(max([x-y for x,y in zip(c_s,d_s)])) | ||
c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous() | ||
d_grad = a.grad.float().permute(0,2,3,1).contiguous() | ||
wgrads = [] | ||
for w in model.w_conv: | ||
wgrads.append(w.grad.float().permute(0,2,3,1).contiguous()) | ||
|
||
torch.cuda.synchronize() | ||
print("comparing nhwc and channels_last:") | ||
print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item()) | ||
print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) | ||
for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)): | ||
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) |
Oops, something went wrong.