Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: UL2 merge #23

Open
wants to merge 144 commits into
base: multi-query-attention
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
928a200
Remove deprecated destination argument to state_dict functions and ma…
jaredcasper Jul 21, 2022
5df9e1f
Remove old merge tool.
jaredcasper Jul 26, 2022
c464a10
Merge branch 'del_merge' into 'main'
jaredcasper Jul 26, 2022
e36cdd7
added a flag to be able to switch between pytorch and ring exchange p2p
shoeybi Jul 26, 2022
8df49e7
Merge branch 'add_ring_exchange_flag' into 'main'
jaredcasper Jul 27, 2022
76db958
support for all mask in fused kernel + avoiding inplace operation in …
kvareddy Jul 28, 2022
189e72a
Merge branch 'fused_softmax_kernel_fixes' into 'main'
jaredcasper Jul 29, 2022
45f4ee5
yttm + BytelevelBPE + setencepeice tokenizer support
kvareddy Aug 4, 2022
b7b2d6a
fix a bug for size mismatch
pxuab Aug 6, 2022
83d7867
Merge branch 'beam_search' into 'main'
jaredcasper Aug 6, 2022
a44360e
adress review comments
kvareddy Aug 8, 2022
77efccc
Timing levels
shoeybi Aug 10, 2022
d207391
Merge branch 'timing' into 'main'
jaredcasper Aug 10, 2022
27bc133
fixed grad scalar warning so it only prints it for fp16
shoeybi Aug 16, 2022
91384a5
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
aaa5715
fixed grad scalar warning for bf16
shoeybi Aug 16, 2022
d63c254
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
e38d41c
Memory safety checks were incorrect for the tokens_to_generate=0 case
Sep 2, 2022
8b68628
Merge branch 'fixing_safety' into 'main'
jaredcasper Sep 12, 2022
1afe354
Merge branch 'state_dict_fix' into 'main'
jaredcasper Sep 12, 2022
981c3df
support separate datasets for train, valid and test
anmolgupt Sep 22, 2022
fabad46
Clean up licensing.
jaredcasper Sep 23, 2022
28ba253
Merge branch 'licensing' into 'main'
jaredcasper Sep 23, 2022
2e6a46e
Start Megatron-Core with vocab parallel cross entropy
jaredcasper Sep 23, 2022
209f91c
Bring mpu.data into megatron.core.
jaredcasper Sep 23, 2022
c2ea914
Move layers from mpu to core.tensor_parallel.
jaredcasper Sep 23, 2022
5942af9
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
jaredcasper Sep 23, 2022
a94d0a6
Move get_num_layers into transformer.py.
jaredcasper Sep 24, 2022
e00a1ca
Improve docstrings, destory global memory buffer.
jaredcasper Sep 24, 2022
cbf780d
Update exports.
jaredcasper Sep 26, 2022
e7e9972
Check for pipeline_parallel > 2 when using interleaving.
jaredcasper Sep 26, 2022
5f4ddd9
Add basic setup.py for core.
jaredcasper Sep 26, 2022
77753d0
Small fixes.
jaredcasper Sep 27, 2022
55817ec
Correct some merge errors.
jaredcasper Sep 27, 2022
2366716
Error, not warn, if gradient_accumulation_fusion is requested but not…
jaredcasper Sep 27, 2022
07916bf
Support gradient accumulation fusion in fp16.
jaredcasper Sep 27, 2022
57bfa7c
Perform distributed optimizer's all-gather in param dtype (instead of…
lmcafee-nvidia Sep 30, 2022
fc7f4f0
Merge branch 'lmcafee/byte-buffer' into 'main'
jaredcasper Sep 30, 2022
41276b6
Merge branch 'main' into nmt-main
kvareddy Oct 3, 2022
b9ae7ba
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 3, 2022
05d731a
Setting up code coverage
shanmugamr1992 Oct 4, 2022
fb8c09e
Code coverage setup
shanmugamr1992 Oct 4, 2022
cbf8250
different encoder/decoder num-layers support
kvareddy Oct 4, 2022
6ab70f5
Adding some basic unit tests
shanmugamr1992 Oct 5, 2022
63e5994
support for separate dataset files for train, valid and test
Oct 5, 2022
2514892
fixed the timer issue for the case with no pipelining
shoeybi Oct 5, 2022
96b7559
Merge branch 'fix_backward_no_pipeline' into 'main'
jaredcasper Oct 6, 2022
6defe18
Setter for pipeline parallel split rank, remove print
ericharper Oct 6, 2022
6d41789
Merge branch 'changes_for_nemo' into 'core'
jaredcasper Oct 6, 2022
b69e219
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
136cf03
Merge branch 'core' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
shanmugamr1992 Oct 6, 2022
056fc7c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
423623c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
56934a2
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
74ee8c0
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
44c94f5
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
e9f2000
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
4ec95a2
Adding some basic unit tests
shanmugamr1992 Oct 7, 2022
11392f0
Changes'
shanmugamr1992 Oct 7, 2022
94dd94e
Changes'
shanmugamr1992 Oct 7, 2022
2fd9ea1
Code covearage
shanmugamr1992 Oct 7, 2022
c0329d8
Code covearage
shanmugamr1992 Oct 7, 2022
f861467
Code covearage
shanmugamr1992 Oct 7, 2022
45cd4e0
removed assert for the case of evaluation only without training
Oct 10, 2022
69f3249
address review comments
kvareddy Oct 11, 2022
a95fda7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 11, 2022
c7d57ff
Merge branch 'anmolg/validation_1' into 'main'
jaredcasper Oct 11, 2022
8b94a16
Adding proper test cases
shanmugamr1992 Oct 13, 2022
8806ba7
Merge branch 'properTest' into 'core'
jaredcasper Oct 13, 2022
2a86fa2
Merge branch 'main' into core
jaredcasper Oct 13, 2022
5da3bb9
Merge branch 'core-merge-main' into 'core'
jaredcasper Oct 14, 2022
dbed5e0
inverse_square_root learning param schedule
kvareddy Oct 14, 2022
bdd9731
Remove noop used to try to force scheduling and check for environment…
jaredcasper Oct 14, 2022
d3a416c
Merge branch 'core-noop' into 'core'
jaredcasper Oct 14, 2022
abf60f7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 14, 2022
544e250
Disable newline after colon
pxuab Oct 20, 2022
f4a8b1d
Merge branch 'disable_newline_after_colon' into 'main'
jaredcasper Oct 20, 2022
2fdd54e
Sending in prompts with the wrong type hangs the server. This is a c…
Oct 27, 2022
fdc801e
Merge branch 'check_prompts_is_list' into 'main'
jaredcasper Nov 2, 2022
42c4071
Merge branch 'core' into 'main'
jaredcasper Nov 2, 2022
e0a12fe
Fix merge error.
jaredcasper Nov 8, 2022
1a26b29
Merge branch 'core-fix' into 'main'
jaredcasper Nov 8, 2022
fabd3e4
ViT Backbone Tensor Shape Fix
yaoyu-33 Nov 10, 2022
b4297c6
Merge branch 'yuya/vit_fix' into 'main'
jaredcasper Nov 10, 2022
c3e688d
Support for variable sequence lengths across micro-batches
kvareddy Nov 11, 2022
1ad1e1b
Merge branch 'nmt-main' into 'main'
jaredcasper Nov 11, 2022
7fc9611
Data Preprocessing Optimizations
kvareddy Nov 17, 2022
7016945
Merge branch 'nmt-main' into 'main'
Nov 17, 2022
6d45a90
Fix DropPath for hidden shape [s, b, h]
yaoyu-33 Nov 22, 2022
d48d95a
Open sourcing lm detoxification code
boxin-wbx Nov 24, 2022
8ce8256
Merge branch 'boxin/detoxify_lm_cr' into 'main'
boxin-wbx Nov 24, 2022
84a43b1
bug fixes in partitioned data preprocessor
Nov 29, 2022
b24f4ad
Merge branch 'partition_fixes' into 'main'
jaredcasper Nov 29, 2022
52e6368
Merge branch 'yuya/drop_path_fix' into 'main'
jaredcasper Nov 29, 2022
f298a85
Fix typo
janEbert Dec 13, 2022
df3ca00
Set SentencePiece tokenizer global variable
janEbert Dec 13, 2022
072b3a6
Refactor masked LM sampling style selection
janEbert Dec 13, 2022
2c94801
Add more masked LM sampling styles
janEbert Dec 13, 2022
e2bc55c
Allow Prefix-LM style masked LM
janEbert Dec 13, 2022
53f0300
Add UL2 pretraining for T5 model
janEbert Dec 13, 2022
35f232c
Refactor span merging
janEbert Dec 13, 2022
6bd44e7
Allow non-causal GPT models
janEbert Dec 13, 2022
9304618
Support UL2 for decoder-only models
janEbert Dec 13, 2022
9add693
Add custom exceptions
janEbert Dec 14, 2022
20b7acd
Error out on too long sequences
janEbert Dec 14, 2022
b5bef77
Remove additional sequence truncation
janEbert Dec 14, 2022
3e46e3c
Prefer array-from-list creation
janEbert Dec 14, 2022
7bb655c
Remove redundant imports
janEbert Jan 3, 2023
4474556
Fix sometimes not inserting prefixes
janEbert Jan 3, 2023
6f88858
Do not insert `extra_id` tokens for PrefixLM task
janEbert Jan 3, 2023
69fa541
Document `max_seq_length_dec` argument
janEbert Jan 3, 2023
020dd64
Skip redundant computations
janEbert Jan 3, 2023
1820f2b
Fix PrefixLM mean location
janEbert Jan 3, 2023
c4a5b40
Pad decoder-only inputs to same length
janEbert Jan 3, 2023
324d70d
Fix decoder-only attention mask shape
janEbert Jan 3, 2023
eb3dd43
Fix `max_ngrams` for normal sampling style
janEbert Jan 23, 2023
2d1b32d
Do not limit `max_predictions_per_seq`
janEbert Jan 23, 2023
10ef283
Calculate and use amount of filtered tokens
janEbert Jan 23, 2023
6b29f42
Document normal sampling style
janEbert Jan 23, 2023
27fc9fb
Fix PrefixLM possible spans calculation
janEbert Jan 23, 2023
359742e
Avoid mutable pointer in arguments
janEbert Jan 23, 2023
11e3d24
Allow passing callable for getting `model_type`
janEbert Jan 23, 2023
2a67e97
Fix getting model type
janEbert Jan 23, 2023
2dc7587
Allow recognizing when UL2 is used
janEbert Jan 23, 2023
7a4a94d
Only add UL2 tokens if using UL2 pretrain script
janEbert Jan 23, 2023
3c852c0
Support UL2 tokens for all tokenizers
janEbert Jan 23, 2023
c03a7be
Add SEP token to GPT tokenizer if using UL2
janEbert Jan 23, 2023
959daaa
Fix enum name
janEbert Jan 23, 2023
49f6b0f
Fix private UL2 argument default value
janEbert Jan 23, 2023
aa9a1c7
Use binary search for PrefixLM first tail index
janEbert Jan 24, 2023
d906cc1
Calculate n-gram indices lazily
janEbert Jan 24, 2023
3805df7
Prefer list comprehensions
janEbert Jan 24, 2023
69a3519
Merge branch 'janEbert-ul2' into ul2-merge
RaymondLi0 Feb 6, 2023
f5d0df1
support UL2 with HFtokenizer
RaymondLi0 Feb 7, 2023
9f024dc
scale normal distribution variance with its mean, and truncate the di…
RaymondLi0 Feb 7, 2023
f845e38
in the decoder-only case, truncate the masked sequence
RaymondLi0 Feb 7, 2023
ea79fe8
refactor: UL2Dataset does not inherit T5Dataset anymore
RaymondLi0 Feb 7, 2023
d1aed24
fix: mpu.get_cuda_rng_tracker() -> tensor_parallel.get_cuda_rng_track…
RaymondLi0 Feb 10, 2023
e712e7e
remove debug print
RaymondLi0 Feb 10, 2023
458ecf8
move is_ul2 to arguments
RaymondLi0 Feb 14, 2023
b9fa5f7
adjust attention-mask in generation for prefix-lm models
RaymondLi0 Feb 15, 2023
3a305eb
fix assert in tokenizer
RaymondLi0 Feb 17, 2023
96d18f7
Merge branch 'ul2-merge' of github.com:bigcode-project/Megatron-LM in…
RaymondLi0 Feb 17, 2023
fe05ccd
fix pretrain_ul2 for causal-decoder
Mar 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[html]
directory = coverage

