diff --git a/nimbleModel/NAMESPACE b/nimbleModel/NAMESPACE index a7cc763..8848a3e 100644 --- a/nimbleModel/NAMESPACE +++ b/nimbleModel/NAMESPACE @@ -47,6 +47,8 @@ export(setTopRules) export(getDependencies) export(getParents) export(getNodes) +export(getNodeNames) +export(expandNodeNames) export(applyGraphRule) export(getSymbolicParentNodes) export(getSymbolicParentNodesRecurse) @@ -70,5 +72,9 @@ export(messageIfVerbose) export(calc_dmnormAltParams) export(calc_dwishAltParams) +export(nimbleCode) export(nimbleModel) export(makeInstrList) + +export(setNimbleModelOption) +export(getNimbleModelOption) diff --git a/nimbleModel/R/model.R b/nimbleModel/R/model.R index efac3f5..9532d5e 100644 --- a/nimbleModel/R/model.R +++ b/nimbleModel/R/model.R @@ -323,11 +323,19 @@ modelClass <- R6Class( # Determine nodes of interest, potentially of particular types. # Incorporates functionality formerly in `getNodeNames` and `expandNodeNames` getNodes <- function(model, nodes = NULL, - stochOnly = FALSE, determOnly = FALSE, + determOnly = FALSE, stochOnly = FALSE, includeData = TRUE, dataOnly = FALSE, - includePredictive = TRUE, predictiveOnly = FALSE, includeRHSonly = FALSE, - topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE) { + topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE, + includePredictive = TRUE, predictiveOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE, + .sort = FALSE + ) { + # A single nodeRange can have elements that don't share a sortID when converted to calcRange representation, + # so we can't sort nodeRanges. + if (.sort && !nodesAsChars) + warning("`.sort=TRUE` is provided only for back compatibility and requires the use of character representations of nodes") # `nodes` may contain one or more varRanges or varNames. if (topOnly + latentOnly + endOnly > 1) { stop("only one of `topOnly`, `latentOnly`, `endOnly` can be `TRUE`.") @@ -342,7 +350,7 @@ getNodes <- function(model, nodes = NULL, if (inherits(nodes, "varRangeClass")) { nodes <- list(nodes) } - if (!all(is.character(nodes) || sapply(nodes, function(node) inherits(node, "varRangeClass")))) { + if (!all(sapply(nodes, function(node) is.character(node) || inherits(node, "varRangeClass")))) { stop("`nodes` must be variable names or `varRange`s.") } } @@ -387,12 +395,74 @@ getNodes <- function(model, nodes = NULL, if (includeRHSonly && !stochOnly && !determOnly) { # RHSonly are considered neither determ not stoch. rhsResult <- lapply(nodes, function(node) applyRules(model$modelDef$rhsOnlyRules, node)) - result <- c(result, flatten(rhsResult)) + if (!.sort) + result <- c(result, flatten(rhsResult)) # TODO: flatten() seems to be deprecated; can we use unlist? } + if (.sort) { + # Ordering is only relevant at calcRange stage and a single nodeRange can contain + # elements with various sortIDs, so we convert to nodeChars first and then get their + # sortID by creating a temporary calcRange for each. + nodeChars <- unlist(sapply(result, \(x) x$toNodeChars())) + + calcRanges <- unlist(lapply(nodeChars, function(node) { + lapply(model$modelDef$calcRules[[getVarName(node)]]$rules, function(rule) { + rule$makeCalcRange(rule$apply(node)) + }) + })) + if (length(nodeChars) != length(calcRanges)) + stop("unexpected mismatch between node character representation and calcRanges in `getNodes` sorting") + ord <- order(sapply(calcRanges, \(x) x$sortID)) + result <- nodeChars[ord] + if (includeRHSonly) + result <- c(sapply(flatten(rhsResult), \(x) x$toNodeChars()), result) + if (returnScalarComponents) + result <- lapply(result, \(x) varRangeClass$new(x)$toVarChars(expandScalars = TRUE)) + result <- unlist(result) + } else { + if (nodesAsChars) { + if (returnScalarComponents) { + result <- lapply(result, \(x) x$toVarRange()$toVarChars(expandScalars = TRUE)) + } else result <- lapply(result, \(x) x$toNodeChars()) + result <- unlist(result) + } else { + if (returnScalarComponents) # TODO: put into new messaging system + warning("one must request result as characters via `nodesAsChars` in order to use `returnScalarComponents`") + } + } if (!length(result)) { return(NULL) } + return(result) +} + +# Provided for backward compatibility. +#' @export +getNodeNames <- function(model, determOnly = FALSE, stochOnly = FALSE, + includeData = TRUE, dataOnly = FALSE, includeRHSonly = FALSE, + topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE, + includePredictive = TRUE, predictiveOnly = FALSE, + returnType = "names", + returnScalarComponents = FALSE) { + if (returnType != "names") + stop("In nimble2, one can only request 'names' as the `returnType`") + return(getNodes(model, nodes = NULL, determOnly, stochOnly, includeData, dataOnly, + includeRHSonly, topOnly, latentOnly, endOnly, + includePredictive, predictiveOnly, + nodesAsChars = TRUE, + returnScalarComponents, .sort = TRUE)) +} - return(removeDuplicateVarRanges(result)) +# Provided for backward compatibility. +# Need a test case where unique=FALSE retains duplicates. +# This should not do any exclusions of nodes based on types. +#' @export +expandNodeNames <- function(model, nodes, returnScalarComponents = FALSE, + returnType = "names", sort = FALSE, unique = TRUE) { + if (returnType != "names") + stop("In nimble2, one can only request 'names' as the `returnType`") + result <- getNodes(model, nodes, includeRHSonly = TRUE, nodesAsChars = TRUE, + returnScalarComponents = returnScalarComponents, .sort = sort) + if (unique) result <- unique(result) + return(result) } diff --git a/nimbleModel/R/modelBaseClass.R b/nimbleModel/R/modelBaseClass.R index f6e216d..d8fdab9 100644 --- a/nimbleModel/R/modelBaseClass.R +++ b/nimbleModel/R/modelBaseClass.R @@ -307,25 +307,54 @@ modelBase_nClass <- nClass( return(expr) } }, - getDependencies = function(nodes, self = TRUE, downstream = FALSE, immediateOnly = FALSE) { - nimbleModel::getDependencies(modelDef, nodes, self, downstream, immediateOnly) + getDependencies = function(nodes, self = TRUE, downstream = FALSE, immediateOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE + ) { + nimbleModel::getDependencies(modelDef, nodes, self, downstream, immediateOnly, + nodesAsChars, returnScalarComponents) }, - getParents = function(nodes, self = TRUE, upstream = FALSE, immediateOnly = FALSE) { - nimbleModel::getParents(modelDef, nodes, self, upstream, immediateOnly) + getParents = function(nodes, self = TRUE, upstream = FALSE, immediateOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE + ) { + nimbleModel::getParents(modelDef, nodes, self, upstream, immediateOnly, + nodesAsChars, returnScalarComponents) }, # TODO: not working because `nimbleModel::getNodes` needs the model not just modelDef. # Once we integrate modelClass with modelBase_nClass, we should be able to pass `self`. - getNodes = function(nodes = NULL, stochOnly = FALSE, determOnly = FALSE, + getNodes = function(nodes = NULL, determOnly = FALSE, stochOnly = FALSE, includeData = TRUE, dataOnly = FALSE, - includePredictive = TRUE, predictiveOnly = FALSE, includeRHSonly = FALSE, - topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE) { + topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE, + includePredictive = TRUE, predictiveOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE, + .sort = FALSE) { nimbleModel::getNodes( - self, nodes, stochOnly, determOnly, includeData, dataOnly, - includePredictive, predictiveOnly, includeRHSonly, - topOnly, latentOnly, endOnly + self, nodes, determOnly, stochOnly, includeData, dataOnly, + includeRHSonly, + topOnly, latentOnly, endOnly, + includePredictive, predictiveOnly, + nodesAsChars, returnScalarComponents, .sort + ) + }, + getNodeNames = function(determOnly = FALSE, stochOnly = FALSE, + includeData = TRUE, dataOnly = FALSE, includeRHSonly = FALSE, + topOnly = FALSE, latentOnly = FALSE, endOnly = FALSE, + includePredictive = TRUE, predictiveOnly = FALSE, + returnType = "names", + returnScalarComponents = FALSE) { + nimbleModel::getNodeNames( + self, determOnly, stochOnly, includeData, dataOnly, + includeRHSonly, topOnly, latentOnly, endOnly, + includePredictive, predictiveOnly, returnType, returnScalarComponents ) }, + expandNodeNames = function(nodes, returnScalarComponents = FALSE, + returnType = "names", sort = FALSE, unique = TRUE) { + nimbleModel::expandNodeNames(self, nodes, returnScalarComponents, "names", sort, unique) + }, calc_op = function(instr, fn, fn_cpp) { if (missing(instr)) { instr <- getVarNames() diff --git a/nimbleModel/R/modelDef.R b/nimbleModel/R/modelDef.R index 370c5eb..5cb38d7 100644 --- a/nimbleModel/R/modelDef.R +++ b/nimbleModel/R/modelDef.R @@ -999,21 +999,29 @@ modelDefClass <- R6Class( getDependencies <- function(modelDef, nodes, self = TRUE, - downstream = FALSE, immediateOnly = FALSE) { + downstream = FALSE, immediateOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE + ) { traverseGraph(modelDef$downstreamRules, modelDef$declRules, nodes = nodes, down = TRUE, self = self, - follow = downstream, immediateOnly = immediateOnly + follow = downstream, immediateOnly = immediateOnly, + nodesAsChars = nodesAsChars, returnScalarComponents = returnScalarComponents ) } getParents <- function(modelDef, nodes, self = FALSE, - upstream = FALSE, immediateOnly = FALSE) { + upstream = FALSE, immediateOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE + ) { traverseGraph(modelDef$upstreamRules, modelDef$declRules, nodes = nodes, down = FALSE, self = self, - follow = upstream, immediateOnly = immediateOnly + follow = upstream, immediateOnly = immediateOnly, + nodesAsChars = nodesAsChars, returnScalarComponents = returnScalarComponents ) } diff --git a/nimbleModel/R/nodeRules.R b/nimbleModel/R/nodeRules.R index 24b8b25..7b5976d 100644 --- a/nimbleModel/R/nodeRules.R +++ b/nimbleModel/R/nodeRules.R @@ -347,8 +347,7 @@ calcRuleClass <- R6Class( ) ) -# calcRanges manage the calculation for one or more nodes, handling the indexing, and -# calling out to the declFun `calculate` function. +# calcRanges manage the calculation information for one or more nodes, handling the indexing calcRangeClass <- R6Class( classname = "calcRangeClass", portable = FALSE, @@ -362,7 +361,6 @@ calcRangeClass <- R6Class( initialize = function(varName, indexingRange, declID, sortID, multiSortIDindex) { varName <<- varName indexingRange <<- indexingRange - calcFun <<- calcFun # note that calcFun itself is not vectorized sortID <<- sortID declID <<- declID multiSortIDindex <<- multiSortIDindex diff --git a/nimbleModel/R/options.R b/nimbleModel/R/options.R index 37e4037..41e47ad 100644 --- a/nimbleModel/R/options.R +++ b/nimbleModel/R/options.R @@ -8,11 +8,13 @@ # but this has .GlobalEnv as a parent. processBackwardsModelIndexRanges = TRUE, disallowMultivariateArgumentExpressions = TRUE, + nodesAsChars = FALSE, verbose = TRUE ) ) # sets a single option +#' @export setNimbleModelOption <- function(name, value) { assign(name, value, envir = .nimbleModelOptions) invisible(value) diff --git a/nimbleModel/R/processModelGraph.R b/nimbleModel/R/processModelGraph.R index 12325e6..10c2afd 100644 --- a/nimbleModel/R/processModelGraph.R +++ b/nimbleModel/R/processModelGraph.R @@ -242,7 +242,10 @@ setSortIDs <- function(calcRules) { # pass result through `getNodes`. traverseGraph <- function(streamRules, declRules, nodes, down, self = TRUE, - follow = FALSE, immediateOnly = FALSE) { + follow = FALSE, immediateOnly = FALSE, + nodesAsChars = getNimbleModelOption('nodesAsChars'), + returnScalarComponents = FALSE + ) { if (inherits(nodes, "varRangeClass")) nodes <- list(nodes) # We use `lapply` on 'nodes' later. results <- traverseGraphRecurse(streamRules, nodes, down, follow, immediateOnly) @@ -321,7 +324,14 @@ traverseGraph <- function(streamRules, declRules, if (!length(results)) { return(NULL) } - return(removeDuplicateVarRanges(results)) + results <- removeDuplicateVarRanges(results) + if (nodesAsChars) { + return(unlist(sapply(results, \(x) x$toVarChars(expandScalars = returnScalarComponents)))) + } else { + if (returnScalarComponents) # TODO: put into new messaging system + warning("one must request result as characters via `nodesAsChars` in order to use `returnScalarComponents`") + } + return(results) } traverseGraphRecurse <- function(rules, nodes, down, follow = FALSE, immediateOnly = FALSE, firstPass = TRUE) {