From a786ca0cf2d2dd790cecd96b19cc478d7e661cb5 Mon Sep 17 00:00:00 2001 From: eqy Date: Mon, 7 Feb 2022 08:36:43 -0800 Subject: [PATCH] fix and generate docs for FusedRMSNorm (#1285) --- apex/normalization/fused_layer_norm.py | 12 ++++++------ docs/source/layernorm.rst | 3 +++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index db7a9afa7..8558f7a5e 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module): Currently only runs on cuda() tensors. .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + y = \frac{x}{\mathrm{RMS}[x]} * \gamma - The mean and standard-deviation are calculated separately over the last + The root-mean-square is calculated separately over the last certain number dimensions which have to be of the shape specified by :attr:`normalized_shape`. - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :math:`\gamma` is a learnable affine transform parameter of :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. .. note:: Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the - :attr:`affine` option, Layer Normalization applies per-element scale and - bias with :attr:`elementwise_affine`. + :attr:`affine` option, RMS Normalization applies per-element scale + with :attr:`elementwise_affine`. This layer uses statistics computed from input data in both training and evaluation modes. @@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module): >>> # Activating the module >>> output = m(input) - .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 + .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf """ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): diff --git a/docs/source/layernorm.rst b/docs/source/layernorm.rst index 36dcb845b..6eedb4ed2 100644 --- a/docs/source/layernorm.rst +++ b/docs/source/layernorm.rst @@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm .. autoclass:: FusedLayerNorm :members: + +.. autoclass:: FusedRMSNorm + :members: