From c99dbbac2f9384ba22c5f55fb0e96d90fc054c82 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 27 Jan 2025 22:27:13 -0800 Subject: [PATCH 1/3] Update tracking of position in async generate Signed-off-by: smajumdar --- nemo_skills/inference/generate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index f8f2a4388..b994a1179 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -85,6 +85,7 @@ class GenerateSolutionsConfig: # data to engine at the same time (batch size is ignored) and then write the output as soon as it's ready # to `output_file`-async (and put it back in order after all generations are done) use_async_loop: bool = True + async_position_key: str = "_async_position" # key to use for preserving position in async loop in data dict # can add this flag to just print the first prompt instead of running generation # useful to double check that your data can be loaded and prompt has what you expect @@ -240,7 +241,7 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params filled_positions = set() with open(cfg.output_file + '-async', "rt", encoding="utf-8") as fin: for line in fin: - filled_positions.add(int(json.loads(line)[0])) + filled_positions.add(int(json.loads(line)[cfg.async_position_key])) data = [dp for idx, dp in enumerate(data) if idx not in filled_positions] original_positions = [idx for idx in original_positions if idx not in filled_positions] except FileNotFoundError: @@ -291,7 +292,11 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params for key in gen_dict: data[gen_pos].pop(key, None) gen_dict.update(data[gen_pos]) - fout.write(json.dumps([original_positions[gen_pos], gen_dict]) + "\n") + + # insert async position information + gen_dict[cfg.async_position_key] = original_positions[gen_pos] + + fout.write(json.dumps(gen_dict) + "\n") time.sleep(1) pbar.close() From 8762fd6d16033c63965659ae40db7e7fafd61c4b Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 27 Jan 2025 22:42:01 -0800 Subject: [PATCH 2/3] Update tracking of position in async generate Signed-off-by: smajumdar --- nemo_skills/inference/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index b994a1179..19557a770 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -307,7 +307,8 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params ordered_generations = [None] * len(generations) for gen_dict in generations: - ordered_generations[gen_dict[0]] = gen_dict[1] + async_pos = gen_dict.pop(cfg.async_position_key) + ordered_generations[gen_dict[async_pos]] = gen_dict with open(cfg.output_file, "wt", encoding="utf-8") as fout: for gen_dict in ordered_generations: From 387c4a2d055a9bab940a22bdf16775243a0756e1 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 27 Jan 2025 22:47:43 -0800 Subject: [PATCH 3/3] Update tracking of position in async generate Signed-off-by: smajumdar --- nemo_skills/inference/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index 19557a770..73181945c 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -308,7 +308,7 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params ordered_generations = [None] * len(generations) for gen_dict in generations: async_pos = gen_dict.pop(cfg.async_position_key) - ordered_generations[gen_dict[async_pos]] = gen_dict + ordered_generations[async_pos] = gen_dict with open(cfg.output_file, "wt", encoding="utf-8") as fout: for gen_dict in ordered_generations: