diff --git a/eval_gsm.py b/eval_gsm.py new file mode 100644 index 00000000..72347e18 --- /dev/null +++ b/eval_gsm.py @@ -0,0 +1,39 @@ +# open the json file and read the data +import json +import re +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--file", type=str, default="experiment/gsm.jsonl") +args = parser.parse_args() + + +def extract_code(text): + match = re.search(r"```(.*?)```", text, re.DOTALL) + if match: + return match.group(1).strip() + return None + + +correct = 0 +all_count = 0 +with open(args.file, "r") as f: + for line in f: + all_count += 1 + data = json.loads(line) + answer = data["answer"] + code = data["output_pred"] + if code is None: + continue + predict_answer = None + try: + exec(code) + exec("predict_answer = solution()") + # exec("print(predict_answer, answer)") + # compute the accuracy + except Exception as e: + print(e) + if predict_answer == answer: + correct += 1 +print(correct, all_count) +print("Accuracy:", correct / all_count) diff --git a/gsm_evaluator_with_lora_soup.py b/gsm_evaluator_with_lora_soup.py new file mode 100644 index 00000000..ebebc928 --- /dev/null +++ b/gsm_evaluator_with_lora_soup.py @@ -0,0 +1,38 @@ +# here, we train experts and we upload them to a local library (repository) of experts. + +import os +from mttl.arguments import ExpertConfig +from mttl.datamodule.base import get_datamodule +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.expert_model import ( + ExpertModel, + MultiExpertModel, + MultiExpertModelConfig, + ExpertModelConfig, +) +from mttl.models.train_utils import train_model + +from mttl.evaluators.gsm_evaluator import GsmEvaluator +from mttl.evaluators.rouge_evaluator import RougeEvaluator +from mttl.models.containers.selectors.base import UniformSelectorConfig +from mttl.arguments import EvaluationConfig, ExpertConfig +from mttl.models.lightning.expert_module import ExpertModule +import torch +from mttl.logging import setup_logging + +device = "cuda" if torch.cuda.is_available() else "cpu" +setup_logging() + +args = EvaluationConfig.parse() + +datamodule = get_datamodule(args, for_generation=True) +evaluator = GsmEvaluator(datamodule) + +# +module = ExpertModule(**vars(args)).to(device) + +if args.checkpoint is not None: + checkpoint = torch.load(args.checkpoint, weights_only=False)["state_dict"] + module.load_state_dict(checkpoint) +## evaluate +result = evaluator.evaluate(module.model, split="test") diff --git a/mttl/arguments.py b/mttl/arguments.py index cc10e578..afa45f07 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -590,6 +590,7 @@ class EvaluationConfig(MultiExpertConfig, TransformArgs): es_metric: str = "loss" n_ng_iterations: int = 30 # number of iterations for LoraHub recompute_prototypes: bool = False + gsm_template: str = "cot" @dataclass diff --git a/mttl/dataloader/alpaca_dataset_readers.py b/mttl/dataloader/alpaca_dataset_readers.py index 350d6517..32628dd7 100644 --- a/mttl/dataloader/alpaca_dataset_readers.py +++ b/mttl/dataloader/alpaca_dataset_readers.py @@ -96,3 +96,19 @@ def read_all_instructions(self): for data in self.dataset: all_instructions.append(data["instruction"]) return all_instructions + + +class AlpacaCodeDataset(AlpacaDataset): + def __init__(self): + super().__init__() + self.dataset = DatasetLibrary.pull_dataset( + "zhan1993/code_alpaca_20k", split="train" + ) + + +class MathQaAlpacaCodeDataset(AlpacaDataset): + def __init__(self): + super().__init__() + self.dataset = DatasetLibrary.pull_dataset( + "zhan1993/metamath_code_alpaca_10k", split="train" + ) diff --git a/mttl/datamodule/alpaca_data_module.py b/mttl/datamodule/alpaca_data_module.py index 64181d69..2fde1269 100644 --- a/mttl/datamodule/alpaca_data_module.py +++ b/mttl/datamodule/alpaca_data_module.py @@ -1,4 +1,8 @@ -from mttl.dataloader.alpaca_dataset_readers import AlpacaDataset +from mttl.dataloader.alpaca_dataset_readers import ( + AlpacaCodeDataset, + AlpacaDataset, + MathQaAlpacaCodeDataset, +) from mttl.datamodule.base import DataModule, DatasetConfig @@ -18,9 +22,46 @@ def setup_dataset(self): self.test_dataset = self.dev_dataset -class AlpacaPretrainDataModule(AlpacaDataModule): - pass +@DataModule.register("alpaca_code", config_cls=DatasetConfig) +class AlpacaCodeDataModule(DataModule): + @property + def all_instructions(self): + return self.dataset.read_all_instructions() + + def __init__(self, config, for_generation=False, val_mixin=None): + super().__init__(config, for_generation, val_mixin) + + def setup_dataset(self): + dataset = AlpacaCodeDataset() + + self.train_dataset, self.dev_dataset = self.create_train_valid_split(dataset) + self.test_dataset = self.dev_dataset + + +@DataModule.register("mathqa_alpaca_code", config_cls=DatasetConfig) +class MathQaAlpacaCodeDataModule(AlpacaDataModule): + def setup_dataset(self): + dataset = MathQaAlpacaCodeDataset() + self.train_dataset, self.dev_dataset = self.create_train_valid_split(dataset) + self.test_dataset = self.dev_dataset class AlpacaFinetuneDataModule(AlpacaDataModule): pass + + +if __name__ == "__main__": + # alpaca_data_module = AlpacaDataModule( + # DatasetConfig(model="meta-llama/Llama-2-7b-hf") + # ) + # alpaca_data_module.setup_dataset() + # print(alpaca_data_module.train_dataset) + + mathqa_alpaca_code_data_module = MathQaAlpacaCodeDataModule( + DatasetConfig(model="meta-llama/Llama-2-7b-hf") + ) + mathqa_alpaca_code_data_module.setup_dataset() + val_dataloder = mathqa_alpaca_code_data_module.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index dc9dcfab..238c381e 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -944,6 +944,12 @@ def get_datamodule(args, for_generation=False, dataset_override=None): HellaswagDataConfig, HellaswagMultiChoiceDataModule, ) + from mttl.datamodule.mathqa_data_module import MathQADataConfig, MathQADataModule + from mttl.datamodule.gsm_data_module import GsmDataConfig, GsmDataModule + from mttl.datamodule.base import DatasetConfig + from mttl.datamodule.alpaca_data_module import ( + AlpacaCodeDataModule, + ) from mttl.datamodule.mmlu_data_module import MMLUDataConfig, MMLUDataModule from mttl.datamodule.mt_seq_to_seq_module import ( FlanConfig, @@ -1063,6 +1069,23 @@ def get_datamodule(args, for_generation=False, dataset_override=None): pack_sequences=args.pack_sequences, ) dm = FlatMultiTaskModule(config, for_generation=for_generation) + elif "mathqa" in dataset: + config = MathQADataConfig( + **common_kwargs, + ) + dm = MathQADataModule(config, for_generation=for_generation) + elif "gsm" in dataset: + config = GsmDataConfig( + **common_kwargs, + gsm_template=args.gsm_template, + ) + dm = GsmDataModule(config, for_generation=for_generation) + + elif "alpaca_code" in dataset: + config = DatasetConfig( + **common_kwargs, + ) + dm = AlpacaCodeDataModule(config, for_generation=for_generation) elif "mmlu" in dataset: config = MMLUDataConfig( **common_kwargs, diff --git a/mttl/datamodule/gsm_data_module.py b/mttl/datamodule/gsm_data_module.py new file mode 100644 index 00000000..66425e3d --- /dev/null +++ b/mttl/datamodule/gsm_data_module.py @@ -0,0 +1,96 @@ +import os +from dataclasses import dataclass + +from mttl.datamodule.base import DataModule, DatasetConfig +from mttl.models.library.dataset_library import DatasetLibrary +import json + + +@dataclass +class GsmDataConfig(DatasetConfig): + gsm_template: str = ( + "cot" # the template we will use for the prompt, for code generation or chain of thought. + ) + + +# code refer to https://github.com/aksh555/LoRA-Soups/blob/main/evaluate.py#L208 +def generate_math_prompt_with_python(instruction, input=None): + with open("mttl/datamodule/math.json", "r") as f: + cot_data = json.load(f) + prompt = """Let's use Python to solve math problems step by step. Below are a few Instruction-Response pairs on how to do it.""" + prompt += "\n\n" + for data in cot_data: + prompt += f"### Instruction:\n{data['instruction']}\n\n### Response:\n{data['output']}\n\n" + prompt += "Now write a function 'solution' encolsed in ``` in Python to solve this Instruction. Write only a code block. Write only valid Python code without using any units with the numerical values and any invalid symbols.\n\n" + prompt += f"### Instruction:\n{instruction}\n\n### Response:\n" + return prompt + + +def instruct_template_python(example): + example["source"] = generate_math_prompt_with_python(example["input"]) + example["target"] = str(example["answer"]) + return example + + +def instruct_template_cot(example): + + PREAMBLE = """As an expert problem solver solve step by step the following mathematical questions.""" + PROMPT = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? + A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6. + + Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? + A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5. + + Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? + A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39. + + Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? + A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8. + + Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? + A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9. + + Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? + A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29. + + Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? + A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33. + + Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.""" + + TEMPLATE = """ + Q: {question} + A:""" + + full_prompt = ( + PREAMBLE + "\n\n" + PROMPT + "\n" + TEMPLATE.format(question=example["input"]) + ) + example["source"] = full_prompt + example["target"] = str(example["answer"]) + return example + + +@DataModule.register("gsm", config_cls=GsmDataConfig) +class GsmDataModule(DataModule): + def setup_dataset(self): + n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 4)) + dataset = DatasetLibrary.pull_dataset("reasoning-machines/gsm-hard") + dataset = dataset.rename_column("target", "answer") + if self.config.gsm_template == "cot": + dataset = dataset.map(instruct_template_cot, num_proc=n_proc) + elif self.config.gsm_template == "python": + dataset = dataset.map(instruct_template_python, num_proc=n_proc) + self.train_dataset = dataset["train"] + self.dev_dataset = self.test_dataset = dataset["train"] + + +if __name__ == "__main__": + config = GsmDataConfig(model="microsoft/Phi-3-mini-4k-instruct", gsm_template="cot") + + datamodule = GsmDataModule(config, for_generation=True) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() diff --git a/mttl/datamodule/math.json b/mttl/datamodule/math.json new file mode 100644 index 00000000..29903895 --- /dev/null +++ b/mttl/datamodule/math.json @@ -0,0 +1,50 @@ +[ + { + "instruction": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n money_initial = 23\n bagels = 5\n bagel_cost = 3\n money_spent = bagels * bagel_cost\n money_left = money_initial - money_spent\n result = money_left\n return result\n```", + "answer": "8" + }, + { + "instruction": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n golf_balls_initial = 58\n golf_balls_lost_tuesday = 23\n golf_balls_lost_wednesday = 2\n golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n result = golf_balls_left\n return result\n```", + "answer": "33" + }, + { + "instruction": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n computers_initial = 9\n computers_per_day = 5\n num_days = 4 # 4 days between monday and thursday\n computers_added = computers_per_day * num_days\n computers_total = computers_initial + computers_added\n result = computers_total\n return result\n```", + "answer": "29" + }, + { + "instruction": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n toys_initial = 5\n mom_toys = 2\n dad_toys = 2\n total_received = mom_toys + dad_toys\n total_toys = toys_initial + total_received\n result = total_toys\n return result\n```", + "answer": "9" + }, + { + "instruction": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n jason_lollipops_initial = 20\n jason_lollipops_after = 12\n denny_lollipops = jason_lollipops_initial - jason_lollipops_after\n result = denny_lollipops\n return result\n```", + "answer": "8" + }, + { + "instruction": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n leah_chocolates = 32\n sister_chocolates = 42\n total_chocolates = leah_chocolates + sister_chocolates\n chocolates_eaten = 35\n chocolates_left = total_chocolates - chocolates_eaten\n result = chocolates_left\n return result\n```", + "answer": "39" + }, + { + "instruction": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", + "input": "", + "output": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\nLet\u2019s go ahead and solve this using Python:\n```\ndef solution():\n cars_initial = 3\n cars_arrived = 2\n total_cars = cars_initial + cars_arrived\n result = total_cars\n return result\n```", + "answer": "5" + }, + { + "instruction": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", + "input": "", + "output": "Let\u2019s go ahead and solve this using Python:\n```\ndef solution():\n trees_initial = 15\n trees_after = 21\n trees_added = trees_after - trees_initial\n result = trees_added\n return result\n```", + "answer": "6" + } +] \ No newline at end of file diff --git a/mttl/datamodule/mathqa_data_module.py b/mttl/datamodule/mathqa_data_module.py new file mode 100644 index 00000000..2d7cc153 --- /dev/null +++ b/mttl/datamodule/mathqa_data_module.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +from mttl.datamodule.base import DataModule, DatasetConfig, DefaultCollator +from mttl.models.library.dataset_library import DatasetLibrary + + +@dataclass +class MathQADataConfig(DatasetConfig): + pass + + +@dataclass +class MathQADataModule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + train_dataset = DatasetLibrary.pull_dataset_with_retry("meta-math/MetaMathQA")[ + "train" + ] + + train_dataset = train_dataset.rename_column("query", "source") + train_dataset = train_dataset.rename_column("response", "target") + + # filter out the rows where the source is empty + train_dataset = train_dataset.filter(lambda x: x["source"] != "") + + self.train_dataset = train_dataset + self.test_dataset = self.dev_dataset = train_dataset + + self.print_infos() + + @property + def collate_fn(self): + return DefaultCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = MathQADataConfig(model="microsoft/Phi-3-mini-4k-instruct") + + datamodule = MathQADataModule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() diff --git a/mttl/evaluators/gsm_evaluator.py b/mttl/evaluators/gsm_evaluator.py new file mode 100644 index 00000000..43322553 --- /dev/null +++ b/mttl/evaluators/gsm_evaluator.py @@ -0,0 +1,140 @@ +import os + +from tqdm.auto import tqdm +from mttl.evaluators.base import GenerativeEvaluator, switch_to_eval_mode +import re +from mttl.logging import logger +import json + + +class GsmEvaluator(GenerativeEvaluator): + def __init__( + self, + datamodule, + use_vllm=False, + generation_kwargs=None, + prepend_source=True, + split="test", + ): + super().__init__( + datamodule=datamodule, + use_vllm=use_vllm, + generation_kwargs=generation_kwargs, + ) + + self.split = split + self.prepend_source = prepend_source + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + if self.config.gsm_template == "python": + self.save_file = ( + f"experiment/{self.config.model}-{self.config.dataset}.jsonl" + ) + + if not os.path.exists(f"experiment/{self.config.model}"): + os.makedirs(f"experiment/{self.config.model}") + + @switch_to_eval_mode + def evaluate( + self, + model, + split=None, + subsample=-1, + num_batches=None, + verbose=True, + shuffle=False, + output_path=None, + ): + dataloader = self.get_dataloader(split, subsample, shuffle=shuffle) + + pbar = tqdm( + enumerate(dataloader), + total=len(dataloader), + ) + + all_predictions = [] + all_targets = [] + + # https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/gsm8k_eval.ipynb + def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r"-?[\d,]*\.?\d+", + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + def get_predictions(predictions_texts, batch, all_predictions, all_targets): + # iterate over the predictions and targets + + for i, (pred, source, target) in enumerate( + zip(predictions_texts, batch["sources_texts"], batch["labels_texts"]) + ): + pred = pred[len(source) :] + fields = pred.split("The answer is") + if len(fields) != 0: + pred_item = fields[0] + predictions = pred_item.replace(",", "") + # code is from https://github.com/aksh555/LoRA-Soups/blob/main/utils.py#L224 + pred = find_numbers(predictions) + if not pred: + all_predictions.append(float("inf")) + else: + pred_answer = float(pred[-1]) + all_predictions.append(pred_answer) + logger.info(f"Predictions: {pred_answer}, Targets: {target}") + print(f"Predictions: {pred_answer}, Targets: {target}") + else: + all_predictions.append(float("inf")) + all_targets.extend(batch["labels_texts"]) + + def extract_code(text): + match = re.search(r"```(.*?)```", text, re.DOTALL) + if match: + return match.group(1).strip() + return None + + def print_python_code(predictions_texts, batch, file): + + for i, (pred, source, target) in enumerate( + zip(predictions_texts, batch["sources_texts"], batch["labels_texts"]) + ): + outputs = pred[len(source) - 1 :] + + # convert it to code + code = extract_code(outputs) + data = {} + data["answer"] = float(target) + data["output_pred"] = code + file.write(json.dumps(data) + "\n") + file.flush() + + with open(self.save_file, "w") as f: + for num_batch, batch in pbar: + predictions = self.generate_for_batch(model, batch) + predictions_texts = predictions.sequences_texts + if self.config.gsm_template == "cot": + get_predictions( + predictions_texts, batch, all_predictions, all_targets + ) + elif self.config.gsm_template == "python": + print_python_code(predictions_texts, batch, f) + else: + raise ValueError("Invalid templete") + if len(all_predictions) != 0: + metrics = self.compute_metrics(all_predictions, all_targets) + return metrics + else: + raise ValueError("No predictions found") + + def compute_metrics(self, predictions, targets): + # compute the accuracy based on the cot prompt + correct = 0 + + for pred_answer, target in tqdm(zip(predictions, targets)): + if pred_answer == float(target): + correct += 1 + + accuracy = correct / len(predictions) + return accuracy