Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove compile wrapper to simplify access to model attributes #5581

Merged
merged 10 commits into from
Jun 17, 2024
153 changes: 1 addition & 152 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,165 +3,14 @@

# DeepSpeed Team

from typing import Union, Callable, Dict, Any
import importlib
import torch
from ..pydantic_v1 import validator
from .config_utils import DeepSpeedConfigModel

COMPILE_CONFIG = "compile"


def is_compile_supported():
return hasattr(torch, "compiler")
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")


def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func


def get_compile_config(param_dict):
if COMPILE_CONFIG in param_dict:
compile_config_dict = param_dict[COMPILE_CONFIG]
else:
compile_config_dict = {}
return CompileConfig(**compile_config_dict)


def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]:
if isinstance(backend, Callable):
return backend

elif isinstance(backend, str):
if backend in torch._dynamo.list_backends(exclude_tags=()):
return backend

# Get module name from backend name
module_name = '.'.join(backend.split('.')[:-1])
fn_name = backend.split('.')[-1]

try:
module = importlib.import_module(module_name)
backend_fn = getattr(module, fn_name)
except ImportError:
raise ValueError(
f"The backend {backend} is not in the list of available backends and could not be imported.")
return backend_fn

raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}")


class CompileConfig(DeepSpeedConfigModel):
"""
[EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings.
Please be aware that these features and API designs are experimental and subject to change.
"""

enabled: bool = False
"""
Enable torch.compile when True.
"""

backend: str = "inductor"
"""
Passed to `backend` argument of torch.compile.
If the given value is not in torch._dynamo.list_backends(),
DeepSpeed attempts to import and instantiate the module with the given name.
"""

kwargs: Dict[str, Any] = {}
"""
Passed to `kwargs` argument of torch.compile.
"""

@validator("enabled")
def validate_enabled(cls, field_value, values):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value


def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__}

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.

Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.

Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.

Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)

engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
3 changes: 0 additions & 3 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .compiler import get_compile_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -911,8 +910,6 @@ def _initialize_params(self, param_dict):
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

self.compile_config = get_compile_config(param_dict)

self.timers_config = get_timers_config(param_dict)

def _batch_assertion(self):
Expand Down
22 changes: 19 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import CompiledModuleWrapper
from .compiler import is_compile_supported
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -361,8 +361,7 @@ def __init__(self,
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors

if self._config.compile_config.enabled:
self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config))
self._is_compiled = False

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
Expand Down Expand Up @@ -3604,3 +3603,20 @@ def empty_partition_cache(self):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

if self.is_compiled:
return

self.module.compile(backend=backend, **compile_kwargs)
self._is_compiled = True

@property
def is_compiled(self) -> bool:
return self._is_compiled
10 changes: 1 addition & 9 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,15 +933,7 @@ def __init__(
_ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
if _ds_config is not None:
if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled:
# memory_efficient_linear displays numerous errors when torch.compile is enabled.
# Refer to https://github.com/pytorch/pytorch/issues/119059 for details.
# Further investigation into performance is necessary, even after resolving this issue because
# the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation.
logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.')
mem_efficient_linear = False
else:
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear

super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
if not dist.is_initialized():
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ def _launch_non_daemonic_procs(self, num_procs):
master_port = get_master_port()
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
processes = []
prev_start_method = mp.get_start_method()
mp.set_start_method('spawn', force=True)
for local_rank in range(num_procs):
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
p.start()
processes.append(p)
mp.set_start_method(prev_start_method, force=True)

# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
Expand Down
85 changes: 0 additions & 85 deletions tests/unit/runtime/compile/test_compile_wrapper.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/unit/runtime/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"enabled": True,
"backend": get_accelerator().get_compile_backend()
}
}

Expand Down
Loading
Loading