From 365fdc180f02d56a92c943f72485f3514eafed9a Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 2 Oct 2021 18:27:06 +0900 Subject: [PATCH] transformer utils (#1181) Co-authored-by: Piotr Bialecki Co-authored-by: Eddie Yan Co-authored-by: Rishi Puri Co-authored-by: Sangkug Lym --- .gitignore | 141 +++- apex/__init__.py | 1 + apex/_autocast_utils.py | 8 + apex/normalization/__init__.py | 2 +- apex/normalization/fused_layer_norm.py | 179 ++-- apex/transformer/README.md | 5 + apex/transformer/__init__.py | 35 + apex/transformer/enums.py | 30 + apex/transformer/functional/__init__.py | 5 + apex/transformer/functional/fused_softmax.py | 202 +++++ apex/transformer/parallel_state.py | 346 ++++++++ apex/transformer/tensor_parallel/__init__.py | 79 ++ .../tensor_parallel/cross_entropy.py | 103 +++ apex/transformer/tensor_parallel/data.py | 113 +++ apex/transformer/tensor_parallel/layers.py | 471 +++++++++++ apex/transformer/tensor_parallel/mappings.py | 159 ++++ apex/transformer/tensor_parallel/memory.py | 136 ++++ .../tensor_parallel/microbatches.py | 164 ++++ apex/transformer/tensor_parallel/random.py | 311 +++++++ .../tensor_parallel/tests/__init__.py | 0 .../tensor_parallel/tests/arguments.py | 766 ++++++++++++++++++ .../tensor_parallel/tests/commons.py | 87 ++ .../tensor_parallel/tests/global_vars.py | 260 ++++++ apex/transformer/tensor_parallel/utils.py | 64 ++ csrc/layer_norm_cuda.cpp | 30 +- csrc/layer_norm_cuda_kernel.cu | 192 +++-- csrc/megatron/scaled_masked_softmax.cpp | 97 +++ csrc/megatron/scaled_masked_softmax.h | 505 ++++++++++++ csrc/megatron/scaled_masked_softmax_cuda.cu | 117 +++ .../scaled_upper_triang_masked_softmax.cpp | 72 ++ .../scaled_upper_triang_masked_softmax.h | 513 ++++++++++++ ...scaled_upper_triang_masked_softmax_cuda.cu | 98 +++ csrc/type_shim.h | 180 ++++ setup.py | 25 + .../test_fused_layer_norm.py | 169 +++- tests/L0/run_test.py | 72 +- tests/L0/run_transformer/__init__.py | 0 .../run_transformer/run_cross_entropy_test.py | 105 +++ tests/L0/run_transformer/run_data_test.py | 93 +++ .../L0/run_transformer/run_initialize_test.py | 102 +++ tests/L0/run_transformer/run_layers_test.py | 559 +++++++++++++ tests/L0/run_transformer/run_mappings_test.py | 61 ++ tests/L0/run_transformer/run_random_test.py | 211 +++++ tests/L0/run_transformer/run_utils_test.py | 20 + .../L0/run_transformer/test_fused_softmax.py | 137 ++++ tests/L0/run_transformer/test_mpu.py | 52 ++ 46 files changed, 6894 insertions(+), 183 deletions(-) create mode 100644 apex/_autocast_utils.py create mode 100644 apex/transformer/README.md create mode 100644 apex/transformer/__init__.py create mode 100644 apex/transformer/enums.py create mode 100644 apex/transformer/functional/__init__.py create mode 100644 apex/transformer/functional/fused_softmax.py create mode 100644 apex/transformer/parallel_state.py create mode 100644 apex/transformer/tensor_parallel/__init__.py create mode 100644 apex/transformer/tensor_parallel/cross_entropy.py create mode 100644 apex/transformer/tensor_parallel/data.py create mode 100644 apex/transformer/tensor_parallel/layers.py create mode 100644 apex/transformer/tensor_parallel/mappings.py create mode 100644 apex/transformer/tensor_parallel/memory.py create mode 100644 apex/transformer/tensor_parallel/microbatches.py create mode 100644 apex/transformer/tensor_parallel/random.py create mode 100644 apex/transformer/tensor_parallel/tests/__init__.py create mode 100644 apex/transformer/tensor_parallel/tests/arguments.py create mode 100644 apex/transformer/tensor_parallel/tests/commons.py create mode 100644 apex/transformer/tensor_parallel/tests/global_vars.py create mode 100644 apex/transformer/tensor_parallel/utils.py create mode 100644 csrc/megatron/scaled_masked_softmax.cpp create mode 100644 csrc/megatron/scaled_masked_softmax.h create mode 100644 csrc/megatron/scaled_masked_softmax_cuda.cu create mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax.cpp create mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax.h create mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu create mode 100644 tests/L0/run_transformer/__init__.py create mode 100644 tests/L0/run_transformer/run_cross_entropy_test.py create mode 100644 tests/L0/run_transformer/run_data_test.py create mode 100644 tests/L0/run_transformer/run_initialize_test.py create mode 100644 tests/L0/run_transformer/run_layers_test.py create mode 100644 tests/L0/run_transformer/run_mappings_test.py create mode 100644 tests/L0/run_transformer/run_random_test.py create mode 100644 tests/L0/run_transformer/run_utils_test.py create mode 100644 tests/L0/run_transformer/test_fused_softmax.py create mode 100644 tests/L0/run_transformer/test_mpu.py diff --git a/.gitignore b/.gitignore index 3eb247e2d..d30f85c34 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,144 @@ build docs/build *~ __pycache__ +.vscode + +# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions *.so -.vscode \ No newline at end of file + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/apex/__init__.py b/apex/__init__.py index 7027b032e..aa7841346 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -18,3 +18,4 @@ from . import optimizers from . import normalization from . import pyprof +from . import transformer diff --git a/apex/_autocast_utils.py b/apex/_autocast_utils.py new file mode 100644 index 000000000..e076d962d --- /dev/null +++ b/apex/_autocast_utils.py @@ -0,0 +1,8 @@ +import torch + + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) diff --git a/apex/normalization/__init__.py b/apex/normalization/__init__.py index b798883b3..07941f271 100644 --- a/apex/normalization/__init__.py +++ b/apex/normalization/__init__.py @@ -1 +1 @@ -from .fused_layer_norm import FusedLayerNorm +from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index c5b3b49ca..337af76a3 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -1,71 +1,103 @@ -import math -import torch +import importlib import numbers + +import torch from torch.nn.parameter import Parameter from torch.nn import init from torch.nn import functional as F -import importlib + +from apex._autocast_utils import _cast_if_autocast_enabled global fused_layer_norm_cuda fused_layer_norm_cuda = None + class FusedLayerNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + return grad_input, grad_weight, grad_bias, None, None + + +class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_layer_norm_cuda.forward_affine( - input_, ctx.normalized_shape, weight_, bias_, ctx.eps) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) - return grad_input, grad_weight, grad_bias, None, None class FusedLayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) + ctx.save_for_backward(input_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, mean, invvar = ctx.saved_tensors + grad_input = None + grad_input = fused_layer_norm_cuda.backward( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps + ) + return grad_input, None, None + + +def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormAffineFunction.apply(*args) - @staticmethod - def forward(ctx, input, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - output, mean, invvar = fused_layer_norm_cuda.forward( - input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, mean, invvar = ctx.saved_tensors - grad_input = None - grad_input = fused_layer_norm_cuda.backward( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - ctx.eps) - return grad_input, None, None - -def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6): - return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps) def fused_layer_norm(input, normalized_shape, eps=1e-6): - return FusedLayerNormFunction.apply(input, normalized_shape, eps) + args = _cast_if_autocast_enabled(input, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormFunction.apply(*args) + + +def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedLayerNormAffineMixedDtypesFunction.apply(*args) + class FusedLayerNorm(torch.nn.Module): r"""Applies Layer Normalization over a mini-batch of inputs as described in @@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module): .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 """ + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): - super(FusedLayerNorm, self).__init__() + super().__init__() global fused_layer_norm_cuda fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") @@ -141,8 +174,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.weight = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape)) else: - self.register_parameter('weight', None) - self.register_parameter('bias', None) + self.register_parameter("weight", None) + self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): @@ -152,14 +185,34 @@ def reset_parameters(self): def forward(self, input): if not input.is_cuda: - return F.layer_norm( - input, self.normalized_shape, self.weight, self.bias, self.eps) + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) if self.elementwise_affine: - return FusedLayerNormAffineFunction.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) + return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) else: - return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps) + return fused_layer_norm(input, self.normalized_shape, self.eps) def extra_repr(self): - return '{normalized_shape}, eps={eps}, ' \ - 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +# NOTE (mkozuki): Why "mixed"? +# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype +# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. +# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" +class MixedFusedLayerNorm(FusedLayerNorm): + + def __init__(self, normalized_shape, eps=1e-5, **kwargs): + if "elementwise_affine" in kwargs: + import warnings + warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") + elementwise_affine = kwargs.pop("elementwise_affine") + if not elementwise_affine: + raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") + + super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) + + def forward(self, input: torch.Tensor): + # NOTE (mkozuki): CPU path is here mainly for unittest sake. + if not input.is_cuda: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) diff --git a/apex/transformer/README.md b/apex/transformer/README.md new file mode 100644 index 000000000..76f41505e --- /dev/null +++ b/apex/transformer/README.md @@ -0,0 +1,5 @@ +# apex.transformer + +`apex.transformer` is a module which enables efficient large Transformer models at scale. + +`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module. diff --git a/apex/transformer/__init__.py b/apex/transformer/__init__.py new file mode 100644 index 000000000..3bb037f39 --- /dev/null +++ b/apex/transformer/__init__.py @@ -0,0 +1,35 @@ +from . import tensor_parallel +from . import functional +from .enums import LayerType +from .enums import AttnType +from .enums import AttnMaskType +from .parallel_state import ( + is_unitialized, + destroy_model_parallel, + get_data_parallel_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_embedding_group, + get_model_parallel_group, + get_tensor_model_parallel_group, + get_pipeline_model_parallel_group, + get_tensor_model_parallel_rank, + set_tensor_model_parallel_rank, + get_pipeline_model_parallel_rank, + set_pipeline_model_parallel_rank, + is_pipeline_first_stage, + is_pipeline_last_stage, + get_tensor_model_parallel_src_rank, + get_pipeline_model_parallel_first_rank, + get_pipeline_model_parallel_last_rank, + get_pipeline_model_parallel_next_rank, + get_pipeline_model_parallel_prev_rank, + get_tensor_model_parallel_world_size, + set_tensor_model_parallel_world_size, + get_pipeline_model_parallel_world_size, + set_pipeline_model_parallel_world_size, + get_virtual_pipeline_model_parallel_rank, + set_virtual_pipeline_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) diff --git a/apex/transformer/enums.py b/apex/transformer/enums.py new file mode 100644 index 000000000..dc050f6cd --- /dev/null +++ b/apex/transformer/enums.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 diff --git a/apex/transformer/functional/__init__.py b/apex/transformer/functional/__init__.py new file mode 100644 index 000000000..72ccdbe33 --- /dev/null +++ b/apex/transformer/functional/__init__.py @@ -0,0 +1,5 @@ +from .fused_softmax import FusedScaleMaskSoftmax + +__all__ = [ + "FusedScaleMaskSoftmax", +] diff --git a/apex/transformer/functional/fused_softmax.py b/apex/transformer/functional/fused_softmax.py new file mode 100644 index 000000000..09b345d61 --- /dev/null +++ b/apex/transformer/functional/fused_softmax.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from apex._autocast_utils import _cast_if_autocast_enabled +from ..enums import AttnMaskType + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +def scaled_upper_triang_masked_softmax(inputs, _, scale): + b, np, sq, sk = inputs.size() + assert sq == sk, "causal mask is only for self attention" + # Reshaping input to 3D tensor (attn_batches, sq, sk) + inputs = inputs.view(-1, sq, sk) + args = _cast_if_autocast_enabled(inputs, scale) + with torch.cuda.amp.autocast(enabled=False): + probs = ScaledUpperTriangMaskedSoftmax.apply(*args) + return probs.view(b, np, sq, sk) + + +# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. +# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. +# So I needed to manually write two `torch.autograd.Function` inheritances. +# Fused operation which performs following three operations in sequence +# 1. Scale the tensor. +# 2. Apply the mask. +# 3. Perform softmax. +class ScaledMaskedSoftmax(torch.autograd.Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +def scaled_masked_softmax(inputs, mask, scale): + # input is 4D tensor (b, np, sq, sk) + args = _cast_if_autocast_enabled(inputs, mask, scale) + with torch.cuda.amp.autocast(enabled=False): + return ScaledMaskedSoftmax.apply(*args) + + +class FusedScaleMaskSoftmax(torch.nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super().__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + if self.input_in_fp16 and self.input_in_bf16: + raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + if not (self.scale is None or softmax_in_fp32): + raise RuntimeError("softmax should be in fp32 when scaled") + + if self.scaled_masked_softmax_fusion: + if self.attn_mask_type == AttnMaskType.causal: + self.fused_softmax_func = scaled_upper_triang_masked_softmax + elif self.attn_mask_type == AttnMaskType.padding: + self.fused_softmax_func = scaled_masked_softmax + else: + raise ValueError("Invalid attn_mask_type.") + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + # input.shape = [b, np, sq, sk] + scale = self.scale if self.scale is not None else 1.0 + return self.fused_softmax_func(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/apex/transformer/parallel_state.py b/apex/transformer/parallel_state.py new file mode 100644 index 000000000..6ee65513c --- /dev/null +++ b/apex/transformer/parallel_state.py @@ -0,0 +1,346 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model and data parallel groups.""" +import torch + +# TODO (mkozuki): Consider dissecting utils as this utils import is here +# only for ensure_divisibility +from .tensor_parallel import utils + + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage +_PIPELINE_GLOBAL_RANKS = None + + +def is_unitialized(): + """Useful for code segments that may be accessed with or without mpu initialization""" + return _DATA_PARALLEL_GROUP is None + + +def initialize_model_parallel( + tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None +): + """ + Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used to parallelize model tensor. + pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + if torch.distributed.get_rank() == 0: + print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_)) + print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) + pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) + # TODO (mkozuki): Consider moving `ensure_divisibility` to this file. + utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size) + data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + num_data_parallel_groups = world_size // data_parallel_size + + if virtual_pipeline_model_parallel_size_ is not None: + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(data_parallel_size): + ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" + global _EMBEDDING_GROUP + assert _EMBEDDING_GROUP is None, "embedding group is already initialized" + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + else: + embedding_ranks = ranks + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized" + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized" + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" + return _EMBEDDING_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if ( + get_virtual_pipeline_model_parallel_world_size() is not None + and get_virtual_pipeline_model_parallel_rank() != 0 + ): + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size() + if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1 + ): + return False + return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_pipeline_model_parallel_first_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + + +def get_pipeline_model_parallel_next_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None diff --git a/apex/transformer/tensor_parallel/__init__.py b/apex/transformer/tensor_parallel/__init__.py new file mode 100644 index 000000000..e87853e53 --- /dev/null +++ b/apex/transformer/tensor_parallel/__init__.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model parallel utility interface.""" + +from .cross_entropy import vocab_parallel_cross_entropy + +from .data import broadcast_data + +from .layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + set_tensor_model_parallel_attributes, + set_defaults_if_not_set_tensor_model_parallel_attributes, + copy_tensor_model_parallel_attributes, +) + +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region, +) + +from .random import ( + checkpoint, + get_cuda_rng_tracker, + init_checkpointed_activations_memory_buffer, + model_parallel_cuda_manual_seed, + reset_checkpointed_activations_memory_buffer, + gather_split_1d_tensor, + split_tensor_into_1d_equal_chunks, +) + +from .utils import divide, split_tensor_along_last_dim + + +__all__ = [ + # cross_entropy.py + "vocab_parallel_cross_entropy", + # data.py + "broadcast_data", + # layers.py + "ColumnParallelLinear", + "RowParallelLinear", + "VocabParallelEmbedding", + "set_tensor_model_parallel_attributes", + "set_defaults_if_not_set_tensor_model_parallel_attributes", + "copy_tensor_model_parallel_attributes", + # mappings.py + "copy_to_tensor_model_parallel_region", + "gather_from_tensor_model_parallel_region", + "reduce_from_tensor_model_parallel_region", + "scatter_to_tensor_model_parallel_region", + # random.py + "checkpoint", + "get_cuda_rng_tracker", + "init_checkpointed_activations_memory_buffer", + "model_parallel_cuda_manual_seed", + "reset_checkpointed_activations_memory_buffer", + "gather_split_1d_tensor", + "split_tensor_into_1d_equal_chunks", + # utils.py + "divide", + "split_tensor_along_last_dim", +] diff --git a/apex/transformer/tensor_parallel/cross_entropy.py b/apex/transformer/tensor_parallel/cross_entropy.py new file mode 100644 index 000000000..21d019ae4 --- /dev/null +++ b/apex/transformer/tensor_parallel/cross_entropy.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..parallel_state import get_tensor_model_parallel_group +from ..parallel_state import get_tensor_model_parallel_rank +from ..parallel_state import get_tensor_model_parallel_world_size +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce( + predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target): + """Helper function for the cross entropy.""" + return _VocabParallelCrossEntropy.apply(torch.clone(vocab_parallel_logits), target) diff --git a/apex/transformer/tensor_parallel/data.py b/apex/transformer/tensor_parallel/data.py new file mode 100644 index 000000000..901df041b --- /dev/null +++ b/apex/transformer/tensor_parallel/data.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..parallel_state import get_tensor_model_parallel_group +from ..parallel_state import get_tensor_model_parallel_rank +from ..parallel_state import get_tensor_model_parallel_src_rank + + +_MAX_DATA_DIM = 5 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, "{} has data type {} which " "is different than {}".format( + key, data[key].dtype, target_dtype + ) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_tensor_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast( + sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), + ) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) + # Pack on rank zero. + if get_tensor_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) + + # Broadcast + torch.distributed.broadcast( + flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), + ) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/apex/transformer/tensor_parallel/layers.py b/apex/transformer/tensor_parallel/layers.py new file mode 100644 index 000000000..45630541f --- /dev/null +++ b/apex/transformer/tensor_parallel/layers.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from ..parallel_state import get_tensor_model_parallel_group +from ..parallel_state import get_tensor_model_parallel_rank +from ..parallel_state import get_tensor_model_parallel_world_size +from .mappings import copy_to_tensor_model_parallel_region +from .mappings import gather_from_tensor_model_parallel_region +from .mappings import reduce_from_tensor_model_parallel_region +from .mappings import scatter_to_tensor_model_parallel_region +from .random import get_cuda_rng_tracker +from .utils import divide +from .utils import VocabUtility + + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + "tensor_model_parallel": False, + "partition_dim": -1, + "partition_stride": 1, +} + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or ( + get_tensor_model_parallel_rank() == 0 + ) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, "tensor_model_parallel", is_parallel) + setattr(tensor, "partition_dim", dim) + setattr(tensor, "partition_stride", stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) + + with get_cuda_rng_tracker().fork(): + init_method(weight) + + +# TODO (mkozuki): Re-consider removing params_dtype from arguments to make this +# more parallel with _initialize_affine_weight_gpu +def _initialize_affine_weight_cpu( + weight, + output_size, + input_size, + per_partition_size, + partition_dim, + init_method, + stride=1, + return_master_weight=False, + *, + params_dtype=torch.float32, +): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + + set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) + init_method(master_weight) + master_weight = master_weight.to(dtype=params_dtype) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__( + self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False, + ): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2.0 + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + # Allocate weights and initialize. + if use_cpu_initialization: + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype) + ) + _initialize_affine_weight_cpu( + self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, + params_dtype=params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_tensor_model_parallel_region(output_parallel) + return output + + +class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): + """ + Column-parallel linear layer execution with asynchronous all-reduce + execution in backprop. + """ + @staticmethod + def forward(ctx, input, weight, bias): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + output = torch.matmul(input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + grad_input = grad_output.matmul(weight) + # Asyncronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_tensor_model_parallel_group(), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + grad_weight = grad_output.t().matmul(input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + handle.wait() + return grad_input, grad_weight, grad_bias + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gether on output and make Y avaiable + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__( + self, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + *, + no_async_tensor_model_parallel_allreduce=False, + params_dtype=torch.float32, + use_cpu_initialization=False, + ): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)) + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + params_dtype=params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) + + if bias: + if use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype)) + else: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + self.async_tensor_model_parallel_allreduce = ( + not no_async_tensor_model_parallel_allreduce and + world_size > 1) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + if self.async_tensor_model_parallel_allreduce: + input_shape = input_.shape + input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) + # Matrix multiply with asynchronous all-reduce execution + output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply( + input_, self.weight, bias) + output_parallel = output_parallel.view( + input_shape[0], input_shape[1], output_parallel.shape[1]) + else: + # Set up backprop all-reduce. + input_parallel = copy_to_tensor_model_parallel_region(input_) + + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, bias) + + if self.gather_output: + # All-gather across the partitions. + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimization where bias + can be fused with other elementwise operations. We skip + adding bias but instead return it. + """ + + def __init__( + self, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + *, + params_dtype=torch.float32, + use_cpu_initialization=False, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + self.skip_bias_add = skip_bias_add + + # as an argument to this function? + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype)) + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + params_dtype=params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) + _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) + if bias: + if use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + else: + self.bias = Parameter( + torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype) + ) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight) + # All-reduce across all the partitions. + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/apex/transformer/tensor_parallel/mappings.py b/apex/transformer/tensor_parallel/mappings.py new file mode 100644 index 000000000..8159e0b62 --- /dev/null +++ b/apex/transformer/tensor_parallel/mappings.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..parallel_state import get_tensor_model_parallel_group +from ..parallel_state import get_tensor_model_parallel_world_size +from ..parallel_state import get_tensor_model_parallel_rank +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_): + return _gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) diff --git a/apex/transformer/tensor_parallel/memory.py b/apex/transformer/tensor_parallel/memory.py new file mode 100644 index 000000000..c23483d0a --- /dev/null +++ b/apex/transformer/tensor_parallel/memory.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +# A dictionary of all the memory buffers allocated. +_MEM_BUFFS = dict() + + +def allocate_mem_buff(name, numel, dtype, track_usage): + """Allocate a memory buffer.""" + assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name) + _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) + return _MEM_BUFFS[name] + + +def get_mem_buff(name): + """Get the memory buffer.""" + return _MEM_BUFFS[name] + + +class MemoryBuffer: + """Contiguous memory buffer. + Allocate a contiguous memory of type `dtype` and size `numel`. It is + used to reduce memory fragmentation. + + Usage: After the allocation, the `_start` index is set tot the first + index of the memory. A memory chunk starting from `_start` index + can be `allocated` for an input tensor, with the elements of the + tensor being coppied. The buffer can be reused by resetting the + `_start` index. + + """ + + def __init__(self, name, numel, dtype, track_usage): + if torch.distributed.get_rank() == 0: + element_size = torch.tensor([], dtype=dtype).element_size() + print( + "> building the {} memory buffer with {} num elements " + "and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024), + flush=True, + ) + self.name = name + self.numel = numel + self.dtype = dtype + self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) + + # Index tracking the start of the free memory. + self._start = 0 + + # Values used for tracking usage. + self.track_usage = track_usage + if self.track_usage: + self.in_use_value = 0.0 + self.total_value = 0.0 + + def reset(self): + """Reset the buffer start index to the beginning of the buffer.""" + self._start = 0 + + def is_in_use(self): + """Whether the current buffer hold on to any memory.""" + return self._start > 0 + + def numel_in_use(self): + """Return number of elements in use.""" + return self._start + + def add(self, tensor): + """Allocate a chunk of memory from the buffer to tensor and copy + the values.""" + assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format( + tensor.dtype, self.dtype + ) + # Number of elements of the input tensor. + tensor_numel = torch.numel(tensor) + new_start = self._start + tensor_numel + assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format( + tensor_numel, self.numel - self._start + ) + # New tensor is a view into the memory. + new_tensor = self.data[self._start : new_start] + self._start = new_start + new_tensor = new_tensor.view(tensor.shape) + new_tensor.copy_(tensor) + # Return a pointer to the new tensor. + return new_tensor + + def get_data(self): + """Return the data currently in use.""" + if self.track_usage: + self.in_use_value += float(self._start) + self.total_value += float(self.numel) + return self.data[: self._start] + + def print_average_usage(self): + """Print memory usage average over time. We would like this value + to be as high as possible.""" + assert self.track_usage, "You need to enable track usage." + if torch.distributed.get_rank() == 0: + print( + " > usage of {} memory buffer: {:.2f} %".format( + self.name, self.in_use_value * 100.0 / self.total_value + ), + flush=True, + ) + + +class RingMemBuffer: + """A ring of memory buffers.""" + + def __init__(self, name, num_buffers, numel, dtype, track_usage): + self.num_buffers = num_buffers + self.buffers = [ + allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) for i in range(num_buffers) + ] + self._index = -1 + + def get_next_buffer(self): + self._index += 1 + self._index = self._index % self.num_buffers + buff = self.buffers[self._index] + assert not buff.is_in_use(), "buffer is already in use." + return buff diff --git a/apex/transformer/tensor_parallel/microbatches.py b/apex/transformer/tensor_parallel/microbatches.py new file mode 100644 index 000000000..ca6330688 --- /dev/null +++ b/apex/transformer/tensor_parallel/microbatches.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megatron number of micro-batches calculators.""" +from abc import ABC +from abc import abstractmethod + + +def build_num_microbatches_calculator(args): + + # Constant num micro-batches. + if args.rampup_batch_size is None: + num_microbatches_calculator = ConstantNumMicroBatches( + args.global_batch_size, args.micro_batch_size, args.data_parallel_size + ) + if args.rank == 0: + print( + "setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True + ) + + else: + assert len(args.rampup_batch_size) == 3, ( + "expected the following " + "format: --rampup-batch-size " + " " + ) + start_batch_size = int(args.rampup_batch_size[0]) + batch_size_increment = int(args.rampup_batch_size[1]) + ramup_samples = int(args.rampup_batch_size[2]) + if args.rank == 0: + print( + "will use batch size rampup starting from global batch " + "size {} to global batch size {} with batch size increments " + "{} over {} samples.".format( + start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples + ), + flush=True, + ) + num_microbatches_calculator = RampupBatchsizeNumMicroBatches( + start_batch_size, + batch_size_increment, + ramup_samples, + args.global_batch_size, + args.micro_batch_size, + args.data_parallel_size, + ) + + return num_microbatches_calculator + + +class NumMicroBatchesCalculator(ABC): + def __init__(self): + self.num_micro_batches = None + self.current_global_batch_size = None + + def get(self): + return self.num_micro_batches + + def get_current_global_batch_size(self): + return self.current_global_batch_size + + @abstractmethod + def update(self, consumed_samples, consistency_check): + pass + + +class ConstantNumMicroBatches(NumMicroBatchesCalculator): + def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): + micro_batch_times_data_parallel = micro_batch_size * data_parallel_size + assert global_batch_size % micro_batch_times_data_parallel == 0, ( + "global batch size ({}) is not divisible by micro batch size ({})" + " times data parallel size ({})".format(global_batch_size, micro_batch_size, data_parallel_size) + ) + self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel + assert self.num_micro_batches >= 1 + self.current_global_batch_size = global_batch_size + + def update(self, consumed_samples, consistency_check): + pass + + +class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): + def __init__( + self, + start_batch_size, + batch_size_increment, + ramup_samples, + global_batch_size, + micro_batch_size, + data_parallel_size, + ): + """Batch size ramp up. + Over + steps = (global-batch-size - start-batch-size) / batch_size_increment + increment batch size from start-batch-size to global-batch-size using + rampup-samples / steps + samples. + Arguments: + start_batch_size: global batch size to start with + batch_size_increment: global batch size increments + ramup_samples: number of samples to use ramp up global + batch size from `start_batch_size` to `global_batch_size` + global_batch_size: global batch size post rampup + micro_batch_size: micro batch size + data_parallel_size: data parallel size. + """ + + self.micro_batch_size = micro_batch_size + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size + assert self.micro_batch_times_data_parallel_size > 0 + + assert start_batch_size > 0 + self.start_batch_size = start_batch_size + + assert global_batch_size > 0 + self.global_batch_size = global_batch_size + diff_batch_size = self.global_batch_size - self.start_batch_size + assert diff_batch_size >= 0 + assert batch_size_increment > 0 + self.batch_size_increment = batch_size_increment + assert diff_batch_size % batch_size_increment == 0, ( + "expected " + "global batch size interval ({}) to be divisible by global batch " + "size increment ({})".format(diff_batch_size, batch_size_increment) + ) + + num_increments = diff_batch_size // self.batch_size_increment + self.ramup_samples = ramup_samples + assert self.ramup_samples >= 0 + self.rampup_samples_per_increment = self.ramup_samples / num_increments + + # Initialize number of microbatches. + self.update(0, False) + + def update(self, consumed_samples, consistency_check): + + if consumed_samples > self.ramup_samples: + self.current_global_batch_size = self.global_batch_size + else: + steps = int(consumed_samples / self.rampup_samples_per_increment) + self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment + assert self.current_global_batch_size <= self.global_batch_size + + if consistency_check: + assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, ( + "current global " + "batch size ({}) is not divisible by micro-batch-size ({}) times" + "data parallel size ({})".format( + self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size + ) + ) + self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size diff --git a/apex/transformer/tensor_parallel/random.py b/apex/transformer/tensor_parallel/random.py new file mode 100644 index 000000000..fb4c8015f --- /dev/null +++ b/apex/transformer/tensor_parallel/random.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call, device as device_ctx_manager +from torch.utils.checkpoint import detach_variable + +from ..parallel_state import get_data_parallel_rank +from ..parallel_state import get_tensor_model_parallel_group +from ..parallel_state import get_tensor_model_parallel_rank +from ..parallel_state import get_tensor_model_parallel_world_size +from .memory import allocate_mem_buff + + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" + + +# Whether apply model parallelsim to checkpointed hidden states. +_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None + + +# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`, +# `get_tensor_model_parallel_world_size()` might be alternative. +def init_checkpointed_activations_memory_buffer( + micro_batch_size, + max_position_embeddings, + hidden_size, + num_layers, + tensor_model_parallel_size, + checkpoint_num_layers, + fp16, +): + """Initializ the memory buffer for the checkpointed activations.""" + + per_layer = micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size + assert num_layers % checkpoint_num_layers == 0, "number of layers is not divisible by checkpoint-num-layers" + num_checkpointer_layers = num_layers // checkpoint_num_layers + numel = per_layer * num_checkpointer_layers + dtype = torch.half + if not fp16: + dtype = torch.float + + global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER + assert ( + _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None + ), "checkpointed activations memory buffer is already allocated." + _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( + "checkpointed activations", numel, dtype, track_usage=False + ) + + +def reset_checkpointed_activations_memory_buffer(): + """Reset the memory used for checkpointing.""" + if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: + _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + + else: + # newer PyTorch + if device == -1: + device = torch.device("cuda") + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +def split_tensor_into_1d_equal_chunks(tensor): + """Break a tensor into equal 1D chunks.""" + data = tensor.view(-1) + partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() + start_index = partition_size * get_tensor_model_parallel_rank() + end_index = start_index + partition_size + return data[start_index:end_index] + + +def gather_split_1d_tensor(tensor): + """Opposite of above function, gather values from model parallel ranks.""" + world_size = get_tensor_model_parallel_world_size() + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] + torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group()) + return gathered + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception("seed {} already exists".format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception("cuda rng state {} already exists".format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception("cuda rng state {} is not added".format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: + ctx.input_0_shape = args[0].data.shape + args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) + args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data) + + # Store everything. + ctx.save_for_backward(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") + inputs = ctx.saved_tensors + if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: + inputs[0].data = gather_split_1d_tensor(inputs[0].data) + inputs[0].data = inputs[0].data.view(ctx.input_0_shape) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + return (None,) + grads + + +def checkpoint(function, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, *args) diff --git a/apex/transformer/tensor_parallel/tests/__init__.py b/apex/transformer/tensor_parallel/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apex/transformer/tensor_parallel/tests/arguments.py b/apex/transformer/tensor_parallel/tests/arguments.py new file mode 100644 index 000000000..19338ecc4 --- /dev/null +++ b/apex/transformer/tensor_parallel/tests/arguments.py @@ -0,0 +1,766 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megatron arguments.""" +import argparse +import os + +import torch + + +def parse_args(extra_args_provider=None, defaults={}, + ignore_unknown_args=False): + """Parse all arguments.""" + parser = argparse.ArgumentParser(description='Megatron-LM Arguments', + allow_abbrev=False) + + # Standard arguments. + parser = _add_network_size_args(parser) + parser = _add_regularization_args(parser) + parser = _add_training_args(parser) + parser = _add_initialization_args(parser) + parser = _add_learning_rate_args(parser) + parser = _add_checkpointing_args(parser) + parser = _add_mixed_precision_args(parser) + parser = _add_distributed_args(parser) + parser = _add_validation_args(parser) + parser = _add_data_args(parser) + parser = _add_autoresume_args(parser) + parser = _add_biencoder_args(parser) + parser = _add_vit_args(parser) + parser = _add_logging_args(parser) + + # Custom arguments. + if extra_args_provider is not None: + parser = extra_args_provider(parser) + + # Parse. + if ignore_unknown_args: + args, _ = parser.parse_known_args() + else: + args = parser.parse_args() + + # Distributed args. + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv("WORLD_SIZE", '1')) + # Tensor model parallel size. + args.tensor_model_parallel_size = min( + args.tensor_model_parallel_size, args.world_size) + assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ + ' ({}) is not divisible by tensor model parallel size ({})'.format( + args.world_size, args.tensor_model_parallel_size) + # Pipeline model parallel size. + args.pipeline_model_parallel_size = min( + args.pipeline_model_parallel_size, + (args.world_size // args.tensor_model_parallel_size)) + # Checks. + model_parallel_size = args.pipeline_model_parallel_size * \ + args.tensor_model_parallel_size + assert args.world_size % model_parallel_size == 0, 'world size is not'\ + ' divisible by tensor parallel size ({}) times pipeline parallel ' \ + 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size) + args.data_parallel_size = args.world_size // model_parallel_size + if args.rank == 0: + print('using world size: {}, data-parallel-size: {}, ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size), flush=True) + + # Deprecated arguments + assert args.batch_size is None, '--batch-size argument is no longer ' \ + 'valid, use --micro-batch-size instead' + del args.batch_size + assert args.warmup is None, '--warmup argument is no longer valid, use ' \ + '--lr-warmup-fraction instead' + del args.warmup + assert args.model_parallel_size is None, '--model-parallel-size is no ' \ + 'longer valid, use --tensor-model-parallel-size instead' + del args.model_parallel_size + + # Set input defaults. + for key in defaults: + # For default to be valid, it should not be provided in the + # arguments that are passed to the program. We check this by + # ensuring the arg is set to None. + if getattr(args, key) is not None: + if args.rank == 0: + print('WARNING: overriding default arguments for {key}:{v} \ + with {key}:{v2}'.format(key=key, v=defaults[key], + v2=getattr(args, key)), + flush=True) + else: + setattr(args, key, defaults[key]) + + # Batch size. + assert args.micro_batch_size is not None + assert args.micro_batch_size > 0 + if args.global_batch_size is None: + args.global_batch_size = args.micro_batch_size * args.data_parallel_size + if args.rank == 0: + print('setting global batch size to {}'.format( + args.global_batch_size), flush=True) + assert args.global_batch_size > 0 + if args.num_layers_per_virtual_pipeline_stage is not None: + assert args.pipeline_model_parallel_size > 2, \ + 'pipeline-model-parallel size should be greater than 2 with ' \ + 'interleaved schedule' + assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ + 'number of layers is not divisible by number of layers per virtual ' \ + 'pipeline stage' + args.virtual_pipeline_model_parallel_size = \ + (args.num_layers // args.pipeline_model_parallel_size) // \ + args.num_layers_per_virtual_pipeline_stage + else: + args.virtual_pipeline_model_parallel_size = None + + # Parameters dtype. + args.params_dtype = torch.float + if args.fp16: + assert not args.bf16 + args.params_dtype = torch.half + if args.bf16: + assert not args.fp16 + args.params_dtype = torch.bfloat16 + # bfloat16 requires gradient accumulation and all-reduce to + # be done in fp32. + if not args.accumulate_allreduce_grads_in_fp32: + args.accumulate_allreduce_grads_in_fp32 = True + if args.rank == 0: + print('accumulate and all-reduce gradients in fp32 for ' + 'bfloat16 data type.', flush=True) + + if args.rank == 0: + print('using {} for parameters ...'.format(args.params_dtype), + flush=True) + + # If we do accumulation and all-reduces in fp32, we need to have + # local DDP and we should set the use-contiguous-buffers-in-ddp. + if args.accumulate_allreduce_grads_in_fp32: + assert args.DDP_impl == 'local' + args.use_contiguous_buffers_in_ddp = True + + # If we use a contiguous buffer to hold main grads, we need to have + # local DDP. + if args.use_contiguous_buffers_in_ddp: + assert args.DDP_impl == 'local' + + if args.dataloader_type is None: + args.dataloader_type = 'single' + + # Consumed tokens. + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + + # Iteration-based training. + if args.train_iters: + # If we use iteration-based training, make sure the + # sample-based options are off. + assert args.train_samples is None, \ + 'expected iteration-based training' + assert args.lr_decay_samples is None, \ + 'expected iteration-based learning rate decay' + assert args.lr_warmup_samples == 0, \ + 'expected iteration-based learning rate warmup' + assert args.rampup_batch_size is None, \ + 'expected no batch-size rampup for iteration-based training' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_iters == 0, \ + 'can only specify one of lr-warmup-fraction and lr-warmup-iters' + + # Sample-based training. + if args.train_samples: + # If we use sample-based training, make sure the + # iteration-based options are off. + assert args.train_iters is None, \ + 'expected sample-based training' + assert args.lr_decay_iters is None, \ + 'expected sample-based learning rate decay' + assert args.lr_warmup_iters == 0, \ + 'expected sample-based learnig rate warmup' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_samples == 0, \ + 'can only specify one of lr-warmup-fraction ' \ + 'and lr-warmup-samples' + + # Check required arguments. + required_args = ['num_layers', 'hidden_size', 'num_attention_heads', + 'max_position_embeddings'] + for req_arg in required_args: + _check_arg_is_not_none(args, req_arg) + + # Checks. + if args.ffn_hidden_size is None: + args.ffn_hidden_size = 4 * args.hidden_size + + if args.kv_channels is None: + assert args.hidden_size % args.num_attention_heads == 0 + args.kv_channels = args.hidden_size // args.num_attention_heads + + if args.seq_length is not None: + assert args.encoder_seq_length is None + args.encoder_seq_length = args.seq_length + else: + assert args.encoder_seq_length is not None + args.seq_length = args.encoder_seq_length + + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + if args.lr is not None: + assert args.min_lr <= args.lr + if args.save is not None: + assert args.save_interval is not None + # Mixed precision checks. + if args.fp16_lm_cross_entropy: + assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' + if args.fp32_residual_connection: + assert args.fp16 or args.bf16, \ + 'residual connection in fp32 only supported when using fp16 or bf16.' + # Activation checkpointing. + if args.distribute_checkpointed_activations: + assert args.checkpoint_activations, \ + 'for distribute-checkpointed-activations to work you '\ + 'need to enable checkpoint-activations' + + _print_args(args) + return args + + +def _print_args(args): + """Print arguments.""" + if args.rank == 0: + print('------------------------ arguments ------------------------', + flush=True) + str_list = [] + for arg in vars(args): + dots = '.' * (48 - len(arg)) + str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print('-------------------- end of arguments ---------------------', + flush=True) + + +def _check_arg_is_not_none(args, arg): + assert getattr(args, arg) is not None, '{} argument is None'.format(arg) + + +def _add_network_size_args(parser): + group = parser.add_argument_group(title='network size') + + group.add_argument('--num-layers', type=int, default=None, + help='Number of transformer layers.') + group.add_argument('--hidden-size', type=int, default=None, + help='Tansformer hidden size.') + group.add_argument('--ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') + group.add_argument('--num-attention-heads', type=int, default=None, + help='Number of transformer attention heads.') + group.add_argument('--kv-channels', type=int, default=None, + help='Projection weights dimension in multi-head ' + 'attention. This is set to ' + ' args.hidden_size // args.num_attention_heads ' + 'if not provided.') + group.add_argument('--max-position-embeddings', type=int, default=None, + help='Maximum number of position embeddings to use. ' + 'This is the size of position embedding.') + group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + group.add_argument('--layernorm-epsilon', type=float, default=1e-5, + help='Layer norm epsilon.') + group.add_argument('--apply-residual-connection-post-layernorm', + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' + 'reasons.') + group.add_argument('--onnx-safe', type=bool, required=False, + help='Use workarounds for known problems with ' + 'Torch ONNX exporter') + group.add_argument('--bert-no-binary-head', action='store_false', + help='Disable BERT binary head.', + dest='bert_binary_head') + + return parser + + +def _add_logging_args(parser): + group = parser.add_argument_group(title='logging') + + group.add_argument('--log-params-norm', action='store_true', + help='If set, calculate and log parameters norm.') + group.add_argument('--log-num-zeros-in-grad', action='store_true', + help='If set, calculate and log the number of zeros in gradient.') + group.add_argument('--tensorboard-log-interval', type=int, default=1, + help='Report to tensorboard interval.') + group.add_argument('--tensorboard-queue-size', type=int, default=1000, + help='Size of the tensorboard queue for pending events ' + 'and summaries before one of the ‘add’ calls forces a ' + 'flush to disk.') + group.add_argument('--log-timers-to-tensorboard', action='store_true', + help='If set, write timers to tensorboard.') + group.add_argument('--log-batch-size-to-tensorboard', action='store_true', + help='If set, write batch-size to tensorboard.') + group.add_argument('--no-log-learnig-rate-to-tensorboard', + action='store_false', + help='Disable learning rate logging to tensorboard.', + dest='log_learning_rate_to_tensorboard') + group.add_argument('--no-log-loss-scale-to-tensorboard', + action='store_false', + help='Disable loss-scale logging to tensorboard.', + dest='log_loss_scale_to_tensorboard') + group.add_argument('--log-validation-ppl-to-tensorboard', + action='store_true', + help='If set, write validation perplexity to ' + 'tensorboard.') + group.add_argument('--log-memory-to-tensorboard', + action='store_true', + help='Enable memory logging to tensorboard.') + + return parser + + +def _add_regularization_args(parser): + group = parser.add_argument_group(title='regularization') + + group.add_argument('--attention-dropout', type=float, default=0.1, + help='Post attention dropout probability.') + group.add_argument('--hidden-dropout', type=float, default=0.1, + help='Dropout probability for hidden state transformer.') + group.add_argument('--weight-decay', type=float, default=0.01, + help='Weight decay coefficient for L2 regularization.') + group.add_argument('--clip-grad', type=float, default=1.0, + help='Gradient clipping based on global L2 norm.') + group.add_argument('--adam-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-beta2', type=float, default=0.999, + help='Second coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-eps', type=float, default=1e-08, + help='Term added to the denominator to improve' + 'numerical stability') + group.add_argument('--sgd-momentum', type=float, default=0.9, + help='Momentum factor for sgd') + + return parser + + +def _add_training_args(parser): + group = parser.add_argument_group(title='training') + + group.add_argument('--micro-batch-size', type=int, default=None, + help='Batch size per model instance (local batch size). ' + 'Global batch size is local batch size times data ' + 'parallel size times number of micro batches.') + group.add_argument('--batch-size', type=int, default=None, + help='Old batch size parameter, do not use. ' + 'Use --micro-batch-size instead') + group.add_argument('--global-batch-size', type=int, default=None, + help='Training batch size. If set, it should be a ' + 'multiple of micro-batch-size times data-parallel-size. ' + 'If this value is None, then ' + 'use micro-batch-size * data-parallel-size as the ' + 'global batch size. This choice will result in 1 for ' + 'number of micro-batches.') + group.add_argument('--rampup-batch-size', nargs='*', default=None, + help='Batch size ramp up with the following values:' + ' --rampup-batch-size ' + ' ' + ' ' + 'For example:' + ' --rampup-batch-size 16 8 300000 \ ' + ' --global-batch-size 1024' + 'will start with global batch size 16 and over ' + ' (1024 - 16) / 8 = 126 intervals will increase' + 'the batch size linearly to 1024. In each interval' + 'we will use approximately 300000 / 126 = 2380 samples.') + group.add_argument('--checkpoint-activations', action='store_true', + help='Checkpoint activation to allow for training ' + 'with larger models, sequences, and batch sizes.') + group.add_argument('--distribute-checkpointed-activations', + action='store_true', + help='If set, distribute checkpointed activations ' + 'across model parallel group.') + group.add_argument('--checkpoint-num-layers', type=int, default=1, + help='chunk size (number of layers) for checkpointing.') + group.add_argument('--train-iters', type=int, default=None, + help='Total number of iterations to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--train-samples', type=int, default=None, + help='Total number of samples to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--log-interval', type=int, default=100, + help='Report loss and timing interval.') + group.add_argument('--exit-interval', type=int, default=None, + help='Exit the program after the iteration is divisible ' + 'by this value.') + group.add_argument('--exit-duration-in-mins', type=int, default=None, + help='Exit the program after this many minutes.') + group.add_argument('--tensorboard-dir', type=str, default=None, + help='Write TensorBoard logs to this directory.') + group.add_argument('--no-masked-softmax-fusion', + action='store_false', + help='Disable fusion of query_key_value scaling, ' + 'masking, and softmax.', + dest='masked_softmax_fusion') + group.add_argument('--no-bias-gelu-fusion', action='store_false', + help='Disable bias and gelu fusion.', + dest='bias_gelu_fusion') + group.add_argument('--no-bias-dropout-fusion', action='store_false', + help='Disable bias and dropout fusion.', + dest='bias_dropout_fusion') + group.add_argument('--optimizer', type=str, default='adam', + choices=['adam', 'sgd'], + help='Optimizer function') + group.add_argument('--dataloader-type', type=str, default=None, + choices=['single', 'cyclic'], + help='Single pass vs multiple pass data loader') + return parser + + +def _add_initialization_args(parser): + group = parser.add_argument_group(title='initialization') + + group.add_argument('--seed', type=int, default=1234, + help='Random seed used for python, numpy, ' + 'pytorch, and cuda.') + group.add_argument('--init-method-std', type=float, default=0.02, + help='Standard deviation of the zero mean normal ' + 'distribution used for weight initialization.') + group.add_argument('--init-method-xavier-uniform', action='store_true', + help='Enable Xavier uniform parameter initialization') + + return parser + + +def _add_learning_rate_args(parser): + group = parser.add_argument_group(title='learning rate') + + group.add_argument('--lr', type=float, default=None, + help='Initial learning rate. Depending on decay style ' + 'and initial warmup, the learing rate at each ' + 'iteration would be different.') + group.add_argument('--lr-decay-style', type=str, default='linear', + choices=['constant', 'linear', 'cosine'], + help='Learning rate decay function.') + group.add_argument('--lr-decay-iters', type=int, default=None, + help='number of iterations to decay learning rate over,' + ' If None defaults to `--train-iters`') + group.add_argument('--lr-decay-samples', type=int, default=None, + help='number of samples to decay learning rate over,' + ' If None defaults to `--train-samples`') + group.add_argument('--lr-warmup-fraction', type=float, default=None, + help='fraction of lr-warmup-(iters/samples) to use ' + 'for warmup (as a float)') + group.add_argument('--lr-warmup-iters', type=int, default=0, + help='number of iterations to linearly warmup ' + 'learning rate over.') + group.add_argument('--lr-warmup-samples', type=int, default=0, + help='number of samples to linearly warmup ' + 'learning rate over.') + group.add_argument('--warmup', type=int, default=None, + help='Old lr warmup argument, do not use. Use one of the' + '--lr-warmup-* arguments above') + group.add_argument('--min-lr', type=float, default=0.0, + help='Minumum value for learning rate. The scheduler' + 'clip values below this threshold.') + group.add_argument('--override-lr-scheduler', action='store_true', + help='Reset the values of the scheduler (learning rate,' + 'warmup iterations, minimum learning rate, maximum ' + 'number of iterations, and decay style from input ' + 'arguments and ignore values from checkpoints. Note' + 'that all the above values will be reset.') + group.add_argument('--use-checkpoint-lr-scheduler', action='store_true', + help='Use checkpoint to set the values of the scheduler ' + '(learning rate, warmup iterations, minimum learning ' + 'rate, maximum number of iterations, and decay style ' + 'from checkpoint and ignore input arguments.') + + return parser + + +def _add_checkpointing_args(parser): + group = parser.add_argument_group(title='checkpointing') + + group.add_argument('--save', type=str, default=None, + help='Output directory to save checkpoints to.') + group.add_argument('--save-interval', type=int, default=None, + help='Number of iterations between checkpoint saves.') + group.add_argument('--no-save-optim', action='store_true', default=None, + help='Do not save current optimizer.') + group.add_argument('--no-save-rng', action='store_true', default=None, + help='Do not save current rng state.') + group.add_argument('--load', type=str, default=None, + help='Directory containing a model checkpoint.') + group.add_argument('--no-load-optim', action='store_true', default=None, + help='Do not load optimizer when loading checkpoint.') + group.add_argument('--no-load-rng', action='store_true', default=None, + help='Do not load rng state when loading checkpoint.') + group.add_argument('--finetune', action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' + 'Assumed when loading a release checkpoint.') + + return parser + + +def _add_mixed_precision_args(parser): + group = parser.add_argument_group(title='mixed precision') + + group.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode.') + group.add_argument('--bf16', action='store_true', + help='Run model in bfloat16 mode.') + group.add_argument('--loss-scale', type=float, default=None, + help='Static loss scaling, positive power of 2 ' + 'values can improve fp16 convergence. If None, dynamic' + 'loss scaling is used.') + group.add_argument('--initial-loss-scale', type=float, default=2**32, + help='Initial loss-scale for dynamic loss scaling.') + group.add_argument('--min-loss-scale', type=float, default=1.0, + help='Minimum loss scale for dynamic loss scale.') + group.add_argument('--loss-scale-window', type=float, default=1000, + help='Window over which to raise/lower dynamic scale.') + group.add_argument('--hysteresis', type=int, default=2, + help='hysteresis for dynamic loss scaling') + group.add_argument('--fp32-residual-connection', action='store_true', + help='Move residual connections to fp32.') + group.add_argument('--no-query-key-layer-scaling', action='store_false', + help='Do not scale Q * K^T by 1 / layer-number.', + dest='apply_query_key_layer_scaling') + group.add_argument('--attention-softmax-in-fp32', action='store_true', + help='Run attention masking and softmax in fp32. ' + 'This flag is ignored unless ' + '--no-query-key-layer-scaling is specified.') + group.add_argument('--accumulate-allreduce-grads-in-fp32', + action='store_true', + help='Gradient accumulation and all-reduce in fp32.') + group.add_argument('--fp16-lm-cross-entropy', action='store_true', + help='Move the cross entropy unreduced loss calculation' + 'for lm head to fp16.') + + return parser + + +def _add_distributed_args(parser): + group = parser.add_argument_group(title='distributed') + + group.add_argument('--tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') + group.add_argument('--pipeline-model-parallel-size', type=int, default=1, + help='Degree of pipeline model parallelism.') + group.add_argument('--model-parallel-size', type=int, default=None, + help='Old model parallel argument, do not use. Use ' + '--tensor-model-parallel-size instead.') + group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, + help='Number of layers per virtual pipeline stage') + group.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + group.add_argument('--DDP-impl', default='local', + choices=['local', 'torch'], + help='which DistributedDataParallel implementation ' + 'to use.') + group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', + help='If set, use contiguous buffer in DDP. Note that ' + 'this option only works woth local DDP.' ) + group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', + help='Use scatter/gather to optimize communication of tensors in pipeline', + dest='scatter_gather_tensors_in_pipeline') + group.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + group.add_argument('--lazy-mpu-init', type=bool, required=False, + help='If set to True, initialize_megatron() ' + 'skips DDP initialization and returns function to ' + 'complete it instead.Also turns on ' + '--use-cpu-initialization flag. This is for ' + 'external DDP manager.' ) + group.add_argument('--use-cpu-initialization', action='store_true', + default=None, help='If set, affine parallel weights ' + 'initialization uses CPU' ) + group.add_argument('--empty-unused-memory-level', default=0, type=int, + choices=[0, 1, 2], + help='Call torch.cuda.empty_cache() each iteration ' + '(training and eval), to reduce fragmentation.' + '0=off, 1=moderate, 2=aggressive.') + return parser + + +def _add_validation_args(parser): + group = parser.add_argument_group(title='validation') + + group.add_argument('--eval-iters', type=int, default=100, + help='Number of iterations to run for evaluation' + 'validation/test for.') + group.add_argument('--eval-interval', type=int, default=1000, + help='Interval between running evaluation on ' + 'validation set.') + + return parser + + +def _add_data_args(parser): + group = parser.add_argument_group(title='data and dataloader') + + group.add_argument('--data-path', nargs='*', default=None, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + group.add_argument('--split', type=str, default='969, 30, 1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file.') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file.') + group.add_argument('--vocab-extra-ids', type=int, default=0, + help='Number of additional vocabulary tokens. ' + 'They are used for span masking in the T5 model') + group.add_argument('--seq-length', type=int, default=None, + help='Maximum sequence length to process.') + group.add_argument('--encoder-seq-length', type=int, default=None, + help='Maximum encoder sequence length to process.' + 'This should be exclusive of --seq-length') + group.add_argument('--decoder-seq-length', type=int, default=None, + help="Maximum decoder sequence length to process.") + group.add_argument('--retriever-seq-length', type=int, default=256, + help='Maximum sequence length for the biencoder model ' + ' for retriever') + group.add_argument('--sample-rate', type=float, default=1.0, + help='sample rate for training data. Supposed to be 0 ' + ' < sample_rate < 1') + group.add_argument('--mask-prob', type=float, default=0.15, + help='Probability of replacing a token with mask.') + group.add_argument('--short-seq-prob', type=float, default=0.1, + help='Probability of producing a short sequence.') + group.add_argument('--mmap-warmup', action='store_true', + help='Warm up mmap files.') + group.add_argument('--num-workers', type=int, default=2, + help="Dataloader number of workers.") + group.add_argument('--tokenizer-type', type=str, + default=None, + choices=['BertWordPieceLowerCase', + 'BertWordPieceCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--data-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer'], + help='Implementation of indexed datasets.') + group.add_argument('--reset-position-ids', action='store_true', + help='Reset posistion ids after end-of-document token.') + group.add_argument('--reset-attention-mask', action='store_true', + help='Reset self attention maske after ' + 'end-of-document token.') + group.add_argument('--eod-mask-loss', action='store_true', + help='Mask loss for the end of document tokens.') + + return parser + + +def _add_autoresume_args(parser): + group = parser.add_argument_group(title='autoresume') + + group.add_argument('--adlr-autoresume', action='store_true', + help='Enable autoresume on adlr cluster.') + group.add_argument('--adlr-autoresume-interval', type=int, default=1000, + help='Intervals over which check for autoresume' + 'termination signal') + + return parser + + +def _add_biencoder_args(parser): + group = parser.add_argument_group(title='biencoder') + + # network size + group.add_argument('--ict-head-size', type=int, default=None, + help='Size of block embeddings to be used in ICT and ' + 'REALM (paper default: 128)') + group.add_argument('--biencoder-projection-dim', type=int, default=0, + help='Size of projection head used in biencoder (paper' + ' default: 128)') + group.add_argument('--biencoder-shared-query-context-model', action='store_true', + help='Whether to share the parameters of the query ' + 'and context models or not') + + # checkpointing + group.add_argument('--ict-load', type=str, default=None, + help='Directory containing an ICTBertModel checkpoint') + group.add_argument('--bert-load', type=str, default=None, + help='Directory containing an BertModel checkpoint ' + '(needed to start ICT and REALM)') + + # data + group.add_argument('--titles-data-path', type=str, default=None, + help='Path to titles dataset used for ICT') + group.add_argument('--query-in-block-prob', type=float, default=0.1, + help='Probability of keeping query in block for ' + 'ICT dataset') + group.add_argument('--use-one-sent-docs', action='store_true', + help='Whether to use one sentence documents in ICT') + group.add_argument('--evidence-data-path', type=str, default=None, + help='Path to Wikipedia Evidence frm DPR paper') + + # training + group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, + default=[], help="Which top-k accuracies to report " + "(e.g. '1 5 20')") + group.add_argument('--retriever-score-scaling', action='store_true', + help='Whether to scale retriever scores by inverse ' + 'square root of hidden size') + + # faiss index + group.add_argument('--block-data-path', type=str, default=None, + help='Where to save/load BlockData to/from') + group.add_argument('--embedding-path', type=str, default=None, + help='Where to save/load Open-Retrieval Embedding' + ' data to/from') + + # indexer + group.add_argument('--indexer-batch-size', type=int, default=128, + help='How large of batches to use when doing indexing ' + 'jobs') + group.add_argument('--indexer-log-interval', type=int, default=1000, + help='After how many batches should the indexer ' + 'report progress') + return parser + + +def _add_vit_args(parser): + group = parser.add_argument_group(title="vit") + + group.add_argument('--num-classes', type=int, default=1000, + help='num of classes in vision classificaiton task') + group.add_argument('--img-dim', type=int, default=224, + help='Image size for vision classification task') + group.add_argument('--num-channels', type=int, default=3, + help='Number of channels in input image data') + group.add_argument('--patch-dim', type=int, default=16, + help='patch dimension used in vit') + + return parser diff --git a/apex/transformer/tensor_parallel/tests/commons.py b/apex/transformer/tensor_parallel/tests/commons.py new file mode 100644 index 000000000..020be9a69 --- /dev/null +++ b/apex/transformer/tensor_parallel/tests/commons.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random + +import numpy +import torch + +from apex import transformer +from apex.transformer.tensor_parallel.tests import global_vars + + +TEST_SUCCESS_MESSAGE = ">> passed the test :-)" + + +class IdentityLayer(torch.nn.Module): + def __init__(self, size, scale=1.0): + super(IdentityLayer, self).__init__() + self.weight = torch.nn.Parameter(scale * torch.randn(size)) + + def forward(self): + return self.weight + + +def set_random_seed(seed): + """Set random seed for reproducibility.""" + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed) + + +def initialize_distributed(backend='nccl'): + """Initialize torch.distributed.""" + # Get local rank in case it is provided. + # parser = argparse.ArgumentParser() + # parser.add_argument('--local_rank', type=int, default=None, + # help='local rank passed from distributed launcher') + # args = parser.parse_args() + args = global_vars.get_args() + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv("WORLD_SIZE", '1')) + + print('> initializing torch.distributed with local rank: {}, ' + 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) + + # Set the device id. + device = rank % torch.cuda.device_count() + if local_rank is not None: + device = local_rank + torch.cuda.set_device(device) + + # Call the init process. + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method) + + +def print_separator(message): + torch.distributed.barrier() + filler_len = (78 - len(message)) // 2 + filler = '-' * filler_len + string = '\n' + filler + ' {} '.format(message) + filler + if torch.distributed.get_rank() == 0: + print(string, flush=True) + torch.distributed.barrier() diff --git a/apex/transformer/tensor_parallel/tests/global_vars.py b/apex/transformer/tensor_parallel/tests/global_vars.py new file mode 100644 index 000000000..1df9fee5f --- /dev/null +++ b/apex/transformer/tensor_parallel/tests/global_vars.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron global variables.""" +import os +import sys +import time + +import torch + +from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator +from apex.transformer.tensor_parallel.tests.arguments import parse_args + +_GLOBAL_ARGS = None +_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None +_GLOBAL_TOKENIZER = None +_GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_ADLR_AUTORESUME = None +_GLOBAL_TIMERS = None + + +def get_args(): + """Return arguments.""" + _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') + return _GLOBAL_ARGS + + +def get_num_microbatches(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() + + +def get_current_global_batch_size(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() + + +def update_num_microbatches(consumed_samples, consistency_check=True): + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, + consistency_check) + + +# def get_tokenizer(): +# """Return tokenizer.""" +# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') +# return _GLOBAL_TOKENIZER + + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + + +def get_adlr_autoresume(): + """ADLR autoresume object. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_ADLR_AUTORESUME + + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + + +def set_global_variables(extra_args_provider=None, args_defaults={}, + ignore_unknown_args=False): + """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" + args = _parse_args(extra_args_provider=extra_args_provider, + defaults=args_defaults, + ignore_unknown_args=ignore_unknown_args) + _build_num_microbatches_calculator(args) + # if args.vocab_file: + # _ = _build_tokenizer(args) + _set_tensorboard_writer(args) + _set_adlr_autoresume(args) + _set_timers() + + +def _parse_args(extra_args_provider=None, defaults={}, + ignore_unknown_args=False): + """Parse entire arguments.""" + global _GLOBAL_ARGS + _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') + _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, + defaults=defaults, + ignore_unknown_args=ignore_unknown_args) + return _GLOBAL_ARGS + + +def _build_num_microbatches_calculator(args): + + global _GLOBAL_NUM_MICROBATCHES_CALCULATOR + _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, + 'num microbatches calculator') + + _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( + args) + + +# def _build_tokenizer(args): +# """Initialize tokenizer.""" +# global _GLOBAL_TOKENIZER +# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') +# _GLOBAL_TOKENIZER = build_tokenizer(args) +# return _GLOBAL_TOKENIZER + + +# def rebuild_tokenizer(args): +# global _GLOBAL_TOKENIZER +# _GLOBAL_TOKENIZER = None +# return _build_tokenizer(args) + + +def _set_tensorboard_writer(args): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + + if hasattr(args, 'tensorboard_dir') and \ + args.tensorboard_dir and args.rank == (args.world_size - 1): + try: + from torch.utils.tensorboard import SummaryWriter + print('> setting tensorboard ...') + _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( + log_dir=args.tensorboard_dir, + max_queue=args.tensorboard_queue_size) + except ModuleNotFoundError: + print('WARNING: TensorBoard writing requested but is not ' + 'available (are you using PyTorch 1.1.0 or later?), ' + 'no TensorBoard logs will be written.', flush=True) + + +def _set_adlr_autoresume(args): + """Initialize ADLR autoresume.""" + global _GLOBAL_ADLR_AUTORESUME + _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') + + if args.adlr_autoresume: + if args.rank == 0: + print('enabling autoresume ...', flush=True) + sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + try: + from userlib.auto_resume import AutoResume + except BaseException: + print('ADLR autoresume is not available, exiting ...') + sys.exit() + + _GLOBAL_ADLR_AUTORESUME = AutoResume + + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) diff --git a/apex/transformer/tensor_parallel/utils.py b/apex/transformer/tensor_parallel/utils.py new file mode 100644 index 000000000..64202afc0 --- /dev/null +++ b/apex/transformer/tensor_parallel/utils.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indecies in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 7b24042b8..df5d4b404 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -130,12 +130,13 @@ std::vector layer_norm( int n1,n2; check_args(input,normalized_shape,n1,n2); at::Tensor output = at::empty_like(input); - at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); + at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, normalized_shape,NULL,NULL,epsilon); return {output, mean, invvar}; } + std::vector layer_norm_affine( at::Tensor input, #ifdef VERSION_GE_1_1 @@ -152,13 +153,35 @@ std::vector layer_norm_affine( int n1,n2; check_args(input,normalized_shape,gamma,beta,n1,n2); at::Tensor output = at::empty_like(input); - at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); + const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); + at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype)); at::Tensor invvar = at::empty_like(mean); cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, normalized_shape,&gamma,&beta,epsilon); return {output, mean, invvar}; } +std::vector layer_norm_affine_mixed_dtypes( + at::Tensor input, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + CHECK_INPUT(input); + int n1, n2; + check_args(input, normalized_shape, n1, n2); + at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); + at::Tensor invvar = at::empty_like(mean); + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + return {output, mean, invvar}; +} + void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, @@ -202,6 +225,7 @@ at::Tensor layer_norm_gradient( &grad_input,NULL,NULL); return grad_input; } + std::vector layer_norm_gradient_affine( at::Tensor dout, at::Tensor mean, @@ -237,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); + + m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); } diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index 1f8033bf0..0a380a0a4 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -56,7 +56,7 @@ void cuWelfordMuSigma2( const int i1, U& mu, U& sigma2, - U* buf) + U* buf) { // Assumptions: // 1) blockDim.x == warpSize @@ -140,7 +140,7 @@ void cuWelfordMuSigma2( const int i1, float& mu, float& sigma2, - float* buf) + float* buf) { // Assumptions: // 1) blockDim.x == warpSize @@ -173,7 +173,7 @@ void cuWelfordMuSigma2( for (int k = 0; k < 8; k+=2) { float2 curr = __half22float2(*((__half2*)(lvals+l+k))); cuWelfordOnlineSum(curr.x,mu,sigma2,count); - cuWelfordOnlineSum(curr.y,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); } } for (; l < n2; ++l) { @@ -276,18 +276,18 @@ struct SharedMemory }; } -template __global__ -void cuApplyLayerNorm( - T* __restrict__ output_vals, +template __device__ +void cuApplyLayerNorm_( + V* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar, const T* __restrict__ vals, const int n1, const int n2, const U epsilon, - const T* __restrict__ gamma, - const T* __restrict__ beta - ) + const V* __restrict__ gamma, + const V* __restrict__ beta + ) { // Assumptions: // 1) blockDim.x == warpSize @@ -299,19 +299,19 @@ void cuApplyLayerNorm( U mu,sigma2; cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); const T* lvals = vals + i1*n2; - T* ovals = output_vals + i1*n2; + V* ovals = output_vals + i1*n2; U c_invvar = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL && beta != NULL) { for (int i = thrx; i < n2; i+=numx) { U curr = static_cast(lvals[i]); - ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; } } else { for (int i = thrx; i < n2; i+=numx) { U curr = static_cast(lvals[i]); - ovals[i] = static_cast(c_invvar * (curr - mu)); + ovals[i] = static_cast(c_invvar * (curr - mu)); } } if (threadIdx.x == 0 && threadIdx.y == 0) { @@ -321,7 +321,24 @@ void cuApplyLayerNorm( } } -template __device__ +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); +} + + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -331,7 +348,7 @@ void cuLoadWriteStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, - const T* dout, + const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, @@ -348,9 +365,9 @@ void cuLoadWriteStridedInputs( int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; } else { warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); @@ -365,7 +382,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -375,7 +392,7 @@ void cuLoadAddStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, - const T* dout, + const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, @@ -392,17 +409,17 @@ void cuLoadAddStridedInputs( int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; } } } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( - const T* __restrict__ dout, + const V* __restrict__ dout, const T* __restrict__ input, const int n1, const int n2, @@ -449,11 +466,11 @@ void cuComputePartGradGammaBeta( for (int offset = blockDim.y/2; offset > 1; offset /= 2) { if (threadIdx.y < offset) { int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1*row_stride + threadIdx.x; - int idx2 = row2*row_stride + threadIdx.x; - warp_buf1[idx1] += warp_buf1[idx2]; - warp_buf2[idx1] += warp_buf2[idx2]; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; } __syncthreads(); } @@ -468,19 +485,19 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const U* part_grad_gamma, const U* part_grad_beta, const int part_size, const int n1, const int n2, - T* grad_gamma, - T* grad_beta) + V* grad_gamma, + V* grad_beta) { // sum partial gradients for gamma and beta SharedMemory shared; - U* buf = shared.getPointer(); + U* buf = shared.getPointer(); int i2 = blockIdx.x * blockDim.x + threadIdx.x; if (i2 < n2) { // each warp does sequential reductions until reduced part_size is num_warps @@ -519,16 +536,16 @@ void cuComputeGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradInput( - const T* __restrict__ dout, + const V* __restrict__ dout, const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, - const T* gamma, + const V* gamma, T* grad_input) { for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { @@ -537,7 +554,7 @@ void cuComputeGradInput( const U c_mean = mean[i1]; const U c_invvar = invvar[i1]; const T* k_input = input + i1*n2; - const T* k_dout = dout + i1*n2; + const V* k_dout = dout + i1*n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL) { @@ -581,7 +598,7 @@ void cuComputeGradInput( // inter-warp reductions if (blockDim.y > 1) { SharedMemory shared; - U* buf = shared.getPointer(); + U* buf = shared.getPointer(); for (int offset = blockDim.y/2; offset > 0; offset /= 2) { // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { @@ -606,7 +623,7 @@ void cuComputeGradInput( if (threadIdx.y !=0) { sum_loss1 = buf[2*threadIdx.x]; sum_loss2 = buf[2*threadIdx.x+1]; - } + } } // all threads now have the two sums over l U fH = (U)n2; @@ -636,35 +653,29 @@ void cuComputeGradInput( } } -template +template void HostApplyLayerNorm( - T* output, + V* output, U* mean, U* invvar, const T* input, int n1, int n2, double epsilon, - const T* gamma, - const T* beta + const V* gamma, + const V* beta ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); const dim3 threads(32,4,1); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? - threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : 0; cuApplyLayerNorm<<>>( - output, - mean, - invvar, - input, - n1,n2, - U(epsilon), - gamma,beta); + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); } void cuda_layer_norm( @@ -684,34 +695,35 @@ void cuda_layer_norm( double epsilon) { using namespace at; - DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", - using accscalar_t = at::acc_type; - HostApplyLayerNorm( - output->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input->DATA_PTR(), - n1,n2, - epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL, - beta != NULL ? beta->DATA_PTR() : NULL); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel", + using accscalar_t = at::acc_type; + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); ) } -template +template void HostLayerNormGradient( - const T* dout, + const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1, int n2, - const T* gamma, - const T* beta, + const V* gamma, + const V* beta, double epsilon, T* grad_input, - T* grad_gamma, - T* grad_beta + V* grad_gamma, + V* grad_beta ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -724,7 +736,13 @@ void HostLayerNormGradient( const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that + // the `cuda_layer_norm_gradient` doesn't support double. + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + at::ScalarType::Float : + input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<>>( dout, @@ -787,21 +805,23 @@ void cuda_layer_norm_gradient( at::Tensor* grad_beta) { using namespace at; - DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", - using accscalar_t = at::acc_type; - HostLayerNormGradient( - dout->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input, - n1,n2, + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", + using accscalar_t = at::acc_type; + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - gamma != NULL ? beta->DATA_PTR() : NULL, - epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); - ) + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) } diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp new file mode 100644 index 000000000..1852aee6f --- /dev/null +++ b/csrc/megatron/scaled_masked_softmax.cpp @@ -0,0 +1,97 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h new file mode 100644 index 000000000..45e8dcea2 --- /dev/null +++ b/csrc/megatron/scaled_masked_softmax.h @@ -0,0 +1,505 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000..902d36dd0 --- /dev/null +++ b/csrc/megatron/scaled_masked_softmax_cuda.cu @@ -0,0 +1,117 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 000000000..ea283588d --- /dev/null +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,72 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.h b/csrc/megatron/scaled_upper_triang_masked_softmax.h new file mode 100644 index 000000000..6df83fc10 --- /dev/null +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,513 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 000000000..5efc3d412 --- /dev/null +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,98 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/csrc/type_shim.h b/csrc/type_shim.h index fc2bb03b9..b9d03a227 100644 --- a/csrc/type_shim.h +++ b/csrc/type_shim.h @@ -34,6 +34,32 @@ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ @@ -106,6 +132,160 @@ } + #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + + + #define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_in = double; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + + template __device__ __forceinline__ T reduce_block_into_lanes (T *x, diff --git a/setup.py b/setup.py index 22382bc02..532e24a0a 100644 --- a/setup.py +++ b/setup.py @@ -206,6 +206,30 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'nvcc':['-O3'] + version_dependent_macros})) + ext_modules.append( + CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda', + sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp', + 'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'], + include_dirs=[os.path.join(this_dir, 'csrc')], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + version_dependent_macros})) + + ext_modules.append( + CUDAExtension(name='scaled_masked_softmax_cuda', + sources=['csrc/megatron/scaled_masked_softmax.cpp', + 'csrc/megatron/scaled_masked_softmax_cuda.cu'], + include_dirs=[os.path.join(this_dir, 'csrc')], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + version_dependent_macros})) + if "--bnp" in sys.argv: sys.argv.remove("--bnp") @@ -495,6 +519,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')], extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag})) + setup( name='apex', version='0.1', diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index 9103eb80f..6d56df69f 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -1,42 +1,171 @@ +import itertools import unittest -import os -import random import torch + import apex -from torch.autograd import Variable - + class TestFusedLayerNorm(unittest.TestCase): + dtype = torch.float + elementwise_affine = False + normalized_shape = [32, 16] + rtol, atol = None, None + fwd_thresholds = dict(rtol=None, atol=None) + bwd_thresholds = dict(rtol=None, atol=None) + def setUp(self): # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one - self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu() - self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda() + self.module_cpu_ = apex.normalization.FusedLayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() + self.module_cuda_ = apex.normalization.FusedLayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) - def _test_same_output(self, batch_size): + def _check_same_output(self, batch_size, contiguous): torch.cuda.manual_seed(42) - self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True) - self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True) - out_cpu_ = self.module_cpu_(self.input_) + if contiguous: + input_shape = [batch_size] + self.normalized_shape + input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) + input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True) + self.assertTrue(input_.is_contiguous()) + self.assertTrue(input_cuda_.is_contiguous()) + else: + input_shape = [batch_size] + self.normalized_shape + input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3] + input_src_ = torch.randn(input_shape, device="cpu") + input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True) + input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True) + # make sure that tensors are NOT contiguous. + self.assertFalse(input_.is_contiguous()) + self.assertFalse(input_cuda_.is_contiguous()) + out_cpu_ = self.module_cpu_(input_) gO = torch.rand_like(out_cpu_) out_cpu_.backward(gO) - out_cuda_ = self.module_cuda_(self.input_cuda_) - gO = gO.cuda() + out_cuda_ = self.module_cuda_(input_cuda_) + gO = gO.to(device="cuda", dtype=self.dtype) out_cuda_.backward(gO) - assert out_cpu_.is_cuda == False - assert out_cuda_.is_cuda == True - torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu()) - torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu()) + self.assertFalse(out_cpu_.is_cuda) + self.assertTrue(out_cuda_.is_cuda) + # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. + # Use `torch.testing.assert_close`. + # See https://github.com/pytorch/pytorch/issues/61844 + torch.testing.assert_allclose( + out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds) + torch.testing.assert_allclose( + input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds) + + def _test_same_output(self, batch_size): + for contiguous in (True, False): + with self.subTest(contiguous=contiguous): + self._check_same_output(batch_size, contiguous) def test_layer_norm(self): self._test_same_output(16) def test_large_batch(self): self._test_same_output(65536) - - + + class TestFusedLayerNormElemWise(TestFusedLayerNorm): + elementwise_affine = True + + +class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): + dtype = torch.half + + def test_large_batch(self): + self.skipTest("Skip to save time") + + +# Megatron style Layer Norm +class TestFusedLayerNormElemWiseMixedDtypes(TestFusedLayerNorm): def setUp(self): - self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cpu() - self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cuda() + self.module_cpu_ = apex.normalization.MixedFusedLayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=True).cpu() + self.module_cuda_ = apex.normalization.MixedFusedLayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=True).to(device="cuda", dtype=self.dtype) + + def test_init_exception(self): + with self.assertRaisesRegex(RuntimeError, "MixedFusedLayerNorm does not support `elementwise_affine = False`"): + apex.normalization.MixedFusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda() + + +class TestFusedLayerNormElemWiseMixedDtypesHalf(TestFusedLayerNormElemWiseMixedDtypes): + dtype = torch.half + + def test_large_batch(self): + self.skipTest("Skip to save time") + + +# NOTE (mkozuki): With the larger threshold values, still flaky. +class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMixedDtypesHalf): + dtype = torch.bfloat16 + # NOTE (mkozuki): [BFloat16 Layer Norm flakiness] + # Use thresholds larger than those used in pytorch, see + # https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26 + fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) + bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + + +class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): + dtype = torch.bfloat16 + # See [BFloat16 Layer Norm flakiness] + fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) + bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + + def test_large_batch(self): + self.skipTest("Skip to save time") + + +def _prep_layers(normalized_shape, elementwise_affine, dtype): + native = torch.nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).to(device="cuda", dtype=dtype) + fused = apex.normalization.FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).cuda() + return native, fused + + +def _prep_inputs(batch_size, normalized_shape, dtype): + shape = (batch_size, *normalized_shape) + fused = torch.randn(shape).cuda().requires_grad_(True) + with torch.no_grad(): + native = fused.clone().to(dtype).requires_grad_(True) + return native, fused + + +autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) + + +class TestAutocastFusedLayerNorm(unittest.TestCase): + bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) + bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + + def setUp(self): + self.batch_size = 16 + self.normalized_shape = [32, 16] + + def _run_test(self, dtype, elementwise_affine): + native, fused = _prep_layers(self.normalized_shape, elementwise_affine, dtype) + native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype) + + expected = native(native_x) + with torch.cuda.amp.autocast(dtype=dtype): + actual = fused(fused_x) + tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_fwd_thresholds + torch.testing.assert_allclose(actual, expected, **tols) + + g_native = torch.rand_like(expected) + with torch.no_grad(): + g_fused = g_native.clone() + expected.backward(g_native) + actual.backward(g_fused) + + tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds + torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols) + def test_autocast(self): + for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): + with self.subTest(f"{dtype}-{elementwise_affine}"): + self._run_test(dtype, elementwise_affine) diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 8a4135d5f..14527a891 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -1,20 +1,72 @@ +"""L0 Tests Runner. + +How to run this script? + +1. Run all the tests: `python /path/to/apex/tests/L0/run_test.py` +2. Run one of the tests (e.g. fused layer norm): + `python /path/to/apex/tests/L0/run_test.py --include run_fused_layer_norm` +3. Run two or more of the tests (e.g. optimizers and fused layer norm): + `python /path/to/apex/tests/L0/run_test.py --include run_optimizers run_fused_layer_norm` +""" +import argparse +import os import unittest import sys -test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"] -runner = unittest.TextTestRunner(verbosity=2) +TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) +TEST_DIRS = [ + "run_amp", + "run_fp16util", + "run_optimizers", + "run_fused_layer_norm", + "run_pyprof_nvtx", + "run_pyprof_data", + "run_mlp", + "run_transformer", +] +DEFAULT_TEST_DIRS = [ + "run_optimizers", + "run_fused_layer_norm", + "run_mlp", + "run_transformer", +] + + +def parse_args(): + parser = argparse.ArgumentParser( + description="L0 test runner", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--include", + nargs="+", + choices=TEST_DIRS, + default=DEFAULT_TEST_DIRS, + help="select a set of tests to run (defaults to ALL tests).", + ) + args, _ = parser.parse_known_args() + return args + + +def main(args): + runner = unittest.TextTestRunner(verbosity=2) + errcode = 0 + for test_dir in args.include: + test_dir = os.path.join(TEST_ROOT, test_dir) + print(test_dir) + suite = unittest.TestLoader().discover(test_dir) -errcode = 0 + print("\nExecuting tests from " + test_dir) -for test_dir in test_dirs: - suite = unittest.TestLoader().discover(test_dir) + result = runner.run(suite) - print("\nExecuting tests from " + test_dir) + if not result.wasSuccessful(): + errcode = 1 - result = runner.run(suite) + sys.exit(errcode) - if not result.wasSuccessful(): - errcode = 1 -sys.exit(errcode) +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/tests/L0/run_transformer/__init__.py b/tests/L0/run_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/L0/run_transformer/run_cross_entropy_test.py b/tests/L0/run_transformer/run_cross_entropy_test.py new file mode 100644 index 000000000..e670941d9 --- /dev/null +++ b/tests/L0/run_transformer/run_cross_entropy_test.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F + +from apex.transformer.tensor_parallel.tests.commons import set_random_seed +from apex.transformer.tensor_parallel.tests.commons import IdentityLayer +from apex.transformer.tensor_parallel.tests.commons import print_separator +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel +from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +from apex.transformer.tensor_parallel.tests import global_vars + + +global_vars.set_global_variables() + + +def torch_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + target = torch.cuda.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), + target.view(-1), + reduction='none').view_as(target).mean() + loss.backward() + return loss, identity.weight.grad + + +def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() + logits = identity() + logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits) + target = torch.cuda.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() + loss.backward() + return loss, identity.weight.grad + + +def test_cross_entropy(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cross entropy with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + batch_size = 13 + seq_length = 17 + vocab_size_per_partition = 11 + logits_scale = 1000.0 + vocab_size = vocab_size_per_partition * tensor_model_parallel_size + seed = 1234 + + loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) + loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) + + error = loss_torch.sub_(loss_mpu).abs().max() + print(' max error in loss on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = grad_torch.sub_(grad_mpu).abs().max() + print(' max error in grad on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cross entropy') + test_cross_entropy(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/tests/L0/run_transformer/run_data_test.py b/tests/L0/run_transformer/run_data_test.py new file mode 100644 index 000000000..ebaadada0 --- /dev/null +++ b/tests/L0/run_transformer/run_data_test.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import operator + +import torch + +from apex.transformer import parallel_state +from apex.transformer.tensor_parallel import data as data_utils +from apex.transformer.tensor_parallel.tests import global_vars +from apex.transformer.tensor_parallel.tests.commons import print_separator +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE + +global_vars.set_global_variables() + + +def test_broadcast_data(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing broadcast_data with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + torch.manual_seed(1234 + parallel_state.get_data_parallel_rank()) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + key_size_t = { + 'key1': [7, 11], + 'key2': [8, 2, 1], + 'key3': [13], + 'key4': [5, 1, 2], + 'key5': [5, 12], + } + keys = list(key_size_t.keys()) + + data = {} + data_t = {} + for key in key_size_t: + data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) + data_t[key] = data[key].clone() + data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) + data_t['keyX'] = data['keyX'].clone() + if parallel_state.get_tensor_model_parallel_rank() != 0: + data = None + + data_utils._check_data_types(keys, data_t, torch.int64) + key_size, key_numel, \ + total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) + for key in keys: + assert key_size[key] == key_size_t[key] + total_numel_t = 0 + for key in keys: + target_size = functools.reduce(operator.mul, key_size_t[key], 1) + assert key_numel[key] == target_size + total_numel_t += target_size + assert total_numel == total_numel_t + + data_b = data_utils.broadcast_data(keys, data, torch.int64) + for key in keys: + tensor = data_t[key].cuda() + assert data_b[key].sub(tensor).abs().max() == 0 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test test broadcast data') + test_broadcast_data(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/tests/L0/run_transformer/run_initialize_test.py b/tests/L0/run_transformer/run_initialize_test.py new file mode 100644 index 000000000..9fa9d524c --- /dev/null +++ b/tests/L0/run_transformer/run_initialize_test.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from apex.transformer import parallel_state +from apex.transformer.tensor_parallel.tests import global_vars +from apex.transformer.tensor_parallel.tests.commons import print_separator +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE + + +global_vars.set_global_variables() + + +def test_initialize_model_parallel(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing initialize_model_parallel with size {} ...'.format( + tensor_model_parallel_size)) + tensor_model_parallel_size_ = min( + tensor_model_parallel_size, + torch.distributed.get_world_size(), + ) + assert not parallel_state.model_parallel_is_initialized() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_) + assert parallel_state.model_parallel_is_initialized() + + # Checks. + def check(group, world_size, rank): + assert world_size == torch.distributed.get_world_size(group=group) + assert rank == torch.distributed.get_rank(group=group) + + # Model parallel. + world_size = tensor_model_parallel_size_ + rank = torch.distributed.get_rank() % tensor_model_parallel_size_ + assert world_size == parallel_state.get_tensor_model_parallel_world_size() + assert rank == parallel_state.get_tensor_model_parallel_rank() + check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) + + # Data parallel. + world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ + rank = torch.distributed.get_rank() // tensor_model_parallel_size + assert world_size == parallel_state.get_data_parallel_world_size() + assert rank == parallel_state.get_data_parallel_rank() + check(parallel_state.get_data_parallel_group(), world_size, rank) + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): + + if torch.distributed.get_rank() == 0: + print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( + tensor_model_parallel_size_)) + tensor_model_parallel_size = min( + tensor_model_parallel_size_, + torch.distributed.get_world_size(), + ) + assert not parallel_state.model_parallel_is_initialized() + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + assert parallel_state.model_parallel_is_initialized() + + # Checks + src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank() + assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test initialize model parallel') + test_initialize_model_parallel(tensor_model_parallel_size) + print_separator('test model parallel source rank') + test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/tests/L0/run_transformer/run_layers_test.py b/tests/L0/run_transformer/run_layers_test.py new file mode 100644 index 000000000..7320b0c5b --- /dev/null +++ b/tests/L0/run_transformer/run_layers_test.py @@ -0,0 +1,559 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from apex.transformer import parallel_state +from apex.transformer.tensor_parallel import layers +from apex.transformer.tensor_parallel.tests import global_vars +from apex.transformer.tensor_parallel.tests.commons import set_random_seed +from apex.transformer.tensor_parallel.tests.commons import print_separator +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE + + +global_vars.set_global_variables() + + +def test_parallel_embedding(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing parallel embedding with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + batch_size = 17 + seq_length = 23 + vocab_size = 48 + hidden_size = 16 + seed = 1236 + + set_random_seed(123) + input_data = torch.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size).cuda() + loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() + + set_random_seed(seed) + embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() + + output = embedding_original(input_data) + loss_original = torch.mul(output, loss_weight).sum() + loss_original.backward() + + set_random_seed(seed) + embedding_parallel = layers.ParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_parallel(input_data) + loss_parallel = torch.mul(output, loss_weight).sum() + loss_parallel.backward() + + set_random_seed(seed) + embedding_vocab_parallel = layers.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_vocab_parallel(input_data) + loss_vocab_parallel = torch.mul(output, loss_weight).sum() + loss_vocab_parallel.backward() + + torch.distributed.barrier() + error = loss_parallel.sub(loss_original).abs() + print(' error in loss (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + torch.distributed.barrier() + error = loss_vocab_parallel.sub(loss_original).abs() + print(' error in loss (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + hidden_size // tensor_model_parallel_size, + 1)[parallel_state.get_tensor_model_parallel_rank()] + error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() + print(' error in grad (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + vocab_size // tensor_model_parallel_size, + 0)[parallel_state.get_tensor_model_parallel_rank()] + error = embedding_vocab_parallel.weight.grad.sub( + weight_grad_orig).abs().max() + print(' error in grad (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_initialize_affine_weight(tensor_model_parallel_size, device): + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing initialize_affine_weight with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed = 12345 + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + + # --------------- + # Column parallel + # --------------- + weight = torch.empty(output_size_coeff, input_size) + set_random_seed(seed) + if device == 'cpu': + layers._initialize_affine_weight_cpu(weight, output_size, input_size, + output_size_coeff, 0, + torch.nn.init.normal_, + params_dtype=global_vars.get_args().params_dtype, + ) + else: + layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0) + + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = parallel_state.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, output_size_coeff, + dim=0)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' column parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # ------------ + # Row parallel + # ------------ + weight = torch.empty(output_size, input_size_coeff) + set_random_seed(seed) + if device == 'cpu': + layers._initialize_affine_weight_cpu( + weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_, + params_dtype=global_vars.get_args().params_dtype) + + else: + layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1) + + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = parallel_state.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, input_size_coeff, + dim=1)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' row parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer2D(torch.nn.Module): + def __init__(self, m, n): + super(IdentityLayer2D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def test_column_parallel_linear(tensor_model_parallel_size): + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing ColumnParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = layers.ColumnParallelLinear( + input_size, output_size, keep_master_weight_for_test=True, + params_dtype=global_vars.get_args().params_dtype, + use_cpu_initialization=global_vars.get_args().use_cpu_initialization, + ).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output, _ = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = parallel_state.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + my_dLdb = torch.split(dLdb, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def test_row_parallel_linear(tensor_model_parallel_size): + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing RowParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = layers.RowParallelLinear( + input_size, output_size, keep_master_weight_for_test=True, + params_dtype=global_vars.get_args().params_dtype, + use_cpu_initialization=global_vars.get_args().use_cpu_initialization, + ).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output, _ = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = parallel_state.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, input_size_coeff, + dim=1)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer3D(torch.nn.Module): + def __init__(self, m, n, k): + super(IdentityLayer3D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n, k)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, + sequence_length): + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + attention_layer = parallel_state.BertParallelSelfAttention(hidden_size, num_att_heads, + dropout_prob).cuda() + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = attention_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = parallel_state.get_tensor_model_parallel_rank() + parallel_state.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer + + +def test_parallel_self_attention(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelSelfAttention with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + dropout_prob = 0.0 # has to be zero + batch_size = 5 + sequence_length = 13 + + rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \ + attention_layer_1, identity_layer_1 = parallel_self_attention( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer = parallel_self_attention( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + assert hideen_size_1 == hidden_size + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + my_lin_grad_list = torch.split( + attention_layer_1.query_key_value.weight.grad, + hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size] + my_lin_grad = torch.cat(my_lin_grad_list, dim=0) + error = my_lin_grad.sub( + attention_layer.query_key_value.weight.grad).abs().max() + torch.distributed.barrier() + print(' weight gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length): + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + intermediate_size = 4 * hidden_size + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + transformer_layer = parallel_state.BertParallelTransformerLayer( + hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, + torch.nn.functional.relu, 1.0e-5).cuda() + + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = transformer_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = parallel_state.get_tensor_model_parallel_rank() + parallel_state.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer + + +def test_parallel_transformer_layer(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelTransformerLayer with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + batch_size = 5 + sequence_length = 13 + + rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \ + transformer_layer_1, identity_layer_1 = parallel_transformer( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer = parallel_transformer( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +if __name__ == '__main__': + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + print_separator('test initialize affine weight cpu') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_initialize_affine_weight(tensor_model_parallel_size, 'cpu') + tensor_model_parallel_size *= 2 + # Reset groups + parallel_state.destroy_model_parallel() + + print_separator('test initialize affine weight gpu') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_initialize_affine_weight(tensor_model_parallel_size, 'gpu') + tensor_model_parallel_size *= 2 + + # Deleted, replaced with vocab parallel embedding? + #tensor_model_parallel_size = 1 + #while tensor_model_parallel_size <= world_size: + # print_separator('test parallel embedding') + # test_parallel_embedding(tensor_model_parallel_size) + # tensor_model_parallel_size *= 2 + + print_separator('test column-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_column_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test row-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_row_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + # Deleted + #print_separator('test parallel self-attention') + #tensor_model_parallel_size = 1 + #while tensor_model_parallel_size <= world_size: + # test_parallel_self_attention(tensor_model_parallel_size) + # tensor_model_parallel_size *= 2 + + #Deleted because PararallelTransformerLayer no longer exists + # print_separator('test parallel transformer') + # tensor_model_parallel_size = 1 + # while tensor_model_parallel_size <= world_size: + # test_parallel_transformer_layer(tensor_model_parallel_size) + # tensor_model_parallel_size *= 2 diff --git a/tests/L0/run_transformer/run_mappings_test.py b/tests/L0/run_transformer/run_mappings_test.py new file mode 100644 index 000000000..f3825c956 --- /dev/null +++ b/tests/L0/run_transformer/run_mappings_test.py @@ -0,0 +1,61 @@ +import torch + +from apex.transformer import parallel_state +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel import mappings +from apex.transformer.tensor_parallel.tests import global_vars + +global_vars.set_global_variables() + + +def test__reduce(args, tensor_model_parallel_size): + print("Testing reduction size =", tensor_model_parallel_size) + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + assert torch.equal( + mappings._reduce(torch.full((10, 10, 10, 10), (50))), + torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size), + ) + parallel_state.destroy_model_parallel() + print("Passed!") + + +def test__split(args, tensor_model_parallel_size): + print("Testing splitting size =", tensor_model_parallel_size) + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + listy = [] + for i in range(tensor_model_parallel_size): + listy.append(torch.randn(10, 1)) + x = torch.cat(tuple(listy), 1) + out = mappings._split(x) + assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()]) + parallel_state.destroy_model_parallel() + print("Passed!") + + +def test__gather(args, tensor_model_parallel_size): + + print("Testing gathering size =", tensor_model_parallel_size) + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + assert torch.equal( + mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])), + torch.tensor(list(range(tensor_model_parallel_size))), + ) + parallel_state.destroy_model_parallel() + print("Passed!") + + +if __name__ == "__main__": + initialize_distributed() + + world_size = torch.distributed.get_world_size() + args = global_vars.get_args() + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test__reduce(args, tensor_model_parallel_size) + test__split(args, tensor_model_parallel_size) + test__gather(args, tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + print(">> passed the test :-)") diff --git a/tests/L0/run_transformer/run_random_test.py b/tests/L0/run_transformer/run_random_test.py new file mode 100644 index 000000000..481535899 --- /dev/null +++ b/tests/L0/run_transformer/run_random_test.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel +from apex.transformer.tensor_parallel.tests import global_vars +from apex.transformer.tensor_parallel.tests.commons import print_separator +from apex.transformer.tensor_parallel.tests.commons import initialize_distributed +from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE + + +global_vars.set_global_variables() + + +def test_set_cuda_rng_state(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing set_rng_state with size {} ...'. + format(tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + size = 123 + seed = 1234 + torch.cuda.manual_seed(seed) + tensor = torch.cuda.FloatTensor(size) + + # Get the state + rng_state = torch.cuda.get_rng_state() + rng_state_copy = rng_state.clone() + + # Do some stuff. + for _ in range(5): + torch.randn(size, out=tensor) + result_1 = tensor.clone() + + assert rng_state.sub(rng_state_copy).max() == 0 + assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 + + # State should be different. + new_rng_state = torch.cuda.get_rng_state() + max_diff = new_rng_state.sub(rng_state).max() + print(' max diff in rng state (should be non-zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), max_diff)) + assert max_diff > 0 + + # Reset the rng state and do the same stuff. + tensor_parallel.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + tensor_parallel.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + result_2 = tensor.clone() + + # Results should be the same + error = result_2.sub(result_1).abs().max() + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Input state should have remained intact. + error = rng_state.sub(rng_state_copy).max() + print(' max error in rng state (should be zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), error)) + assert error == 0 + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +def test_cuda_rng_tracker(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cuda rng tracker with size {} ...'. + format(tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + seed_1 = 1234 + seed_2 = 4321 + size = [12, 21] + tensor = torch.cuda.FloatTensor(size) + + # Set to seed_1 and generate two tensors. + torch.cuda.manual_seed(seed_1) + torch.randn(size, out=tensor) + target_11 = tensor.clone() + torch.randn(size, out=tensor) + target_12 = tensor.clone() + + # Set to seed_2 and generate two tensors. + torch.cuda.manual_seed(seed_2) + torch.randn(size, out=tensor) + target_21 = tensor.clone() + torch.randn(size, out=tensor) + target_22 = tensor.clone() + + # Now if we interleave seed_1 and seed_2, + # we should still get the same tensors + torch.cuda.manual_seed(seed_1) + tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2) + + torch.randn(size, out=tensor) + result_11 = tensor.clone() + + with tensor_parallel.random.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_21 = tensor.clone() + + torch.randn(size, out=tensor) + result_12 = tensor.clone() + + with tensor_parallel.random.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_22 = tensor.clone() + + diff = result_11.sub(result_21).abs().max() + diff = min(diff, result_12.sub(result_22).abs().max()) + print(' max diff in generated tensors (should be non-zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) + assert diff > 1.0e-6 + error = max(result_11.sub(target_11).abs().max(), + result_12.sub(target_12).abs().max()) + error = max(error, result_21.sub(target_21).abs().max()) + error = max(error, result_22.sub(target_22).abs().max()) + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset the tracker + tensor_parallel.random.get_cuda_rng_tracker().reset() + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print( + '> testing model parallel cuda manual seed with size {} ...'.format( + tensor_model_parallel_size)) + + parallel_state.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + + tensor_parallel.random.model_parallel_cuda_manual_seed(12345) + assert torch.cuda.initial_seed() == 12345 + with tensor_parallel.random.get_cuda_rng_tracker().fork(): + assert ( + torch.cuda.initial_seed() == + 12345 + 2718 + parallel_state.get_tensor_model_parallel_rank() + ) + + # Reset the tracker + tensor_parallel.random.get_cuda_rng_tracker().reset() + + # Reset groups + parallel_state.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(TEST_SUCCESS_MESSAGE) + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test set rng state') + test_set_cuda_rng_state(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cuda rng tracker') + test_cuda_rng_tracker(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test model parallel cuda manual seed') + test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/tests/L0/run_transformer/run_utils_test.py b/tests/L0/run_transformer/run_utils_test.py new file mode 100644 index 000000000..662e01c3e --- /dev/null +++ b/tests/L0/run_transformer/run_utils_test.py @@ -0,0 +1,20 @@ +import torch + +from apex.transformer.tensor_parallel import utils + + +def test_divide(): + assert utils.divide(8, 4) == 2 + + +def test_split_tensor_along_last_dim(): + inputy = torch.randn((100, 100, 100)) + splits = utils.split_tensor_along_last_dim(inputy, 10) + last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits]) + assert torch.equal(last_dim_shapes, torch.full((10,), 10)) + + +if __name__ == "__main__": + test_divide() + test_split_tensor_along_last_dim() + print(">> passed the test :-)") diff --git a/tests/L0/run_transformer/test_fused_softmax.py b/tests/L0/run_transformer/test_fused_softmax.py new file mode 100644 index 000000000..b25c59f0e --- /dev/null +++ b/tests/L0/run_transformer/test_fused_softmax.py @@ -0,0 +1,137 @@ +"""Test for fused softmax functions. + +Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py +""" # NOQA +import itertools +import unittest + +import torch + +from apex.transformer import AttnMaskType +from apex.transformer.functional import FusedScaleMaskSoftmax + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) + + +class TestFusedScaleMaskSoftmax(unittest.TestCase): + + def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding): + fused_fn = FusedScaleMaskSoftmax( + input_in_fp16=input_in_fp16, + input_in_bf16=input_in_bf16, + mask_func=attention_mask_func, + scale=scale, + softmax_in_fp32=softmax_in_fp32, + attn_mask_type=attn_mask_type, + scaled_masked_softmax_fusion=True, + ) + torch_fn = FusedScaleMaskSoftmax( + input_in_fp16=input_in_fp16, + input_in_bf16=input_in_bf16, + mask_func=attention_mask_func, + scale=scale, + softmax_in_fp32=softmax_in_fp32, + attn_mask_type=attn_mask_type, + scaled_masked_softmax_fusion=False, + ) + return fused_fn, torch_fn + + def test_fused_scale_mask_softmax(self): + """ + attention_scores.shape = [4, 12, 24, 24] + mask.shape = [4, 1, 24, 24] + """ + for (dtype, scale, softmax_in_fp32) in itertools.product( + (torch.half, torch.bfloat16), + (None, 2.0), + (False, True), + ): + with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): + input_in_fp16 = dtype == torch.half + input_in_bf16 = dtype == torch.bfloat16 + if not (scale is None or softmax_in_fp32): + with self.assertRaises(RuntimeError): + self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding) + return + fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding) + + attention_scores = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype) + mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool() + reference = fused_fn(attention_scores, mask) + actual = torch_fn(attention_scores, mask) + torch.testing.assert_allclose(actual, reference) + + def test_autocast_fused_scale_mask_softmax(self): + for dtype in autocast_dtypes: + with self.subTest(f"{dtype}"): + input_in_fp16 = dtype == torch.half + input_in_bf16 = dtype == torch.bfloat16 + fused_fn, torch_fn = self._setup_fused_softmax( + input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding) + attention_scores = torch.randn((4, 12, 24, 24)).cuda() + mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda() + + with torch.cuda.amp.autocast(dtype=dtype): + actual = fused_fn(attention_scores, mask) + self.assertEqual(actual.dtype, dtype) + with torch.no_grad(): + expected = torch_fn(attention_scores.to(dtype), mask) + torch.testing.assert_allclose(actual, expected) + + def test_fused_upper_triangle_mask_softmax(self): + """ + attn_weights.shape: [4, 12, 24, 24] + total_mask.shape: [4, 1, 24, 24] + + total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but + upper elements are True and lower elements and diagonal are False. + """ + for (dtype, scale, softmax_in_fp32) in itertools.product( + (torch.half, torch.bfloat16), + (None, 2.0), + (False, True), + ): + with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): + input_in_fp16 = dtype == torch.half + input_in_bf16 = dtype == torch.bfloat16 + if not (scale is None or softmax_in_fp32): + with self.assertRaises(RuntimeError): + self._setup_fused_softmax( + input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal) + return + fused_fn, torch_fn = self._setup_fused_softmax( + input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal) + + attn_weights = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype) + total_mask = (~( + torch.tril(torch.randn((24, 24), device="cuda")).bool() + ).unsqueeze(0).unsqueeze(0)) + total_mask = total_mask.repeat((4, 1, 1, 1)) + reference = fused_fn(attn_weights, total_mask) + actual = torch_fn(attn_weights, total_mask) + torch.testing.assert_allclose(actual, reference) + + def test_autocast_fused_upper_triangle_mask_softmax(self): + for dtype in autocast_dtypes: + with self.subTest(f"{dtype}"): + input_in_fp16 = dtype == torch.half + input_in_bf16 = dtype == torch.bfloat16 + fused_fn, torch_fn = self._setup_fused_softmax( + input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal) + attn_weights = torch.randn((4, 12, 24, 24)).cuda() + total_mask = (~( + torch.tril(torch.randn((24, 24), device="cuda")).bool() + ).unsqueeze(0).unsqueeze(0)) + + with torch.cuda.amp.autocast(dtype=dtype): + actual = fused_fn(attn_weights, total_mask) + self.assertEqual(actual.dtype, dtype) + with torch.no_grad(): + expected = torch_fn(attn_weights.to(dtype), total_mask) + torch.testing.assert_allclose(actual, expected) diff --git a/tests/L0/run_transformer/test_mpu.py b/tests/L0/run_transformer/test_mpu.py new file mode 100644 index 000000000..20458ab81 --- /dev/null +++ b/tests/L0/run_transformer/test_mpu.py @@ -0,0 +1,52 @@ +import os +import subprocess +import sys +import unittest + + +def run_mpu_tests(): + python_executable_path = sys.executable + # repository_root = os.path.join(os.path.dirname(__file__), "../../../") + # directory = os.path.abspath(os.path.join(repository_root, "tests/mpu")) + directory = os.path.dirname(__file__) + files = [ + os.path.join(directory, f) for f in os.listdir(directory) + if f.startswith("run_") and os.path.isfile(os.path.join(directory, f)) + ] + print("#######################################################") + print(f"# Python executable path: {python_executable_path}") + print(f"# {len(files)} tests: {files}") + print("#######################################################") + errors = [] + for i, test_file in enumerate(files, 1): + test_run_cmd = f"NVIDIA_TF32_OVERRIDE=0 {python_executable_path} {test_file} --micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings 32 --encoder-seq-length 32 --use-cpu-initialization" # NOQA + print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") + try: + output = subprocess.check_output( + test_run_cmd, shell=True + ).decode(sys.stdout.encoding).strip() + except Exception as e: + errors.append((test_file, str(e))) + else: + if '>> passed the test :-)' not in output: + errors.append(test_file, output) + else: + if not errors: + print("### PASSED") + else: + print("### FAILED") + short_msg = f"{len(errors)} out of {len(files)} tests failed" + print(short_msg) + for (filename, log) in errors: + print(f"File: {filename}\nLog: {log}") + raise RuntimeError(short_msg) + + +class TestMPU(unittest.TestCase): + + def test_mpu(self): + run_mpu_tests() + + +if __name__ == '__main__': + unittest.main()