diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 5cdd99ff0b90..74fc7f74f6bd 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -6,6 +6,7 @@ import torch import time import os +import deepspeed from deepspeed import comm as dist from deepspeed.utils.logging import log_dist @@ -13,6 +14,7 @@ from packaging import version as pkg_version from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.timer import SynchronizedWallClockTimer +from deepspeed.runtime.compiler import is_compile_supported from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization @@ -185,6 +187,7 @@ def __init__(self, model, config): # Check if local CUDA graphs can be created in replacement modules self.local_cuda_graph = self._local_cuda_graph_used(self.module) + self._is_compiled = False def destroy(self): # Have to import here because inference_module is a global, but python @@ -634,3 +637,22 @@ def _generate(self, *inputs, **kwargs): ) return self.module.generate(*inputs, **kwargs) + + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None: + """ + Compile the module using the specified backend and kwargs. + """ + if not is_compile_supported(): + raise RuntimeError("compile is not supported in your version of PyTorch.") + + if self._is_compiled: + return + + # Avoid graph breaks + deepspeed.utils.nvtx.enable_nvtx = False + self.module.compile(backend=backend, **compile_kwargs) + self._is_compiled = True + + @property + def is_compiled(self) -> bool: + return self._is_compiled