[run]
data_file = .coverage_$LOCAL_RANK
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
__pycache__

*.so
build
.coverage_*
*.egg-info
11 changes: 7 additions & 4 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel

test:
tags:
- docker_gpu_enabled
script:
- pytest --junitxml=report.xml tests
- torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts:
when: always
reports:
junit: report.xml
paths:
- coverage
expire_in: 30 days

2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The following applies to all files unless otherwise noted:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; ch

See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options.

### Detoxify GPT via Self-generation
We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models.

See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus.


## GPT Evaluation
We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.

Expand Down
112 changes: 112 additions & 0 deletions examples/detxoify_lm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SGEAT: Detoxify Larger-scale Language Models

This is the official code base for our NeurIPS 2022 paper:

[Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173)

Boxin Wang, Wei Ping, Chaowei Xiao, Peng Xu, Mostofa Patwary, Mohammad Shoeybi, Bo Li, Anima Anandkumar, Bryan Catanzaro


## Citation

```
@article{WangExp2022,
title={Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models},
author={Wang, Boxin and Ping, Wei and Xiao, Chaowei and Xu, Peng and Patwary, Mostofa and Shoeybi, Mohammad and and Li, Bo and Anandkumar, Anima and Catanzaro, Bryan},
journal={NeurIPS},
year={2022}
}
```

