Skip to content

Commit

Permalink
Fixed PP for 70B
Browse files Browse the repository at this point in the history
  • Loading branch information
SahilJain314 committed Jan 23, 2025
1 parent b6196fd commit 8c47fd0
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def fwd_output_and_loss_func(data_iterator, model):
required_keys.update(("response_tokens", "position_ids"))

if parallel_state.is_pipeline_last_stage():
required_keys.update(("response_tokens", "baseline", "mask", "rewards_with_kl", "is_end"))
required_keys.update(("response_tokens", "baseline", "mask", "is_end", "init_policy_kl", "init_log_probs", "rewards", "prompt_mask", "log_probs"))

batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()}

Expand Down

0 comments on commit 8c47fd0

Please sign in to comment.