Skip to content

Commit

Permalink
fix isort
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Jul 15, 2024
1 parent ec73c5a commit 93b5526
Showing 3 changed files with 13 additions and 11 deletions.
6 changes: 4 additions & 2 deletions mttl/datamodule/preference_data_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from dataclasses import dataclass
from mttl.models.library.expert_library import DatasetLibrary

import torch

from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
class DataCollatorForDPO(DefaultCollator):
5 changes: 2 additions & 3 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
@@ -6,15 +6,16 @@
from typing import Dict, List

import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from transformers import PreTrainedModel

from mttl.models.containers import add_expert_to_transformer
from mttl.models.containers.expert_containers import ExpertContainer
from mttl.models.containers.selectors import (
ArrowSelectorConfig,
Selector,
SelectorConfig,
ArrowSelectorConfig,
)
from mttl.models.expert_config import ExpertConfig
from mttl.models.library.expert import Expert, ExpertInfo
@@ -31,8 +32,6 @@
from mttl.models.utils import EfficientCheckpointModule, prepare_model_for_kbit_training
from mttl.utils import logger

import torch.nn.functional as F

torch.set_float32_matmul_precision("high")


13 changes: 7 additions & 6 deletions projects/modular_llm/train_dpo.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import copy
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import copy

import torch
from pytorch_lightning import Trainer, seed_everything

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from mttl.callbacks import LiveCheckpointCallback
from mttl.datamodule.base import DatasetConfig
from mttl.datamodule.preference_data_module import Preferencemodule

# from mttl.datamodule.base import get_datamodule
from mttl.models.expert_config import ExpertConfig
from mttl.models.expert_model import (
ExpertModel,
MultiExpertModel,
MoEModel,
ExpertModelDPO,
MoEModel,
MultiExpertModel,
)
from mttl.models.library.expert import Expert, load_expert
from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary
@@ -29,11 +32,9 @@
remote_login,
setup_logging,
)
from mttl.datamodule.base import DatasetConfig
from mttl.datamodule.preference_data_module import Preferencemodule
from projects.modular_llm.eval_library import patch_prototypes
from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig
from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix
from projects.modular_llm.eval_library import patch_prototypes


def create_transfer_matrix(args, checkpoint):

0 comments on commit 93b5526

Please sign in to comment.