Skip to content

Commit

Permalink
Fine tuning updates + unsloth experimentation (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlucco authored Jan 7, 2025
1 parent c420d1b commit 9f2577c
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 73 deletions.
33 changes: 23 additions & 10 deletions python/fineTuning/chaparral/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dataclasses import dataclass
from chaparral.models.knowledge import KnowledgeResponse
from chaparral.models.mistral import MixtralDataset
from chaparral.models.mistral import MixtralFormat
from chaparral.prompts.knowledge import get_knowledge_prompt
from typing import List

Expand Down Expand Up @@ -35,7 +35,7 @@ def to_dict(self):
}

@dataclass
class Dataset:
class ChapparalDataset:
prompt: str
info_pairs: list[InfoPair]

Expand Down Expand Up @@ -71,18 +71,31 @@ def to_dict(self):
def create_train_eval_sets(self, split_ratio: float = 0.8):
train_size = int(len(self.info_pairs) * split_ratio)

train_set = Dataset(self.prompt, self.info_pairs[:train_size])
eval_set = Dataset(self.prompt, self.info_pairs[train_size:])
train_set = ChapparalDataset(self.prompt, self.info_pairs[:train_size])
eval_set = ChapparalDataset(self.prompt, self.info_pairs[train_size:])

return train_set, eval_set

def format_v2(self, eos_token: str = "\n\n####\n\n") -> dict:
items = []
for pair in self.info_pairs:
items.append({
"prompt" : get_knowledge_prompt(pair.message)+eos_token,
"completion" : pair.knowledge.to_str()
})

return {
"items": items
}


def format(self, model_name: str) -> dict:
format_map = {
"mistralai/Mixtral-8x7b-v0.1": MixtralDataset,
"google/gemma-2-2b": MixtralDataset,
"meta-llama/Llama-3.1-8B": MixtralDataset,
"meta-llama/Llama-3.2-3B-Instruct": MixtralDataset,
"meta-llama/Llama-3.2-1B-Instruct": MixtralDataset
"mistralai/Mixtral-8x7b-v0.1": MixtralFormat,
"google/gemma-2-2b": MixtralFormat,
"meta-llama/Llama-3.1-8B": MixtralFormat,
"meta-llama/Llama-3.2-3B-Instruct": MixtralFormat,
"meta-llama/Llama-3.2-1B-Instruct": MixtralFormat
}

dataset_type = format_map.get(model_name, None)
Expand All @@ -93,4 +106,4 @@ def format(self, model_name: str) -> dict:
return dataset_type.from_dataset(self).to_dict()

def get_filled_prompt(self, message: str) -> str:
return get_knowledge_prompt(message)
return get_knowledge_prompt(message)
2 changes: 1 addition & 1 deletion python/fineTuning/chaparral/models/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,4 @@ def to_dict(self):
}

def to_str(self):
return str(self.to_dict())
return str(self.to_dict())
6 changes: 3 additions & 3 deletions python/fineTuning/chaparral/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def to_dict(self):
}

@dataclass
class MixtralDataset:
class MixtralFormat:
items: List[MixtralChat]

@classmethod
def from_dataset(cls, dataset: "Type[Dataset]") -> "MixtralDataset":
def from_dataset(cls, dataset: "Type[ChapparalDataset]") -> "MixtralFormat":
items = []
for pair in dataset.info_pairs:
populated_message = dataset.get_filled_prompt(pair.message)
Expand All @@ -46,4 +46,4 @@ def from_dataset(cls, dataset: "Type[Dataset]") -> "MixtralDataset":
def to_dict(self):
return {
"items": [chat.to_dict() for chat in self.items]
}
}
184 changes: 146 additions & 38 deletions python/fineTuning/chaparral/train/hf_model.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,68 @@
# Copyright (c) Microsoft Corporation and Henry Lucco.
# Licensed under the MIT License.

from transformers import TextStreamer
from peft.mapping import get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.lora import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
DataCollatorForLanguageModeling,
PreTrainedModel,
)
from peft import LoraConfig, get_peft_model
from chaparral.models.data import Dataset
from trl import SFTTrainer
from chaparral.train.hf_params import HFParams
import torch
from dotenv import load_dotenv

from chaparral.models.data import ChapparalDataset
from unsloth import to_sharegpt, standardize_sharegpt, apply_chat_template
from datasets import load_dataset

class HFModel:

model_name: str
model: AutoModelForCausalLM
tokenizer: AutoTokenizer
train_set: Dataset | None = None
model: PreTrainedModel
tokenizer: PreTrainedTokenizer
params: HFParams
train_set: ChapparalDataset | None = None
peft_model: PeftModel | None = None

def __init__(self, params: HFParams):
self.params = params
self.model_name = params.model_name

def init_peft(self):
LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1
LORA_R = 16
LORA_ALPHA = 16
LORA_DROPOUT = 0

config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
# Only Training the "expert" layers
target_modules=["w1", "w2", "w3"],
# target_modules=["w1", "w2", "w3"],
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM"
)
peft_model = get_peft_model(self.model, config)
if not isinstance(peft_model, PeftModel):
raise ValueError("PEFT model not initialized properly")

self.model = get_peft_model(self.model, config)
self.peft_model = peft_model

