Skip to content

Commit

Permalink
Merge pull request #50 from Nixtla/feat/finetune_depth
Browse files Browse the repository at this point in the history
feat: add finetune_depth parameter
  • Loading branch information
MMenchero authored Dec 19, 2024
2 parents d0fcff1 + c2ee9cc commit 9c45bea
Show file tree
Hide file tree
Showing 49 changed files with 134,798 additions and 133,872 deletions.
8 changes: 5 additions & 3 deletions R/nixtla_client_cross_validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param n_windows Number of windows to evaluate.
#' @param step_size Step size between each cross validation window. If NULL, it will equal the forecast horizon (h).
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
#'
Expand All @@ -27,7 +28,7 @@
#' fcst <- nixtlar::nixtla_client_cross_validation(df, h = 8, id_col = "unique_id", n_windows = 5)
#' }
#'
nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col="unique_id", time_col="ds", target_col="y", level=NULL, quantiles=NULL, n_windows=1, step_size=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1"){
nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col="unique_id", time_col="ds", target_col="y", level=NULL, quantiles=NULL, n_windows=1, step_size=NULL, finetune_steps=0, finetune_depth=1, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1"){

# Validate input ----
if(!is.data.frame(df) & !inherits(df, "tbl_df") & !inherits(df, "tsibble")){
Expand Down Expand Up @@ -130,6 +131,7 @@ nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col="unique_id
freq = freq,
clean_ex_first = clean_ex_first,
finetune_steps = finetune_steps,
finetune_depth = finetune_depth,
finetune_loss = finetune_loss
)

Expand Down
23 changes: 19 additions & 4 deletions R/nixtla_client_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#' @param X_df A tsibble or a data frame with future exogenous variables.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param add_history Return fitted values of the model.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
Expand All @@ -27,7 +28,7 @@
#' fcst <- nixtlar::nixtla_client_forecast(df, h=8, id_col="unique_id", level=c(80,95))
#' }
#'
nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_col="ds", target_col="y", X_df=NULL, level=NULL, quantiles=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, add_history=FALSE, model="timegpt-1"){
nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_col="ds", target_col="y", X_df=NULL, level=NULL, quantiles=NULL, finetune_steps=0, finetune_depth=1, finetune_loss="default", clean_ex_first=TRUE, add_history=FALSE, model="timegpt-1"){

# Validate input ----
if(!is.data.frame(df) & !inherits(df, "tbl_df") & !inherits(df, "tsibble")){
Expand Down Expand Up @@ -116,6 +117,7 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_
freq = freq,
clean_ex_first = clean_ex_first,
finetune_steps = finetune_steps,
finetune_depth = finetune_depth,
finetune_loss = finetune_loss
)

Expand Down Expand Up @@ -254,7 +256,20 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_

# Add fitted values if required ----
if(add_history){
fitted <- nixtla_client_historic(df=df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, quantiles=quantiles, finetune_steps=finetune_steps, finetune_loss=finetune_loss, clean_ex_first=clean_ex_first)
fitted <- nixtla_client_historic(
df=df,
freq=freq,
id_col=id_col,
time_col=time_col,
target_col=target_col,
level=level,
quantiles=quantiles,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first
)

forecast <- dplyr::bind_rows(fitted, forecast)
}

Expand Down
7 changes: 4 additions & 3 deletions R/nixtla_client_historic.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#' @param target_col Column that contains the target variable.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param quantiles Quantiles to forecast. Should be between 0 and 1.
#' @param finetune_steps Number of steps used to finetune 'TimeGPT' in the new data.
#' @param finetune_loss Loss function to use for finetuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param finetune_steps Number of steps used to fine-tune 'TimeGPT' in the new data.
#' @param finetune_depth The depth of the fine-tuning. Uses a scale from 1 to 5, where 1 means little fine-tuning and 5 means that the entire model is fine-tuned.
#' @param finetune_loss Loss function to use for fine-tuning. Options are: "default", "mae", "mse", "rmse", "mape", and "smape".
#' @param clean_ex_first Clean exogenous signal before making the forecasts using 'TimeGPT'.
#' @param model Model to use, either "timegpt-1" or "timegpt-1-long-horizon". Use "timegpt-1-long-horizon" if you want to forecast more than one seasonal period given the frequency of the data.
#'
Expand All @@ -24,7 +25,7 @@
#' fcst <- nixtlar::nixtla_client_historic(df, id_col="unique_id", level=c(80,95))
#' }
#'
nixtla_client_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=NULL, quantiles=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1"){
nixtla_client_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=NULL, quantiles=NULL, finetune_steps=0, finetune_depth=1, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1"){

# Validate input ----
if(!is.data.frame(df) & !inherits(df, "tbl_df") & !inherits(df, "tsibble")){
Expand Down
7 changes: 5 additions & 2 deletions man/nixtla_client_cross_validation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions man/nixtla_client_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions man/nixtla_client_historic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/mocks/api.nixtla.io/model_params-3f263f-POST.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"input_size": 120,
"horizon": 24
},
"request_id": "6UZDCP8S3J"
"request_id": "TFKG7KC8EP"
}
2 changes: 1 addition & 1 deletion tests/mocks/api.nixtla.io/model_params-9255c1-POST.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"input_size": 840,
"horizon": 168
},
"request_id": "YDX6MPGAUH"
"request_id": "MBSFLHXE2U"
}
Loading

0 comments on commit 9c45bea

Please sign in to comment.