Skip to content

Commit

Permalink
Revert "option to modify configure_optimizers"
Browse files Browse the repository at this point in the history
This reverts commit 7d83799.
  • Loading branch information
JQGoh committed Jan 28, 2025
1 parent 0266daf commit 5765c20
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 98 deletions.
47 changes: 0 additions & 47 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # customized by set_configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
" self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []\n",
Expand Down Expand Up @@ -401,9 +398,6 @@
" random.seed(self.random_seed)\n",
"\n",
" def configure_optimizers(self):\n",
" if self.config_optimizers is not None:\n",
" return self.config_optimizers\n",
" \n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
Expand Down Expand Up @@ -438,47 +432,6 @@
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
"\n",
" def set_configure_optimizers(\n",
" self, \n",
" optimizer=None,\n",
" scheduler=None,\n",
" interval='step',\n",
" frequency=1,\n",
" monitor='val_loss',\n",
" strict=True,\n",
" name=None\n",
" ):\n",
" \"\"\"Helper function to customize the lr_scheduler_config as detailed in \n",
" https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers\n",
"\n",
" Calling set_configure_optimizers() with valid `optimizer`, `scheduler` shall modify the returned \n",
" dictionary of key='optimizer', key='lr_scheduler' in configure_optimizers().\n",
" Note that the default choice of `interval` in set_configure_optiizers() is 'step',\n",
" which differs from the choice of 'epoch' used in lightning_module. \n",
" \"\"\"\n",
" lr_scheduler_config = {\n",
" 'interval': interval,\n",
" 'frequency': frequency,\n",
" 'monitor': monitor,\n",
" 'strict': strict,\n",
" 'name': name,\n",
" }\n",
"\n",
" if scheduler is not None and optimizer is not None:\n",
" if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")\n",
" if not isinstance(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid instance of torch.optim.Optimizer\") \n",
" \n",
" lr_scheduler_config[\"scheduler\"] = scheduler\n",
" self.config_optimizers = {\n",
" 'optimizer': optimizer,\n",
" 'lr_scheduler': lr_scheduler_config,\n",
" }\n",
" else:\n",
" # falls back to default option as specified in configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
" def get_test_size(self):\n",
" return self.test_size\n",
"\n",
Expand Down
51 changes: 0 additions & 51 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ def __init__(
lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
)

# customized by set_configure_optimizers()
self.config_optimizers = None

# Variables
self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []
self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []
Expand Down Expand Up @@ -378,9 +375,6 @@ def on_fit_start(self):
random.seed(self.random_seed)

def configure_optimizers(self):
if self.config_optimizers is not None:
return self.config_optimizers

if self.optimizer:
optimizer_signature = inspect.signature(self.optimizer)
optimizer_kwargs = deepcopy(self.optimizer_kwargs)
Expand Down Expand Up @@ -421,51 +415,6 @@ def configure_optimizers(self):
)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

def set_configure_optimizers(
self,
optimizer=None,
scheduler=None,
interval="step",
frequency=1,
monitor="val_loss",
strict=True,
name=None,
):
"""Helper function to customize the lr_scheduler_config as detailed in
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
Calling set_configure_optimizers() with valid `optimizer`, `scheduler` shall modify the returned
dictionary of key='optimizer', key='lr_scheduler' in configure_optimizers().
Note that the default choice of `interval` in set_configure_optiizers() is 'step',
which differs from the choice of 'epoch' used in lightning_module.
"""
lr_scheduler_config = {
"interval": interval,
"frequency": frequency,
"monitor": monitor,
"strict": strict,
"name": name,
}

if scheduler is not None and optimizer is not None:
if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler):
raise TypeError(
"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler"
)
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
"optimizer is not a valid instance of torch.optim.Optimizer"
)

lr_scheduler_config["scheduler"] = scheduler
self.config_optimizers = {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler_config,
}
else:
# falls back to default option as specified in configure_optimizers()
self.config_optimizers = None

def get_test_size(self):
return self.test_size

Expand Down

0 comments on commit 5765c20

Please sign in to comment.