Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Jul 15, 2024
1 parent 330813c commit ec73c5a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
29 changes: 23 additions & 6 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from mttl.models.containers import add_expert_to_transformer
from mttl.models.containers.expert_containers import ExpertContainer
from mttl.models.containers.selectors import Selector, SelectorConfig
from mttl.models.containers.selectors import (
Selector,
SelectorConfig,
ArrowSelectorConfig,
)
from mttl.models.expert_config import ExpertConfig
from mttl.models.library.expert import Expert, ExpertInfo
from mttl.models.library.expert_library import ExpertLibrary
Expand Down Expand Up @@ -640,9 +644,23 @@ def __init__(self, expert_model, ref_expert_model, **kwargs):
super().__init__(**kwargs)
self.expert_model = expert_model
self.ref_expert_model = ref_expert_model
self.trainable_param_names = kwargs.get("trainable_param_names", None)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
params = []
# for param_name, param in self.named_parameters():
# param.requires_grad = False
# if self.trainable_param_names and re.fullmatch(
# self.trainable_param_names, param_name
# ):
# param.requires_grad = True
# params.append(param)

# logger.info(f"Setting {param_name} to trainable.")
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3
)

return optimizer

def training_step(self, batch, _):
Expand Down Expand Up @@ -865,15 +883,14 @@ def __init__(self, expert_library: ExpertLibrary = None, **kwargs):
self.hparams.library_id
)
for i, expert in enumerate(sorted(list(expert_library.keys()))):
self.add_expert_instance(expert_library[expert], expert_name=f"e{i}")

self.add_expert_instance(expert_library[expert], expert_name=expert)
self.moe_num_experts = i + 1
if isinstance(
self.selector_config, (ArrowConfig, HiddenStateComputerConfig)
self.selector_config, (ArrowSelectorConfig, HiddenStateComputerConfig)
):
from projects.modular_llm.eval_library import patch_prototypes

patch_prototypes(self, expert_library, self.selector_config)
patch_prototypes(self, expert_library, self.hparams)

def training_step(self, batch, _):
loss = super().training_step(batch, _)
Expand Down
41 changes: 29 additions & 12 deletions projects/modular_llm/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
import shutil
import sys
from tempfile import TemporaryDirectory

import copy
import torch
from pytorch_lightning import Trainer, seed_everything

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from mttl.callbacks import LiveCheckpointCallback, NanoMMLUCallback, RougeCallback
from mttl.callbacks import LiveCheckpointCallback

# from mttl.datamodule.base import get_datamodule
from mttl.models.expert_config import ExpertConfig
from mttl.models.expert_model import ExpertModel, MoEModel, ExpertModelDPO
from mttl.models.expert_model import (
ExpertModel,
MultiExpertModel,
MoEModel,
ExpertModelDPO,
)
from mttl.models.library.expert import Expert, load_expert
from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary
from mttl.models.monitors import get_monitors
Expand All @@ -24,10 +29,11 @@
remote_login,
setup_logging,
)
from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from mttl.datamodule.base import DatasetConfig
from mttl.datamodule.preference_data_module import Preferencemodule
from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig
from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix
from projects.modular_llm.eval_library import patch_prototypes


def create_transfer_matrix(args, checkpoint):
Expand Down Expand Up @@ -83,23 +89,34 @@ def create_library(args):
loggers = get_pl_loggers(args)
# select dataloader
if args.model_modifier == "poly":
args.init_from_scratch = True
model_class = MoEModel
else:
model_class = ExpertModel

config = DatasetConfig(model=args.model)
dm = Preferencemodule(config)

# dm = get_datamodule(args)
# args.n_tasks = len(dm._task_names)
# args.task_names = dm._task_names

model = model_class(**vars(args), tokenizer=dm.tokenizer)
# if args.router_selector == "arrow_router":
args.trainable_param_names = None
ref_model = model_class(
**vars(args), tokenizer=dm.tokenizer, expert_library=expert_library
)

if args.rl_training == "dpo":
ref_model = model_class(**vars(args), tokenizer=dm.tokenizer)
module = ExpertModelDPO(model, ref_model)
args.trainable_param_names = ".*prototypes.*"
model = model_class(
**vars(args), tokenizer=dm.tokenizer, expert_library=expert_library
)
# if args.library_id:
# model.add_experts_from_library(expert_library)
# patch_prototypes(model, expert_library, args)

# # ref_model = copy.deepcopy(model)
# ref_model.add_experts_from_library(expert_library)
# patch_prototypes(ref_model, expert_library, args)
module = ExpertModelDPO(model, ref_model, **vars(args))

# get metric monitors for models
callbacks = get_monitors(args)
Expand Down Expand Up @@ -130,8 +147,8 @@ def create_library(args):
val_check_interval = args.total_steps

trainer = Trainer(
# devices=-1,
# accelerator="gpu",
devices=-1,
accelerator="gpu",
logger=loggers,
num_sanity_val_steps=0,
default_root_dir=args.output_dir,
Expand Down

0 comments on commit ec73c5a

Please sign in to comment.