## Usage

### Prepare your environment

The project environment is based on the standard [nvcr docker](nvcr.io/nvidia/pytorch:21.12-py3) of version `nvcr.io/nvidia/pytorch:21.12-py3`.

To run Perspective API, you need to install `google-api-python-client`
```bash
pip install --upgrade google-api-python-client
```

### Self Generation

#### SGEAT (Standard)
To perform unconditional generation for a Megatron LM, we provide an example script for 1.3B LM.

```bash
# [num of samples] [model checkpoint] [random seed]
bash examples/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh 1000 checkpoints/gpt3/gpt3-1.3b/ 2333
```
This will generate a jsonl file of 1000 generated text (as a toy example) at `selfgeneration/unconditional_generation_gpt3-1.3b/2333.out`.

Note that you may want to set your own gpt2 vocab and merge file dir, as well as your output data dir in `selfgenerate-1.3b-unconditional.sh`.

### Annotation

We then use Perspective API to annotate the self generated corpus. Note that you need to fill in your own Perspective API key in the `examples/detoxify_lm/perspective_api_annotate.py`.

```bash
python examples/detxoify_lm/perspective_api_annotate.py --data-path [input-data-path] --out-path [output-data-path] --workers 70
```

For example,

```bash
python examples/detxoify_lm/annotations/perspective_api_annotate.py --data-path selfgeneration/unconditional_generation_gpt3-1.3b/2333.out --out-path selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.out --workers 70
```

