From b01cbc98bfd91a23cb4e950806b030e748a0e376 Mon Sep 17 00:00:00 2001 From: Valerie Sarge Date: Tue, 4 Feb 2025 16:47:59 -0800 Subject: [PATCH] Guard TE import & clean up unused imports Signed-off-by: Valerie Sarge --- nemo/collections/llm/gpt/data/mlperf_govreport.py | 7 +------ nemo/collections/llm/gpt/model/llama.py | 4 ++++ scripts/llm/performance/mlperf_lora_llama2_70b.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/gpt/data/mlperf_govreport.py b/nemo/collections/llm/gpt/data/mlperf_govreport.py index a394aa3bb39d..259fa532504c 100644 --- a/nemo/collections/llm/gpt/data/mlperf_govreport.py +++ b/nemo/collections/llm/gpt/data/mlperf_govreport.py @@ -14,18 +14,13 @@ import shutil from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from datasets import DatasetDict, load_dataset import numpy as np -import torch -from torch import nn - from nemo.collections.llm.gpt.data.core import get_dataset_root from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule -from nemo.collections.llm.utils import Config -from nemo.lightning import OptimizerModule from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 0f74a8a77633..f81e9a80a979 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -267,6 +267,10 @@ def __init__( ): super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + from nemo.utils.import_utils import safe_import + _, HAVE_TE = safe_import("transformer_engine") + assert HAVE_TE, "TransformerEngine is required for MLPerfLoRALlamaModel." + def configure_model(self): # Apply context managers to reduce memory by (1) avoiding unnecessary gradients # and (2) requesting that TE initialize params as FP8. diff --git a/scripts/llm/performance/mlperf_lora_llama2_70b.py b/scripts/llm/performance/mlperf_lora_llama2_70b.py index a01807464f9f..ec4a3a6e748a 100644 --- a/scripts/llm/performance/mlperf_lora_llama2_70b.py +++ b/scripts/llm/performance/mlperf_lora_llama2_70b.py @@ -26,7 +26,7 @@ from nemo import lightning as nl from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin from utils import (