forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Piotr Bialecki <pbialecki@nvidia.com> Co-authored-by: Eddie Yan <eddiey@nvidia.com> Co-authored-by: Rishi Puri <riship@nvidia.com> Co-authored-by: Sangkug Lym <slym@nvidia.com>
- Loading branch information
1 parent
bdac244
commit 365fdc1
Showing
46 changed files
with
6,894 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ | |
from . import optimizers | ||
from . import normalization | ||
from . import pyprof | ||
from . import transformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import torch | ||
|
||
|
||
def _cast_if_autocast_enabled(*args): | ||
if not torch.is_autocast_enabled(): | ||
return args | ||
else: | ||
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .fused_layer_norm import FusedLayerNorm | ||
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# apex.transformer | ||
|
||
`apex.transformer` is a module which enables efficient large Transformer models at scale. | ||
|
||
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module. |
Oops, something went wrong.