Skip to content

Commit

Permalink
fix and generate docs for FusedRMSNorm (NVIDIA#1285)
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored Feb 7, 2022
1 parent 684c473 commit a786ca0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
12 changes: 6 additions & 6 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{x}{\mathrm{RMS}[x]} * \gamma
The mean and standard-deviation are calculated separately over the last
The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
:attr:`affine` option, RMS Normalization applies per-element scale
with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Expand Down Expand Up @@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module):
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
Expand Down
3 changes: 3 additions & 0 deletions docs/source/layernorm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
.. autoclass:: FusedLayerNorm
:members:

.. autoclass:: FusedRMSNorm
:members:

0 comments on commit a786ca0

Please sign in to comment.