Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO RL for Reasoning #500

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions examples/nlp/gpt/conf/gpt_reinforce_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ trainer:
gradient_clip_val: 1.0

# REINFORCE args to generate the data for training
initial_policy_kl_penalty: 0.01
initial_policy_kl_penalty: 0.03
max_length_penalty: 0.0
use_absolute_kl: True
num_rollouts_per_prompt: 4

prompt_rollouts_per_microbatch: 4
val_prompt_rollouts_per_microbatch: 1
# Prompts are filtered from training whose accuracy is less than min or more than max threshold set below.
online_filtering_min_accuracy_threshold: 0.0
online_filtering_max_accuracy_threshold: 1.0

# the sequence length to pad the rollout batch for training to
# reduce fragmentation at the cost of using more
Expand Down Expand Up @@ -62,6 +67,7 @@ trainer:
# *do not change this*
model_gbs: ${model.global_batch_size}
model_mbs: ${model.micro_batch_size}
generation_save_dir: /tmp/

# no need to change these
logger: False # logger provided by exp_manager
Expand All @@ -78,7 +84,7 @@ remote_rm:

# reward model server
reward_model:
name: reward_model
name: math_grader
ip: localhost
port: 5555

Expand All @@ -97,7 +103,7 @@ exp_manager:
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_rewards
save_top_k: 1
save_top_k: -1
mode: max
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
Expand All @@ -113,6 +119,11 @@ model:
# training generation mbs
rollout_micro_batch_size: 8
num_rollout_samples: 512
prompt_rollouts_per_microbatch: ${trainer.reinforce.prompt_rollouts_per_microbatch}

initial_policy_kl_penalty: ${trainer.reinforce.initial_policy_kl_penalty}
use_grpo_loss: False
grpo_eps: 1e-2 # does nothing unless use_grpo_loss is True

# mbs to do log prob inference, can be set to
# lower than rollout_micro_batch_size to reduce
Expand Down Expand Up @@ -143,9 +154,10 @@ model:
# length argument for autoregressive sampling
# max length means max amount of tokens to generate
length_params:
max_length: ${int_div:${model.encoder_seq_length}, 2}
max_length: 3072 #${int_div:${model.encoder_seq_length}, 2}
min_length: 1

disable_baseline: False
trt_llm: ${trainer.reinforce.trt_llm}

peft:
Expand Down Expand Up @@ -194,23 +206,25 @@ model:
name: CosineAnnealing
warmup_steps: 10
constant_steps: 1000
min_lr: 9e-8
min_lr: 9e-7

precision: ${trainer.precision}

data:
data_impl: jsonl
shuffle_train_data: True
splits_string: null
seq_length: ${model.encoder_seq_length}
seq_length: ${subtract:${model.encoder_seq_length}, ${model.reinforce.length_params.max_length}}
skip_warmup: True
num_workers: 0
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_prefix: null
prompt_file: null

# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
180 changes: 180 additions & 0 deletions examples/nlp/gpt/conf/gpt_star_actor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
defaults:
- optional tp_overlap@model.ub_tp_comm_overlap_cfg:

trainer:
# these args are respected
num_nodes: 8
devices: 8
accelerator: gpu
precision: bf16

star:
max_epochs: 1
max_steps: -1 # max star steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
gradient_clip_val: 1.0

# pick up from the model
# *do not change this*
model_gbs: ${model.global_batch_size}
model_mbs: ${model.micro_batch_size}

# the sequence length to pad the rollout batch to
# this reduces fragmentation at the cost of using more
# memory, set to null if we don't want to pad it
# to a constant size
rollout_batch_seq_length: null

# no need to change these
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_time: null
max_epochs: ${.star.max_epochs}
max_steps: ${.star.max_steps}

remote_rm:
# what to pad the inputs to
# set to None if no padding when sending data for reward model inference
pad_to_length: ${model.encoder_seq_length}

# reward model server
reward_model:
name: math_grader
ip: localhost
port: 5555


exp_manager:
explicit_log_dir: /results
exp_dir: null
name: megatron_gpt_star_actor
create_wandb_logger: False
wandb_logger_kwargs:
project: nemo_aligner_star
name: gpt3_star_2b
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_global_rewards
save_top_k: 1
mode: max
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
filename: 'megatron_gpt-{step}-{consumed_samples}-{rs_optimization_step}-{epoch}-{val_global_rewards:.3f}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

pretrained_checkpoint:
restore_from_path: null

model:

star:
# training generation mbs
rollout_micro_batch_size: 8
num_rollout_samples: 512

