Skip to content

Commit

Permalink
fix gsm python code
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Dec 21, 2024
1 parent e91ab0d commit 9711704
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions mttl/evaluators/gsm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,24 @@ def get_predictions(predictions_texts, batch, all_predictions, all_targets):
all_predictions.append(float("inf"))
all_targets.extend(batch["labels_texts"])

def extract_code(text):
match = re.search(r"```(.*?)```", text, re.DOTALL)
if match:
return match.group(1).strip()
return None

def print_python_code(predictions_texts, batch, file):

for i, (pred, source, target) in enumerate(
zip(predictions_texts, batch["sources_texts"], batch["labels_texts"])
):
outputs = (
pred.split("### Response:")[-1].strip().split("### Instruction:")[0]
)
outputs = pred[len(source) - 1 :]

# convert it to code
code = extract_code(outputs)
data = {}
data["answer"] = float(target)
data["output_pred"] = outputs
data["output_pred"] = code
file.write(json.dumps(data) + "\n")
file.flush()

Expand Down

0 comments on commit 9711704

Please sign in to comment.