Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled configurable auto Tensor Parallelism (TP) for the inference of diverse models #6553

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
import os
import ast


def move(tensor, device):
Expand Down Expand Up @@ -270,6 +272,7 @@ def kernel_supported(module_list):
return True
return False

## tp parser based on autoTP config in environment
def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -279,40 +282,27 @@ def tp_parser(model):
module_list = AutoTP.get_module_list(model)
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']
#ln_1 , ln_2 for Qwen

allReduceLinearItems = os.environ['allReduceLinearItems']
gyou2021 marked this conversation as resolved.
Show resolved Hide resolved
allReduceLinearItems = ast.literal_eval(allReduceLinearItems)

for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
gyou2021 marked this conversation as resolved.
Show resolved Hide resolved
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)

for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
continue
for item in allReduceLinearItems:
if item in layer:
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand Down Expand Up @@ -473,7 +463,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
keepLinearItems = os.environ['keepLinearItems']
keepLinearItems = ast.literal_eval(keepLinearItems)

if any(item not in checking_key for item in keepLinearItems):
setattr(
r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
# Added for falcon model support
Expand Down