diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e59f84bc8453..00b22aac81d8 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -342,13 +342,11 @@ def set_lm_head(module): module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight # enable tensor parallel for the last linear - if hasattr(module, "lm_head") and hasattr(module.lm_head, - "weight") and not module.lm_head.weight.is_meta and isinstance( - module.lm_head, torch.nn.Linear): + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance( + module.lm_head, torch.nn.Linear): module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") - elif hasattr(module, "embed_out") and hasattr(module.embed_out, - "weight") and not module.embed_out.weight.is_meta and isinstance( - module.embed_out, torch.nn.Linear): + elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance( + module.embed_out, torch.nn.Linear): module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"): module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head") @@ -389,7 +387,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size): checkpoint=checkpoint_file) pbar.update(1) gc.collect() - replaced_module = set_lm_head(replaced_module) # conv2d tp module replace # Now is for yuan model. Add model list and conv policy to decide whether to replace conv. if 'Yuan' in str(replaced_module): @@ -399,6 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size): orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=config.injection_policy_tuple) + # AutoTP default set lm_head tp + if not config.replace_with_kernel_inject: + replaced_module = set_lm_head(replaced_module) quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1