# mbs to do log prob inference, can be set to
# lower than rollout_micro_batch_size to reduce
# memory usage
forward_micro_batch_size: ${.rollout_micro_batch_size}

num_rollouts_per_prompt: 4 # Number of completions to sample per prompt
top_n_rollouts: 1 # Number of completions to select based on reward and train upon (per prompt)

# val generation mbs
val_rollout_micro_batch_size: ${.rollout_micro_batch_size}
num_val_samples: ${.num_rollout_samples}

# to offload during generation or not
offload_adam_states: True

# params for generation
sampling_params:
use_greedy: False
temperature: 1.0
top_k: 0
top_p: 1.0
repetition_penalty: 1.0
add_BOS: False
all_probs: False
compute_logprob: False
end_strings: ["<|endoftext|>", "<extra_id_1>"]

# length argument for autoregressive sampling
# max length means max amount of tokens to generate
length_params:
max_length: ${int_div:${model.encoder_seq_length}, 2}
min_length: 1

#peft
peft:
peft_scheme: "none" # ["lora", "none"]
restore_from_path: null
restore_from_ckpt:
checkpoint_dir: null
checkpoint_name: null

lora_tuning:
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all'
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

mcore_gpt: True
# these control the mbs/gbs during RS training
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True

encoder_seq_length: 4096
max_position_embeddings: ${model.encoder_seq_length}

## Sequence Parallelism
sequence_parallel: False

# miscellaneous
seed: 1234

optim:
name: distributed_fused_adam
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
lr: 9e-7
weight_decay: 0.1
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 10
constant_steps: 1000
min_lr: 9e-8

precision: ${trainer.precision}

data:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_prefix: null

# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
14 changes: 14 additions & 0 deletions examples/nlp/gpt/serve_math_grader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from functools import partial

from nemo_aligner.algorithms.math_grader_server import SimpleMathGrader, extract_and_check

ENDPOINT_BIND_ADDRESS = "0.0.0.0"


def main() -> None:
server = SimpleMathGrader(grading_function=extract_and_check, port=5555, process_count=16,)
server.run_server()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/nlp/gpt/serve_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import torch
from pytorch_lightning.trainer.trainer import Trainer
from lightning.pytorch.trainer.trainer import Trainer

from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.core.config import hydra_runner
Expand Down
21 changes: 21 additions & 0 deletions examples/nlp/gpt/star.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/sh
TRAIN_DATA_PATH=/opt/NeMo-Skills/nemo_skills/dataset/gsm8k/train_full.jsonl
VALID_DATA_PATH=/opt/NeMo-Skills/nemo_skills/dataset/gsm8k/test.jsonl
TEST_DATA_PATH=/opt/NeMo-Skills/nemo_skills/dataset/gsm8k/test.jsonl
CHECKPOINT=/opt/NeMo/checkpoints/qwen2-1-5b-it.nemo
#CHECKPOINT=/opt/NeMo/checkpoints/llama3-1-8B-instruct.nemo
#CHECKPOINT=/opt/NeMo/checkpoints/llama3-2-1B-instruct.nemo
#CHECKPOINT=/opt/NeMo-Aligner/checkpoints/2b_mcore_actor.nemo

python train_gpt_star.py \
pretrained_checkpoint.restore_from_path=$CHECKPOINT \
"model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \
trainer.num_nodes=1 \
trainer.devices=2 \
model.star.rollout_micro_batch_size=4 \
model.star.num_rollout_samples=32 \
model.star.num_rollouts_per_prompt=4 \
model.star.top_n_rollouts=4 \
model.micro_batch_size=2 \
model.global_batch_size=32 \
+model.tensor_model_parallel_size=1
32 changes: 32 additions & 0 deletions examples/nlp/gpt/test_client_math_grader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List

import numpy as np

from nemo_aligner.servers.http_communicator import HTTPCommunicator

communicator = HTTPCommunicator.create_http_communicator_from_dict({"math_grader": ("0.0.0.0", 5555)})

preds = [
"\\boxed{.5}",
"There is a lot of random reasoning done here, but eventually, the answer is \\boxed{123/456}",
"\\boxed{15}",
]
gt = ["1/2", "0.2697368421", "7"]


def triton_textencode(text_batch: List[str]):
enc = np.array([[np.char.encode(i, "utf-8")] for i in text_batch])
enc = np.reshape(enc, (enc.shape[0], 1))

return enc


data = {
"pred_responses": triton_textencode(preds),
"ground_truth": triton_textencode(gt),
}

future = communicator.send_data_to_server("math_grader", data)
print(future)
v = future.result()
print(v)
Loading
Loading