Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move async position into data dict for valid intermediate jsonl #333

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class GenerateSolutionsConfig:
# data to engine at the same time (batch size is ignored) and then write the output as soon as it's ready
# to `output_file`-async (and put it back in order after all generations are done)
use_async_loop: bool = True
async_position_key: str = "_async_position" # key to use for preserving position in async loop in data dict

# can add this flag to just print the first prompt instead of running generation
# useful to double check that your data can be loaded and prompt has what you expect
Expand Down Expand Up @@ -240,7 +241,7 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params
filled_positions = set()
with open(cfg.output_file + '-async', "rt", encoding="utf-8") as fin:
for line in fin:
filled_positions.add(int(json.loads(line)[0]))
filled_positions.add(int(json.loads(line)[cfg.async_position_key]))
data = [dp for idx, dp in enumerate(data) if idx not in filled_positions]
original_positions = [idx for idx in original_positions if idx not in filled_positions]
except FileNotFoundError:
Expand Down Expand Up @@ -291,7 +292,11 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params
for key in gen_dict:
data[gen_pos].pop(key, None)
gen_dict.update(data[gen_pos])
fout.write(json.dumps([original_positions[gen_pos], gen_dict]) + "\n")

# insert async position information
gen_dict[cfg.async_position_key] = original_positions[gen_pos]

fout.write(json.dumps(gen_dict) + "\n")

time.sleep(1)
pbar.close()
Expand All @@ -302,7 +307,8 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params

ordered_generations = [None] * len(generations)
for gen_dict in generations:
ordered_generations[gen_dict[0]] = gen_dict[1]
async_pos = gen_dict.pop(cfg.async_position_key)
ordered_generations[async_pos] = gen_dict

with open(cfg.output_file, "wt", encoding="utf-8") as fout:
for gen_dict in ordered_generations:
Expand Down
Loading