diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 5611f58fb..b8ae239e2 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -11,7 +11,11 @@ from mttl.models.containers import add_expert_to_transformer from mttl.models.containers.expert_containers import ExpertContainer -from mttl.models.containers.selectors import Selector, SelectorConfig +from mttl.models.containers.selectors import ( + Selector, + SelectorConfig, + ArrowSelectorConfig, +) from mttl.models.expert_config import ExpertConfig from mttl.models.library.expert import Expert, ExpertInfo from mttl.models.library.expert_library import ExpertLibrary @@ -640,9 +644,23 @@ def __init__(self, expert_model, ref_expert_model, **kwargs): super().__init__(**kwargs) self.expert_model = expert_model self.ref_expert_model = ref_expert_model + self.trainable_param_names = kwargs.get("trainable_param_names", None) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + params = [] + # for param_name, param in self.named_parameters(): + # param.requires_grad = False + # if self.trainable_param_names and re.fullmatch( + # self.trainable_param_names, param_name + # ): + # param.requires_grad = True + # params.append(param) + + # logger.info(f"Setting {param_name} to trainable.") + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3 + ) + return optimizer def training_step(self, batch, _): @@ -865,15 +883,14 @@ def __init__(self, expert_library: ExpertLibrary = None, **kwargs): self.hparams.library_id ) for i, expert in enumerate(sorted(list(expert_library.keys()))): - self.add_expert_instance(expert_library[expert], expert_name=f"e{i}") - + self.add_expert_instance(expert_library[expert], expert_name=expert) self.moe_num_experts = i + 1 if isinstance( - self.selector_config, (ArrowConfig, HiddenStateComputerConfig) + self.selector_config, (ArrowSelectorConfig, HiddenStateComputerConfig) ): from projects.modular_llm.eval_library import patch_prototypes - patch_prototypes(self, expert_library, self.selector_config) + patch_prototypes(self, expert_library, self.hparams) def training_step(self, batch, _): loss = super().training_step(batch, _) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 820f85c65..dc5bbea05 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -2,17 +2,22 @@ import shutil import sys from tempfile import TemporaryDirectory - +import copy import torch from pytorch_lightning import Trainer, seed_everything sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from mttl.callbacks import LiveCheckpointCallback, NanoMMLUCallback, RougeCallback +from mttl.callbacks import LiveCheckpointCallback # from mttl.datamodule.base import get_datamodule from mttl.models.expert_config import ExpertConfig -from mttl.models.expert_model import ExpertModel, MoEModel, ExpertModelDPO +from mttl.models.expert_model import ( + ExpertModel, + MultiExpertModel, + MoEModel, + ExpertModelDPO, +) from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary from mttl.models.monitors import get_monitors @@ -24,10 +29,11 @@ remote_login, setup_logging, ) -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.datamodule.base import DatasetConfig from mttl.datamodule.preference_data_module import Preferencemodule from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix +from projects.modular_llm.eval_library import patch_prototypes def create_transfer_matrix(args, checkpoint): @@ -83,23 +89,34 @@ def create_library(args): loggers = get_pl_loggers(args) # select dataloader if args.model_modifier == "poly": - args.init_from_scratch = True model_class = MoEModel else: model_class = ExpertModel - config = DatasetConfig(model=args.model) dm = Preferencemodule(config) # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names - - model = model_class(**vars(args), tokenizer=dm.tokenizer) + # if args.router_selector == "arrow_router": + args.trainable_param_names = None + ref_model = model_class( + **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + ) if args.rl_training == "dpo": - ref_model = model_class(**vars(args), tokenizer=dm.tokenizer) - module = ExpertModelDPO(model, ref_model) + args.trainable_param_names = ".*prototypes.*" + model = model_class( + **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + ) + # if args.library_id: + # model.add_experts_from_library(expert_library) + # patch_prototypes(model, expert_library, args) + + # # ref_model = copy.deepcopy(model) + # ref_model.add_experts_from_library(expert_library) + # patch_prototypes(ref_model, expert_library, args) + module = ExpertModelDPO(model, ref_model, **vars(args)) # get metric monitors for models callbacks = get_monitors(args) @@ -130,8 +147,8 @@ def create_library(args): val_check_interval = args.total_steps trainer = Trainer( - # devices=-1, - # accelerator="gpu", + devices=-1, + accelerator="gpu", logger=loggers, num_sanity_val_steps=0, default_root_dir=args.output_dir,