### Filtering

We then filter the self annotated generated corpus to get the most nontoxic 50% of the corus.

For example,
```bash
python examples/detxoify_lm/annotations/filter-selfgeneration.py --data-path selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.out --out-path selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic.out
```

This will generate a jsonl file of 500 text of the lowest toxicity (as a toy example) at `selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic.out`.


### Preprocess

We then preprocess the dataset so that Megatron LM can use the dumped dataset to fine-tune.

```
bash examples/detxoify_lm/annotations/preprocess.sh selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic.out selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic
```

This will generate two files as follows
```bash
selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic_text_document.idx
selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic_text_document.bin
```
which will be used in the following domain-adative training step.

### Fine-tuning

We then use the preprocess dataset as input to fine-tune our Megatron-LM.
```bash
# [fine-tuning dataset] [output-dir] [lr] [bs] [train-iters] [load checkpoint]
bash examples/detxoify_lm/finetune_gpt_distributed-1.3b.sh selfgeneration/unconditional_generation_gpt3-1.3b/2333.annotated.nontoxic_text_document gpt3-1.3b-toy-example-lr-2e-5-bs-512 2e-5 512 78 checkpoints/gpt3/gpt3-1.3b
```

This will dump the final checkpoint in `$SHARE_DATA/gpt3-1.3b-toy-example-lr-2e-5-bs-512`. (`$SHARE_DATA` is your current work dir, default to `$PWD`)

