-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autotp training #6922
base: master
Are you sure you want to change the base?
Autotp training #6922
Changes from all commits
674a873
a2e4c47
f4eb142
dd081ed
cdaed2f
9aad0e7
2bb11fd
e75c1c2
840a5f2
60bd6ab
9266383
07174a9
ee6323e
6461b84
4d73011
c79c3bb
97e659c
a15905b
e9802b0
88b8acf
868be0b
3788e07
27b24f6
3d7b89f
47a6b0b
3a23997
e3ec46e
9685879
7b99b03
570645f
3729b64
62d8858
dd17313
93cf6f5
87c4bc2
1714bb5
dadf915
86c9399
2526dc6
c9fd699
797e71f
86ae65e
3e40024
7d94b77
b297950
67ce220
f818be9
11c98f6
e22b625
8531b64
8d19e01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,8 @@ | |
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine | ||
from deepspeed.utils.timer import SynchronizedWallClockTimer | ||
from deepspeed.runtime.compiler import is_compile_supported | ||
|
||
from deepspeed.utils import groups | ||
from deepspeed.module_inject.layers import is_autotp_training_mode | ||
from ..runtime.state_dict_factory import SDLoaderFactory | ||
from ..runtime.weight_quantizer import WeightQuantization | ||
from ..module_inject import replace_transformer_layer, generic_injection | ||
|
@@ -247,6 +248,11 @@ def _post_forward_hook(self, module, input, output): | |
self._model_times.append(elapsed_time) | ||
|
||
def _create_model_parallel_group(self, config): | ||
|
||
if is_autotp_training_mode(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conceptually, control flow for training should not come here. I think some refactoring/restructuring is needed for code quality. |
||
groups._init_tp_mesh_device(config.tensor_parallel.tp_size) | ||
self.mp_group = groups.get_tensor_model_parallel_group() | ||
return | ||
# Call the init process | ||
if InferenceEngine.inference_mp_group is None: | ||
init_distributed() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,12 @@ | |
from typing import Optional | ||
import torch | ||
from deepspeed import comm as dist | ||
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce | ||
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearALlreduce, Yuan_LinearLayer, GLM_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Original coding style is |
||
from deepspeed.accelerator import get_accelerator | ||
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp | ||
from .fusedqkv_utils import require_tp_fused_qkvw | ||
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list | ||
from deepspeed.utils import groups | ||
from deepspeed.module_inject.layers import is_autotp_training_mode | ||
|
||
|
||
def move(tensor, device): | ||
|
@@ -324,10 +326,18 @@ def tp_parser(model): | |
return policy_list | ||
|
||
def set_tensor_parallel_config(self, mp_size, mp_group): | ||
|
||
if is_autotp_training_mode(): | ||
self.mp_group = groups.get_tensor_model_parallel_group() | ||
self.mp_size = groups.get_tensor_model_parallel_world_size() | ||
return | ||
|
||
self.mp_size = mp_size | ||
self.mp_group = mp_group | ||
|
||
def _replace(self, child, name, conv_linear_layer): | ||
# This function should clearly define the routing rules for specific layers | ||
# and avoid any complex shard-related logic. | ||
if getattr(child, "replaced", False) == True: | ||
return | ||
weight_shape = child.weight.shape | ||
|
@@ -339,77 +349,37 @@ def _replace(self, child, name, conv_linear_layer): | |
# For Yuan model | ||
if 'Yuan' in str(self.module): | ||
if 'v_proj' in name: | ||
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), | ||
dist.get_world_size(), True) | ||
return LinearLayer(weight=weight, bias=bias) | ||
return Yuan_LinearLayer(child, self.mp_group) | ||
|
||
elif 'o_proj' in name: | ||
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), | ||
dist.get_world_size(), False) | ||
return LinearAllreduce(weight, bias, self.mp_group) | ||
# For Arctic model, bypass to all_reduce replacement for w2 weights | ||
return Yuan_LinearALlreduce(child, self.mp_group) | ||
|
||
# For MLP including chunk layer. | ||
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This additional code block is trying to deal with "MLP including chunk layer" (general case), but the returned module/object is in the name of GLM prefix. |
||
return GLM_LinearLayer(child, self.mp_group) | ||
# For Arctic model, bypass to all_reduce replacement for w2 weights | ||
arctic_w2_all_reduce_linear = False | ||
if 'Arctic' in str(self.module) and 'w2' in name: | ||
arctic_w2_all_reduce_linear = True | ||
# For MLP including chunk layer. | ||
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): | ||
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) | ||
return LinearLayer(weight=weight, bias=bias) | ||
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear: | ||
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] | ||
# else [weight_shape[0], weight_shape[1] // mp_size] | ||
|
||
setattr(child, "replaced", True) | ||
if self.conv_linear_layer: | ||
child.weight.data = child.weight.data.transpose(-1, -2).contiguous() | ||
data = child.weight.data.split(get_shard_size_list( | ||
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), | ||
dim=1) | ||
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() | ||
del data | ||
return Conv_LinearALlreduce(child, self.mp_group, name=name) | ||
elif name == "lm_head" or name == 'embed_out': | ||
return LmHeadLinearAllreduce(child, self.mp_group) | ||
|
||
setattr(child, "replaced", True) | ||
if name == "lm_head" or name == 'embed_out': | ||
return LmHeadLinearAllreduce( | ||
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), | ||
child.bias if child.bias is None else torch.nn.parameter.Parameter( | ||
move(child.bias, | ||
get_accelerator().current_device_name())), self.mp_group) | ||
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ | ||
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group) | ||
return LinearAllreduce(child, self.mp_group, name=name) | ||
else: | ||
|
||
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] | ||
# else [weight_shape[0] // mp_size, weight_shape[1]] | ||
setattr(child, "replaced", True) | ||
if self.conv_linear_layer: | ||
child.weight.data = child.weight.data.transpose(-1, -2).contiguous() | ||
|
||
if require_tp_fused_qkvw(name, self.mp_size): | ||
conv_LinearLayer(child, self.mp_group) | ||
elif require_tp_fused_qkvw(name, self.mp_size): | ||
#Check and handle fused qkv for TP | ||
#The copy is a regular copy, The shape of dst and src is the same | ||
data_dc = move( | ||
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), | ||
get_accelerator().current_device_name()) | ||
|
||
bias_data_dc = None if child.bias is None else move( | ||
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), | ||
get_accelerator().current_device_name()) | ||
else: | ||
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), | ||
dim=1 if self.conv_linear_layer else 0) | ||
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() | ||
del data | ||
|
||
if child.bias is not None: | ||
bias_data = child.bias.data.split(get_shard_size_list( | ||
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), | ||
dim=0) | ||
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name()) | ||
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) | ||
del bias_data | ||
else: | ||
bias_data_dc = None | ||
return fused_LinearLayer(child, self.mp_group, fused_module=self.module) | ||
|
||
setattr(child, "replaced", True) | ||
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc) | ||
return LinearLayer(child, self.mp_group, name=name) | ||
|
||
def _slice_embedding(self, child, name, conv_linear_layer): | ||
if getattr(child, "replaced", False) == True: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should ideally be outside of the inference module. For example, in deepspeed/runtime/tensor_parallel module?