Skip to content
Snippets Groups Projects
Commit 1f459106 authored by pbac's avatar pbac
Browse files

Made backward step function

parent df4ca77f
Branches
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -23,7 +23,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
# #
# The horizons to fit for # The horizons to fit for
kseq = NA, kseq = NA,
# The (transformation stage) parameters used for the fit # The (transformation stage) parameters (only the ones set in last call of insert_prm())
prm = NA, prm = NA,
# Stores the maximum lag for AR terms # Stores the maximum lag for AR terms
maxlagAR = NA, maxlagAR = NA,
...@@ -83,7 +83,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -83,7 +83,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
#---------------------------------------------------------------- #----------------------------------------------------------------
# Get the transformation parameters # Get the transformation parameters (set for optimization)
get_prmbounds = function(nm){ get_prmbounds = function(nm){
if(nm == "init"){ if(nm == "init"){
if(is.null(dim(self$prmbounds))){ if(is.null(dim(self$prmbounds))){
...@@ -130,8 +130,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -130,8 +130,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
} }
# MUST INCLUDE SOME checks here and print useful messages if something is not right # MUST INCLUDE SOME checks here and print useful messages if something is not right
if(any(is.na(prm))){ stop(pst("None of the parameters (in prm) must be NA: prm=",prm)) } if(any(is.na(prm))){ stop(pst("None of the parameters (in prm) must be NA: prm=",prm)) }
# Keep the prm given
# Keep the prm
self$prm <- prm self$prm <- prm
# Find if any opt parameters, first the one with "__" hence for the inputs # Find if any opt parameters, first the one with "__" hence for the inputs
pinputs <- prm[grep("__",nams(prm))] pinputs <- prm[grep("__",nams(prm))]
...@@ -152,7 +151,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -152,7 +151,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
# Find if the input i have prefix match with the opt. parameter ii # Find if the input i have prefix match with the opt. parameter ii
if(pnms[ii]==nams(self$inputs)[i]){ if(pnms[ii]==nams(self$inputs)[i]){
# if the opt. parameter is in the expr, then replace # if the opt. parameter is in the expr, then replace
self$inputs[[i]]$expr <- private$replace_value(name = pprm[ii], self$inputs[[i]]$expr <- private$replace_prmvalue(name = pprm[ii],
value = pinputs[ii], value = pinputs[ii],
expr = self$inputs[[i]]$expr) expr = self$inputs[[i]]$expr)
} }
...@@ -160,12 +159,12 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -160,12 +159,12 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
} }
} }
# ################ # ################
# For the fit parameters, insert from prm if any found # For the regression parameters, insert from prm if any found
if (length(preg) & any(!is.na(self$regprmexpr))) { if (length(preg) & any(!is.na(self$regprmexpr))) {
nams(preg) nams(preg)
for(i in 1:length(preg)){ for(i in 1:length(preg)){
# if the opt. parameter is in the expr, then replace # if the opt. parameter is in the expr, then replace
self$regprmexpr <- private$replace_value(name = nams(preg)[i], self$regprmexpr <- private$replace_prmvalue(name = nams(preg)[i],
value = preg[i], value = preg[i],
expr = self$regprmexpr) expr = self$regprmexpr)
} }
...@@ -175,6 +174,32 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -175,6 +174,32 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
#---------------------------------------------------------------- #----------------------------------------------------------------
#----------------------------------------------------------------
# Return the values of the parameter names given
get_prmvalues = function(prmnames){
#
regprm <- eval(parse(text = self$regprmexpr))
# From the input parameters
val <- sapply(prmnames, function(nm){
if(length(grep("__",nm))){
tmp <- strsplit(nm, "__")[[1]]
if(tmp[1] %in% names(self$inputs)){
return(as.numeric(private$get_exprprmvalue(tmp[2], self$inputs[[tmp[1]]]$expr)))
}else{
return(NA)
}
}else{
if(nm %in% names(regprm)){
return(as.numeric(regprm[nm]))
}else{
return(NA)
}
}
})
return(val)
},
#----------------------------------------------------------------
#---------------------------------------------------------------- #----------------------------------------------------------------
# Function for transforming the input data to the regression data # Function for transforming the input data to the regression data
transform_data = function(data){ transform_data = function(data){
...@@ -289,7 +314,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -289,7 +314,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
#---------------------------------------------------------------- #----------------------------------------------------------------
# Replace the value in "name=value" in expr # Replace the value in "name=value" in expr
replace_value = function(name, value, expr){ replace_prmvalue = function(name, value, expr){
# First make regex # First make regex
pattern <- gsub("\\.", ".*", name) pattern <- gsub("\\.", ".*", name)
# Try to find it in the input # Try to find it in the input
...@@ -298,7 +323,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -298,7 +323,7 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
if(pos>0){ if(pos>0){
pos <- c(pos+attr(pos,"match.length")) pos <- c(pos+attr(pos,"match.length"))
# Find the substr to replace with the prm value # Find the substr to replace with the prm value
(tmp <- substr(expr, pos, nchar(expr))) tmp <- substr(expr, pos, nchar(expr))
pos2 <- regexpr(",|)", tmp) pos2 <- regexpr(",|)", tmp)
# Insert the prm value and return # Insert the prm value and return
expr <- pst(substr(expr,1,pos-1), "=", value, substr(expr,pos+pos2-1,nchar(expr))) expr <- pst(substr(expr,1,pos-1), "=", value, substr(expr,pos+pos2-1,nchar(expr)))
...@@ -309,6 +334,30 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list( ...@@ -309,6 +334,30 @@ forecastmodel <- R6::R6Class("forecastmodel", public = list(
}, },
#---------------------------------------------------------------- #----------------------------------------------------------------
#----------------------------------------------------------------
get_exprprmvalue = function(name, expr){
#name <- "degree"
#expr <- "bspline(tday, Boundary.knots = c(start=6,18), degree = 5, intercept=TRUE) %**% ones() + 2 + ones()"
#expr <- "one()"
expr <- gsub(" ", "", expr)
# First make regex
pattern <- gsub("\\.", ".*", name)
# Try to find it in the input
pos <- regexpr(pattern, expr)
# Only replace if prm was found
if(pos>0){
pos <- c(pos+attr(pos,"match.length"))
# Find the substr to replace with the prm value
(tmp <- substr(expr, pos, nchar(expr)))
pos2 <- regexpr(",|)", tmp)
return(substr(tmp, 2, pos2-1))
}else{
return(NA)
}
},
#----------------------------------------------------------------
#---------------------------------------------------------------- #----------------------------------------------------------------
# For deep cloning, in order to get the inputs list of R6 objects copied # For deep cloning, in order to get the inputs list of R6 objects copied
deep_clone = function(name, value) { deep_clone = function(name, value) {
...@@ -344,9 +393,9 @@ print.forecastmodel <- function(x, ...){ ...@@ -344,9 +393,9 @@ print.forecastmodel <- function(x, ...){
model <- x model <- x
# cat("\nObject of class forecastmodel (R6::class)\n\n") # cat("\nObject of class forecastmodel (R6::class)\n\n")
cat("\nOutput:",model$output) cat("\nOutput:",model$output)
cat("Inputs: ") cat("\nInputs: ")
if(length(model$inputs) == 0 ){ if(length(model$inputs) == 0 ){
cat("No inputs\n") cat("\nNo inputs")
}else{ }else{
cat(names(model$inputs)[1],"=",model$inputs[[1]]$expr,"\n") cat(names(model$inputs)[1],"=",model$inputs[[1]]$expr,"\n")
for(i in 2:length(model$inputs)){ for(i in 2:length(model$inputs)){
......
...@@ -135,7 +135,7 @@ ...@@ -135,7 +135,7 @@
#---------------------------------------------------------------- #----------------------------------------------------------------
#' @section \code{$insert_prm(prm)}: #' @section \code{$insert_prm(prm)}:
#' Insert the transformation parameters prm in the input expressions and regression expressions, and keep them (simply string manipulation). #' Insert the transformation parameters prm in the input expressions and regression expressions, and keep them in $prm (simply string manipulation).
#' #'
#' @examples #' @examples
#' #'
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#' @title Optimize parameters for onlineforecast model fitted with LM #' @title Optimize parameters for onlineforecast model fitted with LM
#' @param model The onlineforecast model, including inputs, output, kseq, p #' @param model The onlineforecast model, including inputs, output, kseq, p
#' @param data The data.list including the variables used in the model. #' @param data The data.list including the variables used in the model.
#' @param kseq The horizons to fit for (if not set, then model$kseq is used)
#' @param scorefun The function to be score used for calculating the score to be optimized. #' @param scorefun The function to be score used for calculating the score to be optimized.
#' @param cachedir A character specifying the path (and prefix) of the cache file name. If set to \code{""}, then no cache will be loaded or written. See \url{https://onlineforecasting.org/vignettes/nice-tricks.html} for examples. #' @param cachedir A character specifying the path (and prefix) of the cache file name. If set to \code{""}, then no cache will be loaded or written. See \url{https://onlineforecasting.org/vignettes/nice-tricks.html} for examples.
#' @param printout A logical determining if the score function is printed out in each iteration of the optimization. #' @param printout A logical determining if the score function is printed out in each iteration of the optimization.
...@@ -55,7 +56,7 @@ ...@@ -55,7 +56,7 @@
#' #'
#' @importFrom stats optim #' @importFrom stats optim
#' @export #' @export
lm_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, method="L-BFGS-B", ...){ lm_optim <- function(model, data, kseq = NA, scorefun = rmse, cachedir="", cachererun=FALSE, printout=TRUE, method="L-BFGS-B", ...){
## Take the parameters bounds from the parameter bounds set in the model ## Take the parameters bounds from the parameter bounds set in the model
init <- model$get_prmbounds("init") init <- model$get_prmbounds("init")
lower <- model$get_prmbounds("lower") lower <- model$get_prmbounds("lower")
...@@ -64,21 +65,33 @@ lm_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, m ...@@ -64,21 +65,33 @@ lm_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, m
if(any(is.na(lower))){ lower[is.na(lower)] <- -Inf} if(any(is.na(lower))){ lower[is.na(lower)] <- -Inf}
if(any(is.na(upper))){ lower[is.na(upper)] <- Inf} if(any(is.na(upper))){ lower[is.na(upper)] <- Inf}
# Clone the model no matter what (at least model$kseq should not be changed no matter if optimization is stopped)
m <- model$clone_deep()
if(!is.na(kseq[1])){
m$kseq <- kseq
}
## Caching the results based on some of the function arguments ## Caching the results based on some of the function arguments
if(cachedir != ""){ if(cachedir != ""){
## Have to insert the parameters in the expressions # Have to insert the parameters in the expressions to get the right state of the model for unique checksum
model$insert_prm(init) m$insert_prm(init)
# Have to reset the state first to remove dependency of previous calls
m$reset_state()
## Give all the elements to calculate the unique cache name ## Give all the elements to calculate the unique cache name
cnm <- cache_name(lm_fit, lm_optim, model$outputrange, model$regprm, model$transform_data(data), cnm <- cache_name(lm_fit, lm_optim, m$outputrange, m$regprm, m$transform_data(data),
data[[model$output]], scorefun, init, lower, upper, cachedir = cachedir) data[[m$output]], scorefun, init, lower, upper, cachedir = cachedir)
## Maybe load the cached result # Load the cached result if it exists
if(file.exists(cnm)){ return(readRDS(cnm)) } if(file.exists(cnm) & !cachererun){
res <- readRDS(cnm)
# Set the optimized parameters into the model
model$insert_prm(res$par)
return(res)
}
} }
## Run the optimization # Run the optimization
res <- optim(par = init, res <- optim(par = init,
fn = lm_fit, fn = lm_fit,
model = model, model = m,
data = data, data = data,
scorefun = scorefun, scorefun = scorefun,
printout = printout, printout = printout,
...@@ -87,8 +100,9 @@ lm_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, m ...@@ -87,8 +100,9 @@ lm_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, m
upper = upper, upper = upper,
method = method, method = method,
...) ...)
## Save the result in the cachedir # Save the result in the cachedir
if(cachedir != ""){ cache_save(res, cnm) } if(cachedir != ""){ cache_save(res, cnm) }
## Return the result # Set the optimized parameters into the model
model$insert_prm(res$par)
return(res) return(res)
} }
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#' @title Optimize parameters for onlineforecast model fitted with RLS #' @title Optimize parameters for onlineforecast model fitted with RLS
#' @param model The onlineforecast model, including inputs, output, kseq, p #' @param model The onlineforecast model, including inputs, output, kseq, p
#' @param data The data.list including the variables used in the model. #' @param data The data.list including the variables used in the model.
#' @param kseq The horizons to fit for (if not set, then model$kseq is used)
#' @param scorefun The function to be score used for calculating the score to be optimized. #' @param scorefun The function to be score used for calculating the score to be optimized.
#' @param cachedir A character specifying the path (and prefix) of the cache file name. If set to \code{""}, then no cache will be loaded or written. See \url{https://onlineforecasting.org/vignettes/nice-tricks.html} for examples. #' @param cachedir A character specifying the path (and prefix) of the cache file name. If set to \code{""}, then no cache will be loaded or written. See \url{https://onlineforecasting.org/vignettes/nice-tricks.html} for examples.
#' @param printout A logical determining if the score function is printed out in each iteration of the optimization. #' @param printout A logical determining if the score function is printed out in each iteration of the optimization.
...@@ -59,7 +60,7 @@ ...@@ -59,7 +60,7 @@
#' #'
#' #'
#' @export #' @export
rls_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, method="L-BFGS-B", ...){ rls_optim <- function(model, data, kseq = NA, scorefun = rmse, cachedir="", cachererun=FALSE, printout=TRUE, method="L-BFGS-B", ...){
# Take the parameters bounds from the parameter bounds set in the model # Take the parameters bounds from the parameter bounds set in the model
init <- model$get_prmbounds("init") init <- model$get_prmbounds("init")
lower <- model$get_prmbounds("lower") lower <- model$get_prmbounds("lower")
...@@ -68,24 +69,35 @@ rls_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, ...@@ -68,24 +69,35 @@ rls_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE,
if(any(is.na(lower))){ lower[is.na(lower)] <- -Inf} if(any(is.na(lower))){ lower[is.na(lower)] <- -Inf}
if(any(is.na(upper))){ lower[is.na(upper)] <- Inf} if(any(is.na(upper))){ lower[is.na(upper)] <- Inf}
# Clone the model no matter what (at least model$kseq should not be changed no matter if optimization is stopped)
m <- model$clone_deep()
if(!is.na(kseq[1])){
m$kseq <- kseq
}
# Caching the results based on some of the function arguments # Caching the results based on some of the function arguments
if(cachedir != ""){ if(cachedir != ""){
# Have to insert the parameters in the expressions to get the right state of the model for unique checksum # Have to insert the parameters in the expressions to get the right state of the model for unique checksum
model$insert_prm(init) m$insert_prm(init)
# Give all the elements needed to calculate the unique cache name
# This is maybe smarter, don't have to calculate the transformation of the data: cnm <- cache_name(model$regprm, getse(model$inputs, nms="expr"), model$output, model$prmbounds, model$kseq, data, objfun, init, lower, upper, cachedir = cachedir)
# Have to reset the state first to remove dependency of previous calls # Have to reset the state first to remove dependency of previous calls
model$reset_state() m$reset_state()
cnm <- cache_name(rls_fit, rls_optim, model$outputrange, model$regprm, model$transform_data(data), data[[model$output]], scorefun, init, lower, upper, cachedir = cachedir) # Give all the elements needed to calculate the unique cache name
# Maybe load the cached result # This is maybe smarter, don't have to calculate the transformation of the data: cnm <- cache_name(m$regprm, getse(m$inputs, nms="expr"), m$output, m$prmbounds, m$kseq, data, objfun, init, lower, upper, cachedir = cachedir)
if(file.exists(cnm)){ return(readRDS(cnm)) } cnm <- cache_name(rls_fit, rls_optim, m$outputrange, m$regprm, m$transform_data(data), data[[m$output]], scorefun, init, lower, upper, kseq, cachedir = cachedir)
# Load the cached result if it exists
if(file.exists(cnm) & !cachererun){
res <- readRDS(cnm)
# Set the optimized parameters into the model
model$insert_prm(res$par)
return(res)
}
} }
# Run the optimization # Run the optimization
res <- optim(par = init, res <- optim(par = init,
fn = rls_fit, fn = rls_fit,
# Parameters to pass to rls_fit # Parameters to pass to rls_fit
model = model, model = m,
data = data, data = data,
scorefun = scorefun, scorefun = scorefun,
printout = printout, printout = printout,
...@@ -95,9 +107,9 @@ rls_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE, ...@@ -95,9 +107,9 @@ rls_optim <- function(model, data, scorefun = rmse, cachedir="", printout=TRUE,
upper = upper, upper = upper,
method = method, method = method,
...) ...)
# Save the result in the cachedir # Save the result in the cachedir
if(cachedir != ""){ cache_save(res, cnm)} if(cachedir != ""){ cache_save(res, cnm)}
# Return the result # Set the optimized parameters into the model
model$insert_prm(res$par)
return(res) return(res)
} }
#' @importFrom parallel mclapply ## #' @importFrom parallel mclapply
rls_reduce <- function(model, data, preduce=list(NA), scorefun = rmse){ ## rls_reduce <- function(model, data, prmreduce=list(NA), kseq = NA, scorefun = rmse){
## prm test ## ## prm test
##preduce <- list(I__degree = c(min=1, init=7), mu_tday__nharmonics = c(min=1, init=7)) ## ##prmreduce <- list(I__degree = c(min=1, init=7), mu_tday__nharmonics = c(min=1, init=7))
prmin <- unlist(getse(preduce, 1)) ## prmin <- unlist(getse(prmreduce, 1))
pr <- unlist(getse(preduce, 2)) ## pr <- unlist(getse(prmreduce, 2))
##!! deep=TRUE didn't work, gave: "Error: C stack usage 9524532 is too close to the limit" ## #
m <- model$clone_deep() ## m <- model$clone_deep()
## Insert the starting p reduction values ## # Insert the starting p reduction values
if(!is.na(preduce[1])){ ## if(!is.na(prmreduce[1])){
m$insert_prm(pr) ## m$insert_prm(pr)
} ## }
## ## #
valref <- rls_optim(m, data, printout=FALSE)$value ## valref <- rls_optim(m, data, kseq, printout=FALSE)$value
## ## #
while(TRUE){ ## while(TRUE){
## ## #
message("------------------------------------") ## message("------------------------------------")
message("Reference score value",valref) ## message("Reference score value",valref)
## -------- ## # --------
## Remove inputs one by one ## # Remove inputs one by one
message("\nRemoving inputs one by one") ## message("\nRemoving inputs one by one")
valsrm <- mclapply(1:length(model$inputs), function(i){ ## valsrm <- mclapply(1:length(model$inputs), function(i){
mr <- m$clone_deep() ## mr <- m$clone_deep()
mr$inputs[[i]] <- NULL ## mr$inputs[[i]] <- NULL
rls_optim(mr, data, printout=FALSE)$value ## rls_optim(mr, data, kseq, printout=FALSE)$value
}) ## })
valsrm <- unlist(valsrm) ## valsrm <- unlist(valsrm)
names(valsrm) <- names(m$inputs) ## names(valsrm) <- names(m$inputs)
message("Scores") ## message("Scores")
print(valsrm) ## print(valsrm)
## -------- ## # --------
## Reduce parameter values if specified ## # Reduce parameter values if specified
if(!is.na(pr[1])){ ## if(!is.na(pr[1])){
message("\nReducing prm with -1 one by one") ## message("\nReducing prm with -1 one by one")
valspr <- mclapply(1:length(pr), function(i){ ## valspr <- mclapply(1:length(pr), function(i){
mr <- m$clone_deep() ## mr <- m$clone_deep()
p <- pr ## p <- pr
## Only count down if above minimum ## # Only count down if above minimum
if( p[i] >= prmin[i] ){ ## if( p[i] >= prmin[i] ){
p[i] <- p[i] - 1 ## p[i] <- p[i] - 1
} ## }
mr$insert_prm(p) ## mr$insert_prm(p)
val <- rls_optim(mr, data, printout=FALSE)$value ## val <- rls_optim(mr, data, kseq, printout=FALSE)$value
## ## #
return(val) ## return(val)
}) ## })
valspr <- unlist(valspr) ## valspr <- unlist(valspr)
names(valspr) <- names(pr) ## names(valspr) <- names(pr)
message("Scores") ## message("Scores")
print(valspr) ## print(valspr)
} ## }
## Is one the reduced smaller than the current ref? ## # Is one the reduced smaller than the current ref?
if( min(c(valsrm,valspr)) < valref ){ ## if( min(c(valsrm,valspr)) < valref ){
if(which.min(c(min(valsrm),min(valspr))) == 1){ ## if(which.min(c(min(valsrm),min(valspr))) == 1){
## One of the models with one of the inputs removed is best ## # One of the models with one of the inputs removed is best
imin <- which.min(valsrm) ## imin <- which.min(valsrm)
message("Removing input",names(m$inputs)[imin]) ## message("Removing input",names(m$inputs)[imin])
m$inputs[[imin]] <- NULL ## m$inputs[[imin]] <- NULL
}else{ ## }else{
## One of the models with reduced parameter values is best ## # One of the models with reduced parameter values is best
imin <- which.min(valspr) ## imin <- which.min(valspr)
pr[imin] <- pr[imin] - 1 ## pr[imin] <- pr[imin] - 1
m$insert_prm(pr) ## m$insert_prm(pr)
message("Reduced parameter",names(pr)[imin],"to:",pr[imin]) ## message("Reduced parameter",names(pr)[imin],"to:",pr[imin])
} ## }
valref <- min(c(valsrm,valspr)) ## valref <- min(c(valsrm,valspr))
}else{ ## }else{
## No improvement obtained from reduction, so return the current model ## # No improvement obtained from reduction, so return the current model
message("------------------------------------\n\nDone") ## message("------------------------------------\n\nDone")
return(m) ## return(m)
} ## }
} ## }
} ## }
...@@ -80,7 +80,7 @@ rls_summary <- function(object, scoreperiod = NA, scorefun = rmse, usecomplete = ...@@ -80,7 +80,7 @@ rls_summary <- function(object, scoreperiod = NA, scorefun = rmse, usecomplete =
if(!printit){ if(!printit){
return(retval) return(retval)
} }
# Insert the optimized parameters # Insert the optimized parameters (or actually $prm are just the last parameters given to insert_prm())
m <- fit$model$clone_deep() m <- fit$model$clone_deep()
m$prm[names(m$prm)] <- signif(m$prm, digits=3) m$prm[names(m$prm)] <- signif(m$prm, digits=3)
m$insert_prm(m$prm) m$insert_prm(m$prm)
......
#' @importFrom parallel mclapply
#'
step_backward <- function(object, data, kseq = NA, prm=list(NA), optimfun = rls_optim, scorefun = rmse, ...){
# Do:
# - Maybe have "cloneit" argument in optimfun, then don't clone inside optim.
# - Add argument controlling how much is kept in each iteration (e.g all fitted models)
#
# - Help: prm <- list(I__degree = c(min=1, max=7), mu_tday__nharmonics = c(min=1, max=7))
# - help: It's not checked that it's the score is calculated on the same values! WARNING should be printed if some models don't forecast same points
#
model <- object
#
m <- model$clone_deep()
# Insert the starting prm reduction values
if(!is.na(prm[1])){
prmMin <- unlist(getse(prm, "min"))
# ??insert_prm should keep only the ones that can be changed
m$insert_prm(unlist(getse(prm, "max")))
}
# For keeping all the results
L <- list()
istep <- 1
# Optimize the reference model
res <- optimfun(m, data, kseq, printout=TRUE, ...)
valRef <- res$value
L[[istep]] <- list(model = m$clone_deep(), result = res)
#
done <- FALSE
while(!done){
#
istep <- istep + 1
# Insert the optimized parameters from last step
m$prmbounds[names(res$par),"init"] <- res$par
#
message("------------------------------------")
message("Reference score value: ",valRef)
# --------
# Generate the reduced models
mReduced <- mclapply(1:length(m$inputs), function(i){
mr <- m$clone_deep()
# Insert the optimized parameters from the reference model
mr$inputs[[i]] <- NULL
return(mr)
})
names(mReduced) <- names(m$inputs)
if(!is.na(prm[1])){
tmp <- mclapply(1:length(prm), function(i){
p <- m$get_prmvalues(names(prm[i]))
# If the input is not in model, then p is NA, so don't include it for fitting
if(!is.na(p)){
# Only the ones with prms above minimum
if(p > prmMin[i]){
p <- p - 1
mr <- m$clone_deep()
mr$insert_prm(p)
return(mr)
}
}
return(NA)
})
names(tmp) <- names(prm)
tmp <- tmp[!is.na(tmp)]
mReduced <- c(mReduced, tmp)
}
resReduced <- lapply(1:length(mReduced), function(i, ...){
res <- optimfun(mReduced[[i]], data, kseq, printout=FALSE, ...)
message(names(mReduced)[[i]], ": ", res$value)
return(res)
}, ...)
names(resReduced) <- names(mReduced)
valReduced <- unlist(getse(resReduced, "value"))
imin <- which.min(valReduced)
# Is one the reduced smaller than the current ref?
if( valReduced[imin] < valRef ){
# Keep the best model
m <- mReduced[[imin]]
res <- resReduced[[imin]]
valRef <- res$value
# Keep for the result
L[[istep]] <- list(model = m$clone_deep(), result = resReduced[[imin]])
}else{
# No improvement obtained from reduction, so return the current model (last in the list)
message("------------------------------------\n\nDone")
done <- TRUE
}
}
invisible(L)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment