diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py index 677771a8e..445ad39ac 100644 --- a/mttl/datamodule/preference_data_module.py +++ b/mttl/datamodule/preference_data_module.py @@ -1,8 +1,10 @@ -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule from dataclasses import dataclass -from mttl.models.library.expert_library import DatasetLibrary + import torch +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.models.library.expert_library import DatasetLibrary + @dataclass class DataCollatorForDPO(DefaultCollator): diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index b8ae239e2..0fba66597 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -6,15 +6,16 @@ from typing import Dict, List import torch +import torch.nn.functional as F from torch.optim.optimizer import Optimizer from transformers import PreTrainedModel from mttl.models.containers import add_expert_to_transformer from mttl.models.containers.expert_containers import ExpertContainer from mttl.models.containers.selectors import ( + ArrowSelectorConfig, Selector, SelectorConfig, - ArrowSelectorConfig, ) from mttl.models.expert_config import ExpertConfig from mttl.models.library.expert import Expert, ExpertInfo @@ -31,8 +32,6 @@ from mttl.models.utils import EfficientCheckpointModule, prepare_model_for_kbit_training from mttl.utils import logger -import torch.nn.functional as F - torch.set_float32_matmul_precision("high") diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index dc5bbea05..1438f550d 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -1,22 +1,25 @@ +import copy import os import shutil import sys from tempfile import TemporaryDirectory -import copy + import torch from pytorch_lightning import Trainer, seed_everything sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from mttl.callbacks import LiveCheckpointCallback +from mttl.datamodule.base import DatasetConfig +from mttl.datamodule.preference_data_module import Preferencemodule # from mttl.datamodule.base import get_datamodule from mttl.models.expert_config import ExpertConfig from mttl.models.expert_model import ( ExpertModel, - MultiExpertModel, - MoEModel, ExpertModelDPO, + MoEModel, + MultiExpertModel, ) from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary @@ -29,11 +32,9 @@ remote_login, setup_logging, ) -from mttl.datamodule.base import DatasetConfig -from mttl.datamodule.preference_data_module import Preferencemodule +from projects.modular_llm.eval_library import patch_prototypes from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix -from projects.modular_llm.eval_library import patch_prototypes def create_transfer_matrix(args, checkpoint):