Skip to content

Commit

Permalink
Fine tune eval (#513)
Browse files Browse the repository at this point in the history
add code to evaluate fine-tuned knowledge extraction model
  • Loading branch information
steveluc authored Dec 27, 2024
1 parent 44f2912 commit 94704de
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 21 deletions.
5 changes: 4 additions & 1 deletion python/fineTuning/chaparral/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def create_train_eval_sets(self, split_ratio: float = 0.8):
def format(self, model_name: str) -> dict:
format_map = {
"mistralai/Mixtral-8x7b-v0.1": MixtralDataset,
"google/gemma-2-2b": MixtralDataset
"google/gemma-2-2b": MixtralDataset,
"meta-llama/Llama-3.1-8B": MixtralDataset,
"meta-llama/Llama-3.2-3B-Instruct": MixtralDataset,
"meta-llama/Llama-3.2-1B-Instruct": MixtralDataset
}

dataset_type = format_map.get(model_name, None)
Expand Down
98 changes: 83 additions & 15 deletions python/fineTuning/chaparral/train/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Licensed under the MIT License.

from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
Expand All @@ -14,6 +14,7 @@
import torch
from dotenv import load_dotenv


class HFModel:

model_name: str
Expand All @@ -34,21 +35,83 @@ def init_peft(self):
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=[ "w1", "w2", "w3"], #Only Training the "expert" layers
# Only Training the "expert" layers
target_modules=["w1", "w2", "w3"],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM"
)

self.model = get_peft_model(self.model, config)

def save_model(self, path: str):
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)

def load_local_model(self, path: str):
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float32,
cache_dir=self.params.cache_dir,
load_in_4bit=True if self.params.use_peft else False,
device_map="auto",
)

self.tokenizer = AutoTokenizer.from_pretrained(
path,
cache_dir=self.params.cache_dir
)

self.tokenizer.pad_token = self.params.pad_token
if self.params.use_peft:
self.init_peft()

def predict(self, dataset: Dataset):
data_dict = dataset.format(self.model_name)
test_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))

trainer = Trainer(
model=self.model,
data_collator=DataCollatorForLanguageModeling(
self.tokenizer, mlm=False)
)

return trainer.predict(test_data)

def evaluate(self, dataset: Dataset):
data_dict = dataset.format(self.model_name)
eval_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))

trainer = Trainer(
model=self.model,
eval_dataset=eval_data,
data_collator=DataCollatorForLanguageModeling(
self.tokenizer, mlm=False)
)

return trainer.evaluate()

def generate(self, prompt: str, max_length: int = 3000):
encoding = self.tokenizer.encode_plus(prompt, return_tensors="pt")
print(encoding)
input_ids = encoding.input_ids
# move the input tensor to the device
input_ids = input_ids.to(self.model.device)
output = self.model.generate(input_ids, max_length=max_length,
attention_mask=encoding.attention_mask.to(
self.model.device),
pad_token_id=self.tokenizer.eos_token_id)
return self.tokenizer.decode(output[0], skip_special_tokens=True)

def load_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float32,
cache_dir=self.params.cache_dir,
load_in_4bit=True if self.params.use_peft else False,
device_map="auto"
device_map="auto",
)

self.tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -67,7 +130,7 @@ def tokenize(self, text: str):
return self.tokenizer(
text + self.tokenizer.eos_token,
truncation=True,
max_length=self.cutoff_length,
max_length=self.params.cutoff_length,
padding="max_length"
)

Expand All @@ -77,12 +140,13 @@ def train(self):
raise ValueError("No training data loaded")

data_dict = self.train_set.format(self.model_name)
training_data = list(map(lambda x: self.tokenize(str(x)), data_dict["items"]))
training_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))

trainer = Trainer(
model = self.model,
train_dataset = training_data,
args = TrainingArguments(
model=self.model,
train_dataset=training_data,
args=TrainingArguments(
per_device_train_batch_size=self.params.hf_trainer_params.per_device_train_batch_size,
gradient_accumulation_steps=self.params.hf_trainer_params.gradient_accumulation_steps,
num_train_epochs=self.params.hf_trainer_params.num_train_epochs,
Expand All @@ -92,16 +156,20 @@ def train(self):
save_strategy=self.params.hf_trainer_params.save_strategy,
output_dir=self.params.hf_trainer_params.output_dir
),
data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
data_collator=DataCollatorForLanguageModeling(
self.tokenizer, mlm=False)
)

self.model.config.use_cache = False
trainer.train()

def print_trainable_parameters(self):
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
trainable_params = sum(p.numel()
for p in self.model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in self.model.parameters())
print(f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params}")
print(f"trainable params: {trainable_params} || all params: {
all_params} || trainable%: {100 * trainable_params / all_params}")


if __name__ == "__main__":
load_dotenv(".env")
Expand Down
2 changes: 1 addition & 1 deletion python/fineTuning/chaparral/train/hf_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class HFTrainerParams:

output_dir: str
per_device_train_batch_size: int = 8
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 1
num_train_epochs: int = 3
learning_rate: float = 1e-4
Expand Down
5 changes: 4 additions & 1 deletion python/fineTuning/chaparral/util/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ class DataReader:

def load_text_file(self, filename: str) -> Dataset:
with open(filename, "r") as file:
raw_data = json.load(file)
try:
raw_data = json.load(file)
except(json.JSONDecodeError):
raise ValueError("Invalid JSON file")

dataset = None
if isinstance(raw_data, list):
Expand Down
40 changes: 40 additions & 0 deletions python/fineTuning/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation and Henry Lucco.
# Licensed under the MIT License.

from chaparral.util.datareader import DataReader
from chaparral.train.hf_model import HFModel
from chaparral.train.hf_params import HFParams
import argparse


def parse_args():
parser = argparse.ArgumentParser(
description="Fine-tune a model with given dataset.")
parser.add_argument("--dataset_file", help="Path to the dataset file.")
parser.add_argument("--model_name", help="Name of the model to fine-tune.")
parser.add_argument("--params", help="Path to params file")
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
dataset_file = args.dataset_file
params_file = args.params

# load params
params = HFParams.from_file(params_file)

# load dataset
dataset = DataReader().load_text_file(dataset_file)

# format data into train and eval sets
train_set, eval_set = dataset.create_train_eval_sets()

model = HFModel(params)

print("Model loaded")

model.load_local_model("./test_output")
print(model.evaluate(eval_set))
print(model.generate(dataset.get_filled_prompt(
"The quick brown fox jumps over the lazy dog")))
10 changes: 7 additions & 3 deletions python/fineTuning/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from chaparral.train.hf_model import HFModel
from chaparral.train.hf_params import HFParams
import argparse

def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a model with given dataset.")
parser.add_argument("--dataset_file", help="Path to the dataset file.")
Expand All @@ -32,7 +31,12 @@ def parse_args():
print("Model loaded")

model.load_training_data(train_set)

model.load_model()

model.train()
model.train()

model.save_model("./test_output")

model.load_local_model("./test_output")
print(model.evaluate(eval_set))

0 comments on commit 94704de

Please sign in to comment.