Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pure meta model lm_head tp #6812

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
14 changes: 7 additions & 7 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,11 @@ def set_lm_head(module):
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head,
"weight") and not module.lm_head.weight.is_meta and isinstance(
module.lm_head, torch.nn.Linear):
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
module.lm_head, torch.nn.Linear):
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
Expand Down Expand Up @@ -389,7 +387,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
Expand All @@ -399,6 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
# AutoTP default set lm_head tp
if not config.replace_with_kernel_inject:
replaced_module = set_lm_head(replaced_module)

quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
Expand Down