### Evaluation

We then use the fine-tuned checkpoint to perform conditional generation given RealToxicityPrompts:

```bash
# [input-prompts] [model-checkpoint]
bash examples/detxoify_lm/generate-1.3b.sh augmented_prompts.jsonl $SHARE_DATA/gpt3-1.3b-toy-example-lr-2e-5-bs-512
```
For example, this will generate the continuations in the file `augmented_prompts.jsonl_output_gpt3-1.3b-toy-example-lr-2e-5-bs-512_seed_31846.jsonl` (seed is a random generated number).

Note that the input prompts are augmented so that each prompts appear 25 times to calculate the Expected Maximum Toxicity over 25 generations and Toxicity Probability,

We then use Perspective API to evaluate the Expected Maximum Toxicity and Toxicity Probability.

```bash
python examples/detxoify_lm/perspective_api.py --data-path "augmented_prompts.jsonl_output_gpt3-1.3b-toy-example-lr-2e-5-bs-512_seed_31846.jsonl" --prompt-path prompts.jsonl --workers 30
```
75 changes: 75 additions & 0 deletions examples/detxoify_lm/annotations/filter-selfgeneration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import time
from typing import Dict, Optional, List

import joblib
from googleapiclient import discovery
from googleapiclient.errors import HttpError

import argparse

from tqdm import tqdm

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--data-path', type=str, default='',
help='data path to load the jsonl')
parser.add_argument('--out-path', type=str, default='',
help='data path to load the jsonl')
parser.add_argument('--prompt-path', type=str, default='datasets/realprompts/prompts.jsonl',
help='data path to load the prompt jsonl')
parser.add_argument('--workers', type=int, default=10,
help='Number of worker processes to launch')


def get_corpus_scores(lines):
scores = []
corpus = []

for line in tqdm(lines):
data = json.loads(line)
score = data['score']
if score is not None:
scores.append(score['toxicity'])
corpus.append(data)
return scores, corpus


def main():
args = parser.parse_args()

with open(args.data_path) as f:
lines = f.readlines()

print(f"total line number: {len(lines)}")

scores, corpus = get_corpus_scores(lines)
import numpy as np
scores = np.array(scores)
indices = np.argsort(scores)

print(f"total valid samples: {len(scores)}")

print(f"Avg toxicity (all): {np.mean(scores)} +- {np.std(scores)}")
print(f"Avg toxicity (toxic): {np.mean(scores[scores > 0.5])} +- {np.std(scores[scores > 0.5])}")
print(f"Toxic Percentage {sum(scores > 0.5) / len(scores)}")
print(f"Avg toxicity (nontoxic): {np.mean(scores[scores <= 0.5])} +- {np.std(scores[scores <= 0.5])}")
print(f"Nontoxic Percentage {sum(scores <= 0.5) / len(scores)}")

samples_left = len(lines) // 2
print(f"After filtering: {samples_left} of samples are left")
nontoxic_indices = indices[:samples_left]
print(f"Avg toxicity (filtered): {np.mean(scores[nontoxic_indices])} +- {np.std(scores[nontoxic_indices])}")
print(f"Toxicity Range (filtered): {np.min(scores[nontoxic_indices])} ~ {np.max(scores[nontoxic_indices])}")
nontoxic_data = [corpus[ind] for ind in nontoxic_indices]
print(f"Total samples after filtering: {len(nontoxic_data)}")
print(f"Examples: {nontoxic_data[:3]}")

from sklearn.utils import shuffle
nontoxic_data = shuffle(nontoxic_data)

with open(args.out_path, 'w') as f:
for x in nontoxic_data:
f.write(json.dumps(x) + '\n')


main()
Loading