def save_model(self, path: str):
self.model.save_pretrained(path)
Expand All @@ -66,23 +86,29 @@ def load_local_model(self, path: str):
if self.params.use_peft:
self.init_peft()

def predict(self, dataset: Dataset):
data_dict = dataset.format(self.model_name)
test_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))
def prep_dataset(self, dataset: ChapparalDataset):
data_dict = dataset.format_v2()
print(data_dict["items"][0])
training_data = list(map(lambda x: self.tokenize(str(x)), data_dict["items"]))
return training_data

def predict(self, dataset: ChapparalDataset):
test_data = self.prep_dataset(dataset)

trainer = Trainer(
model=self.model,
data_collator=DataCollatorForLanguageModeling(
self.tokenizer, mlm=False)
)

return trainer.predict(test_data)
# this needs to be a pytorch dataset?
# instead of a huggingface dataset.
# adding the ignore flag here because the type
# is correct
return trainer.predict(test_data) #type: ignore

def evaluate(self, dataset: Dataset):
data_dict = dataset.format(self.model_name)
eval_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))
def evaluate(self, dataset: ChapparalDataset):
eval_data = self.prep_dataset(dataset)

trainer = Trainer(
model=self.model,
Expand All @@ -104,49 +130,130 @@ def generate(self, prompt: str, max_length: int = 3000):
self.model.device),
pad_token_id=self.tokenizer.eos_token_id)
return self.tokenizer.decode(output[0], skip_special_tokens=True)

def generate_v2(self, message: str):
messages = [
{"role": "user", "content": message},
]
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
return_tensors = "pt",
).to("cuda")

text_streamer = TextStreamer(self.tokenizer, skip_prompt = True)
_ = self.model.generate(input_ids, streamer = text_streamer, max_new_tokens = 128, pad_token_id = tokenizer.eos_token_id)

def load_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float32,
cache_dir=self.params.cache_dir,
load_in_4bit=True if self.params.use_peft else False,
device_map="auto",
# load_in_4bit=True if self.params.use_peft else False,
# device_map="auto",
)

self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.params.cache_dir
)

self.tokenizer.pad_token = self.params.pad_token
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.params.use_peft:
self.init_peft()

def load_training_data(self, dataset: Dataset):
def load_training_data(self, dataset: ChapparalDataset):
self.train_set = dataset

def init_data(self):
dataset = load_dataset("hlucco/npr_gpt4o_train_200")["train"]
print(dataset.column_names)

dataset = to_sharegpt(
dataset,
merged_prompt = "{instruction}",
output_column_name = "output",
conversation_extension = 3, # Select more to handle longer conversations
)

dataset = standardize_sharegpt(dataset)

chat_template = """Below are some instructions that describe some tasks. Write responses that appropriately complete each request.
### Instruction:
{INPUT}
### Response:
{OUTPUT}"""

dataset = apply_chat_template(
dataset,
tokenizer = self.tokenizer,
chat_template = chat_template,
# default_system_message = "You are a helpful assistant", << [OPTIONAL]
)

self.dataset = dataset

def tokenize(self, text: str):
eos_token = self.tokenizer.eos_token
if not isinstance(eos_token, str):
eos_token = self.tokenizer.decode(eos_token)

return self.tokenizer(
text + self.tokenizer.eos_token,
text + eos_token,
truncation=True,
max_length=self.params.cutoff_length,
padding="max_length"
)

def sft_train(self):
trainer = SFTTrainer(
model = self.model,
tokenizer = self.tokenizer,
train_dataset = self.dataset,
dataset_text_field = "text",
max_seq_length = 8024,
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 10,
# num_train_epochs = 1, # For longer training runs!
learning_rate = 2e-4,
fp16 = True,
bf16 = False,
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)
self.model.config.use_cache = False
print("training has started...")
trainer.train()

def train(self):

if not self.train_set:
raise ValueError("No training data loaded")

data_dict = self.train_set.format(self.model_name)
training_data = list(
map(lambda x: self.tokenize(str(x)), data_dict["items"]))
# training_data = self.prep_dataset(self.train_set)

print("initialization of trainer")
trainer = Trainer(
model=self.model,
train_dataset=training_data,
tokenizer=self.tokenizer,
# train_dataset=training_data,
train_dataset=self.dataset,
dataset_text_field = "text",
max_seq_length = 8024,
args=TrainingArguments(
fp16=True,
max_steps=20,
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
per_device_train_batch_size=self.params.hf_trainer_params.per_device_train_batch_size,
gradient_accumulation_steps=self.params.hf_trainer_params.gradient_accumulation_steps,
num_train_epochs=self.params.hf_trainer_params.num_train_epochs,
Expand All @@ -161,18 +268,19 @@ def train(self):
)

self.model.config.use_cache = False
print("training has started...")
trainer.train()

def print_trainable_parameters(self):
trainable_params = sum(p.numel()
for p in self.model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in self.model.parameters())
print(f"trainable params: {trainable_params} || all params: {
all_params} || trainable%: {100 * trainable_params / all_params}")
print(f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params}")


if __name__ == "__main__":
load_dotenv(".env")
model = HFModel("mistralai/Mixtral-8x7b-v0.1")
params = HFParams.from_file("params.json")
model = HFModel(params)
print("Model loaded successfully ✅")
model.print_trainable_parameters()
Loading

0 comments on commit 9f2577c

Please sign in to comment.