Skip to content

Commit

Permalink
Adding fast bottleneck implementation into contrib (NVIDIA#1079)
Browse files Browse the repository at this point in the history
* initial commit for adding fast bottleneck

* sync cudnn-frontend module

Co-authored-by: pbialecki <pbialecki@nvidia.com>
  • Loading branch information
FDecaYed and ptrblck authored Apr 17, 2021
1 parent 5c9b21d commit 705cba9
Show file tree
Hide file tree
Showing 7 changed files with 1,942 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
path = apex/contrib/csrc/multihead_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v1.2.0
[submodule "apex/contrib/csrc/cudnn-frontend"]
path = apex/contrib/csrc/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
1 change: 1 addition & 0 deletions apex/contrib/bottleneck/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bottleneck import Bottleneck
214 changes: 214 additions & 0 deletions apex/contrib/bottleneck/bottleneck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import torch
from torch import nn
import fast_bottleneck

def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)

class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def get_scale_bias(self, nhwc=False):
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
if nhwc:
scale = scale.reshape(1, 1, 1, -1)
bias = bias.reshape(1, 1, 1, -1)
else:
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return scale, bias

def forward(self, x):
scale, bias = self.get_scale_bias()
return x * scale + bias


@torch.jit.script
def drelu_dscale1(grad_o, output, scale1):
relu_mask = (output>0).half()
dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1
return g1, dx_relu

@torch.jit.script
def drelu_dscale2(grad_o, output, scale1, scale2):
relu_mask = (output>0).half()
dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1
g2 = dx_relu * scale2
return g1, g2

class BottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
if ctx.downsample:
args.append(conv[3])
args.append(scale[3])
args.append(bias[3])

# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
return outputs[2]

# backward relu is not exposed, MUL with mask used now
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
outputs = ctx.saved_tensors[-3:]

if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
else:
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])

# create input vector for backward
t_list = [*ctx.saved_tensors[0:10]]
t_list.append(grad_conv3)
t_list.append(grad_conv4)

# outputs used for wgrad and generating drelu mask
t_list.append(outputs[0])
t_list.append(outputs[1])

# in case there is downsample
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])

grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)

return (None, None, None, None, *grads)

bottleneck_function = BottleneckFunction.apply

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class Bottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1

def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False):
super(Bottleneck, self).__init__()
if groups != 1:
raise RuntimeError('Only support groups == 1')
if dilation != 1:
raise RuntimeError('Only support dilation == 1')
if norm_func == None:
norm_func = FrozenBatchNorm2d
else:
raise RuntimeError('Only support frozen BN now.')

if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride),
norm_func(out_channels),
)
else:
self.downsample = None

# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
self.conv3 = conv1x1(bottleneck_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.stride = stride

self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)

self.use_cudnn = use_cudnn

# setup conv weights
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
if self.downsample is not None:
self.w_conv.append(self.downsample[0].weight)

# init weight in nchw format before possible transpose
for w in self.w_conv:
kaiming_uniform_(w, a=1)

# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self.explicit_nhwc = explicit_nhwc
if self.explicit_nhwc:
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous()
return

def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)

out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
return out

if self.explicit_nhwc:
raise RuntimeError('explicit nhwc with native ops is not supported.')

# fallback to native ops
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out
71 changes: 71 additions & 0 deletions apex/contrib/bottleneck/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from bottleneck import Bottleneck
torch.manual_seed(23337)

# use True to print layerwise sum for all outputs in reference code path
DEBUG = False#True

for stride, o_channel in [(1,32), (1,128), (2,32)]:
print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel)
a_ = torch.randn(17,32,28,28)

a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()
model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)

# test model
b = model(a)
b.mean().backward()
d_grad = a.grad.float()
a.grad = None
torch.cuda.synchronize()

if DEBUG:
print("[DEBUG] ref dx :", d_grad.sum().item())
# print wgrad. we don't need to reset since later cpp print before accumulation
for i, w in enumerate(model.w_conv):
print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item())

wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float())

model.use_cudnn = True
model.zero_grad()
c = model(a)
c.mean().backward()

torch.cuda.synchronize()
print("comparing native and channels_last:")
print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item())
print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())

nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()
nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()
for p,q in zip(model.parameters(), nhwc_model.parameters()):
# model's storage is already in nhwc, we clone and assign to explicit nhwc model
q.data.copy_(p.data.permute(0,2,3,1).contiguous())
for p,q in zip(model.buffers(), nhwc_model.buffers()):
q.data.copy_(p.data)

d = nhwc_model(nhwc_a)
d.mean().backward()
torch.cuda.synchronize()

# reset reference to cudnn channels_last permute
#c_s = c.storage().tolist()
#d_s = d.storage().tolist()
#print(max([x-y for x,y in zip(c_s,d_s)]))
c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()
d_grad = a.grad.float().permute(0,2,3,1).contiguous()
wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())

torch.cuda.synchronize()
print("comparing nhwc and channels_last:")
print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item())
print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
Loading

0 comments on commit 705cba9

Please sign in to comment.