diff --git a/R/meta-method-converged.R b/R/meta-method-converged.R index 5fbfb019..a72d65f3 100644 --- a/R/meta-method-converged.R +++ b/R/meta-method-converged.R @@ -25,6 +25,7 @@ lcMetaConverged = function(method, maxRep = Inf) { #' @rdname lcMetaMethod-interface setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) { attempt = 1L + repeat { enter(verbose, level = verboseLevels$fine, suffix = '') model = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose) @@ -44,6 +45,12 @@ setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) { return (model) } else { attempt = attempt + 1L + seed = sample.int(.Machine$integer.max, 1L) + set.seed(seed) + if (has_lcMethod_args(getLcMethod(method), 'seed')) { + # update fit method with new seed + method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE) + } if (is.infinite(method$maxRep)) { cat(verbose, sprintf('Method failed to converge. Retrying... attempt %d', attempt)) @@ -53,3 +60,13 @@ setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) { } } }) + +#' @rdname lcMetaMethod-interface +setMethod('validate', 'lcMetaConverged', function(method, data, envir = NULL, ...) { + callNextMethod() + + validate_that( + has_lcMethod_args(method, 'maxRep'), + is.count(method$maxRep) + ) +}) diff --git a/man/lcMetaMethod-interface.Rd b/man/lcMetaMethod-interface.Rd index 0986c033..f0d763cd 100644 --- a/man/lcMetaMethod-interface.Rd +++ b/man/lcMetaMethod-interface.Rd @@ -10,6 +10,7 @@ \alias{postFit,lcMetaMethod-method} \alias{validate,lcMetaMethod-method} \alias{fit,lcMetaConverged-method} +\alias{validate,lcMetaConverged-method} \title{lcMetaMethod methods} \usage{ \S4method{compose}{lcMetaMethod}(method, envir = NULL) @@ -29,6 +30,8 @@ \S4method{validate}{lcMetaMethod}(method, data, envir = NULL, ...) \S4method{fit}{lcMetaConverged}(method, data, envir, verbose) + +\S4method{validate}{lcMetaConverged}(method, data, envir = NULL, ...) } \description{ lcMetaMethod methods diff --git a/tests/testthat/test-meta-methods.R b/tests/testthat/test-meta-methods.R index 36d8c808..09f906fa 100644 --- a/tests/testthat/test-meta-methods.R +++ b/tests/testthat/test-meta-methods.R @@ -1,5 +1,36 @@ method = lcMethodLMKM(Value ~ Assessment, id = 'Traj', time = 'Assessment', nClusters = 2) +setClass('lcMethodConv', contains = 'lcMethod') + +lcMethodConv = function( + response = 'Value', + time = 'Assessment', + id = 'Traj', + nClusters = 1, + nAttempts = 1, + ... +) { + mc = match.call.all() + mc$Class = 'lcMethodConv' + do.call(new, as.list(mc)) +} + +setMethod('preFit', 'lcMethodConv', function(method, data, envir, verbose) { + convAttempts <<- 0 + callNextMethod() +}) + +setMethod('fit', 'lcMethodConv', function(method, data, envir, verbose) { + convAttempts <<- convAttempts + 1 + lcModelPartition( + data = data, + response = method$response, + trajectoryAssignments = rep(1, uniqueN(data[[method$id]])), + converged = convAttempts >= method$nAttempts + ) +}) + + test_that('specify converged', { metaMethod = lcMetaConverged(method) expect_s4_class(metaMethod, 'lcMetaConverged') @@ -40,3 +71,37 @@ test_that('meta converged fit', { model = latrend(metaMethod, testLongData) }) }) + +test_that('meta converged fit until converged', { + metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 2), maxRep = 3) + + # workaround because testthat::expect_message() is failing to capture the output... + out = capture.output({ + model = latrend(metaMethod, testLongData, verbose = TRUE) + }, type = 'message') + expect_match(paste0(out, collapse = '\n'), regexp = 'attempt 2') + expect_true(converged(model)) +}) + +test_that('meta converged fit always fails', { + metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 3), maxRep = 2) + expect_warning({ + model = latrend(metaMethod, testLongData) + }, regexp = 'Failed to obtain converged') + + expect_false(converged(model)) +}) + +test_that('meta converged fit with seed on first attempt', { + metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 1, seed = 13)) + model = latrend(metaMethod, testLongData) + + expect_equal(getLcMethod(model)$method$seed, 13) +}) + +test_that('meta converged fit different seed on second attempt', { + metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 2, seed = 13)) + model = latrend(metaMethod, testLongData) + + expect_true(getLcMethod(model)$method$seed != 13) +})