Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotp training #6922

Open
wants to merge 51 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
674a873
auto tp training
inkcherry Apr 3, 2024
a2e4c47
update parallel_states
inkcherry Apr 23, 2024
f4eb142
Merge branch 'master' into HEAD
inkcherry Nov 19, 2024
dd081ed
WA skips assertions, the loss remains exactly consistent with the low…
inkcherry Nov 19, 2024
cdaed2f
save/load ckpt & save/load hf model basic POC
inkcherry Nov 22, 2024
9aad0e7
finish all the basic functionalities
inkcherry Nov 27, 2024
2bb11fd
update
inkcherry Nov 28, 2024
e75c1c2
use groups for parallel_states
inkcherry Dec 2, 2024
840a5f2
enable bwd allreduce, enable scale loss by gas
inkcherry Dec 2, 2024
60bd6ab
add dataloader check
inkcherry Dec 4, 2024
9266383
refactor autoTP step1
inkcherry Dec 4, 2024
07174a9
rm parallel_states
inkcherry Dec 5, 2024
ee6323e
refactor autoTP step2
inkcherry Dec 5, 2024
6461b84
update ut step1
inkcherry Dec 10, 2024
4d73011
update
inkcherry Dec 11, 2024
c79c3bb
add uts
inkcherry Dec 11, 2024
97e659c
finished all ut code base
inkcherry Dec 12, 2024
a15905b
addllr scheduler test
inkcherry Dec 12, 2024
e9802b0
refine ut
inkcherry Dec 12, 2024
88b8acf
fix bcast_objlist
inkcherry Dec 15, 2024
868be0b
refine layers.py
inkcherry Dec 15, 2024
3788e07
refine gather
inkcherry Dec 15, 2024
27b24f6
pass codegen350M +TP2 ut
inkcherry Dec 16, 2024
3d7b89f
add mode choice
inkcherry Dec 16, 2024
47a6b0b
fix chatglm
inkcherry Dec 16, 2024
3a23997
fix chatglm2 with transformers=4.40 version
inkcherry Dec 16, 2024
e3ec46e
uneven
inkcherry Dec 16, 2024
9685879
fix uneven
inkcherry Dec 16, 2024
7b99b03
fix training
inkcherry Dec 16, 2024
570645f
refine code
inkcherry Dec 17, 2024
3729b64
remove skip bcase&reduce
inkcherry Dec 17, 2024
62d8858
fix typo
inkcherry Dec 17, 2024
dd17313
format
inkcherry Dec 17, 2024
93cf6f5
refine code
inkcherry Dec 18, 2024
87c4bc2
refine code
inkcherry Dec 18, 2024
1714bb5
refine
inkcherry Dec 18, 2024
dadf915
update yuan
inkcherry Dec 19, 2024
86c9399
optimize usage of move function
inkcherry Dec 19, 2024
2526dc6
refine args usage
inkcherry Dec 19, 2024
c9fd699
format
inkcherry Dec 19, 2024
797e71f
zero1 compatible
inkcherry Dec 19, 2024
86ae65e
remove wa
inkcherry Dec 22, 2024
3e40024
fix cpu device name
inkcherry Dec 22, 2024
7d94b77
fix lm-head
inkcherry Dec 23, 2024
b297950
add detach
inkcherry Dec 23, 2024
67ce220
fix ipex intergration
inkcherry Dec 23, 2024
f818be9
fix tied_embedding
inkcherry Dec 24, 2024
11c98f6
Merge remote-tracking branch 'origin/master' into autotp_training
inkcherry Jan 2, 2025
e22b625
format
inkcherry Jan 2, 2025
8531b64
Merge branch 'master' into autotp_training
tjruwase Jan 6, 2025
8d19e01
Merge branch 'master' into autotp_training
loadams Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
from .runtime.hybrid_engine import DeepSpeedHybridEngine
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
from .inference.config import DeepSpeedInferenceConfig
from .inference.config import DeepSpeedInferenceConfig, AUTOTP_MODE
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
Expand Down Expand Up @@ -364,3 +364,26 @@ def init_inference(model, config=None, **kwargs):
engine = InferenceEngine(model, config=ds_inference_config)

return engine


def tp_model_init(model, tp_size, dtype):
"""
Initialize the model for tensor parallelism.

Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.

Returns:
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."

set_autotp_mode(training=True)
model = init_inference(model=model, mp_size=tp_size, dtype=dtype, replace_with_kernel_inject=False).module
setattr(model, 'ds_autotp_parsed', True)

return model
6 changes: 6 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


@timed_op
def broadcast_object_list(object_list, src, group=None, device=None):
global cdb
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)


@timed_op
def all_gather(tensor_list,
tensor,
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@disable_compiler_collective
def broadcast_object_list(self, object_list, src, group=None, device=None):
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)

@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class MoETypeEnum(str, Enum):
standard = "standard"


class AUTOTP_MODE(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should ideally be outside of the inference module. For example, in deepspeed/runtime/tensor_parallel module?

TRAINING = "TRAINING"
INFERENCE = "INFERENCE"


class DeepSpeedTPConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """

Expand Down
8 changes: 7 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject import replace_transformer_layer, generic_injection
Expand Down Expand Up @@ -247,6 +248,11 @@ def _post_forward_hook(self, module, input, output):
self._model_times.append(elapsed_time)

def _create_model_parallel_group(self, config):

if is_autotp_training_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, control flow for training should not come here. I think some refactoring/restructuring is needed for code quality.

groups._init_tp_mesh_device(config.tensor_parallel.tp_size)
self.mp_group = groups.get_tensor_model_parallel_group()
return
# Call the init process
if InferenceEngine.inference_mp_group is None:
init_distributed()
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .policy import DSPolicy
90 changes: 30 additions & 60 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearALlreduce, Yuan_LinearLayer, GLM_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original coding style is LinearAllreduce instead of LinearALlreduce.

from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device):
Expand Down Expand Up @@ -324,10 +326,18 @@ def tp_parser(model):
return policy_list

def set_tensor_parallel_config(self, mp_size, mp_group):

if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return

self.mp_size = mp_size
self.mp_group = mp_group

def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
weight_shape = child.weight.shape
Expand All @@ -339,77 +349,37 @@ def _replace(self, child, name, conv_linear_layer):
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
return Yuan_LinearLayer(child, self.mp_group)

elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
return Yuan_LinearALlreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional code block is trying to deal with "MLP including chunk layer" (general case), but the returned module/object is in the name of GLM prefix.
It could be better to rename the GLM_LinearLayer to sth like GateUpPack_LinearLayer.

return GLM_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0] // mp_size, weight_shape[1]]
setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()

if require_tp_fused_qkvw(name, self.mp_size):
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
bias_data_dc = None
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
return LinearLayer(child, self.mp_group, name=name)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
Expand Down
Loading
Loading