Skip to content

Commit

Permalink
transformer utils (NVIDIA#1181)
Browse files Browse the repository at this point in the history
Co-authored-by: Piotr Bialecki <pbialecki@nvidia.com>
Co-authored-by: Eddie Yan <eddiey@nvidia.com>
Co-authored-by: Rishi Puri <riship@nvidia.com>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
  • Loading branch information
5 people authored Oct 2, 2021
1 parent bdac244 commit 365fdc1
Show file tree
Hide file tree
Showing 46 changed files with 6,894 additions and 183 deletions.
141 changes: 140 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,144 @@ build
docs/build
*~
__pycache__
.vscode

# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so
.vscode

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
1 change: 1 addition & 0 deletions apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from . import optimizers
from . import normalization
from . import pyprof
from . import transformer
8 changes: 8 additions & 0 deletions apex/_autocast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch


def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
2 changes: 1 addition & 1 deletion apex/normalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fused_layer_norm import FusedLayerNorm
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
179 changes: 116 additions & 63 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,103 @@
import math
import torch
import importlib
import numbers

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F
import importlib

from apex._autocast_utils import _cast_if_autocast_enabled

global fused_layer_norm_cuda
fused_layer_norm_cuda = None


class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
return grad_input, grad_weight, grad_bias, None, None


class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None

class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None


def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineFunction.apply(*args)

@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None

def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)

def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormFunction.apply(*args)


def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)


class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
Expand Down Expand Up @@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module):
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""

def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
super().__init__()

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
Expand All @@ -141,8 +174,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -152,14 +185,34 @@ def reset_parameters(self):

def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
return fused_layer_norm(input, self.normalized_shape, self.eps)

def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)


# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedLayerNorm(FusedLayerNorm):

def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`")

super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)

def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
5 changes: 5 additions & 0 deletions apex/transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# apex.transformer

`apex.transformer` is a module which enables efficient large Transformer models at scale.

`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
Loading

0 comments on commit 365fdc1

Please sign in to comment.