Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cp_comm_type param to Mistral config #12049

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

@dataclass
class MistralConfig7B(GPTConfig):
"""
Mistral 7B config.
"""

normalization: str = "RMSNorm"
activation_func: Callable = F.silu
position_embedding_type: str = "rope"
Expand All @@ -56,6 +60,7 @@ class MistralConfig7B(GPTConfig):
init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
window_size: List[int] = field(default_factory=lambda: [4096, 0])
cp_comm_type: str = "a2a"


@dataclass
Expand All @@ -70,6 +75,7 @@ class MistralNeMoConfig12B(MistralConfig7B):
seq_length: int = 4096 # but "max_position_embeddings": 1024000,

window_size: List[int] = None
cp_comm_type: str = None
rotary_percent: float = 1.0
rotary_base: float = 1000000.0

Expand All @@ -88,11 +94,14 @@ class MistralNeMoConfig123B(MistralConfig7B):
seq_length: int = 4096 # but "max_position_embeddings": 131072,

window_size: List[int] = None
cp_comm_type: str = None
rotary_percent: float = 1.0
rotary_base: float = 1000000.0


class MistralModel(GPTModel):
""" """

def __init__(
self,
config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None,
Expand All @@ -107,6 +116,8 @@ def __init__(

@io.model_importer(MistralModel, "hf")
class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]):
""" """

def init(self) -> MistralModel:
return MistralModel(self.config, tokenizer=self.tokenizer)

Expand All @@ -127,6 +138,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
Expand All @@ -141,12 +153,14 @@ def convert_state(self, source, target):

@property
def tokenizer(self) -> "AutoTokenizer":
""" """
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)))

@property
def config(self) -> MistralConfig7B:
""" """
from transformers import MistralConfig

source = MistralConfig.from_pretrained(str(self))
Expand Down Expand Up @@ -175,6 +189,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
window_size=window_size,
cp_comm_type=source.cp_comm_type,
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
Expand All @@ -186,6 +201,8 @@ def make_vocab_size_divisible_by(mistral_vocab_size):

@io.model_exporter(MistralModel, "hf")
class HFMistralExporter(io.ModelConnector[MistralModel, "MistralForCausalLM"]):
""" """

def init(self, dtype=torch.bfloat16) -> "MistralForCausalLM":
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import no_init_weights
Expand All @@ -209,6 +226,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
Expand All @@ -226,10 +244,12 @@ def convert_state(self, source, target):

@property
def tokenizer(self):
""" """
return io.load_context(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "MistralConfig":
""" """
source: MistralConfig7B = io.load_context(str(self)).model.config

from transformers import MistralConfig as HfMistralConfig
Expand Down Expand Up @@ -259,6 +279,7 @@ def config(self) -> "MistralConfig":
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv(ctx: io.TransformCTX, q, k, v):
""" """
megatron_config = ctx.target.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -301,6 +322,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
""" """
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -333,6 +355,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_linear_fc1(down, gate):
""" """
return torch.cat((down, gate), axis=0)


Expand All @@ -341,6 +364,7 @@ def _import_linear_fc1(down, gate):
target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
)
def _export_linear_fc1(linear_fc1):
""" """
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj
Expand Down
Loading