Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
- name: CmdStan installation cacheing
id: cache-cmdstan
if: ${{ !startswith(needs.get-cmdstan-version.outputs.version, 'git:') }}
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cmdstan
key: ${{ runner.os }}-cmdstan-${{ needs.get-cmdstan-version.outputs.version }}-${{ hashFiles('**/install_cmdstan.py') }}
Expand Down
44 changes: 29 additions & 15 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
CmdStanArgs,
GenerateQuantitiesArgs,
LaplaceArgs,
Method,
OptimizeArgs,
PathfinderArgs,
SamplerArgs,
Expand Down Expand Up @@ -434,16 +433,21 @@ def optimize(
)
runset.raise_for_timeouts()

if not runset._check_retcodes():
converged = runset._check_retcodes()
if not converged:
msg = "Error during optimization! Command '{}' failed: {}".format(
' '.join(runset.cmd(0)), runset.get_err_msgs()
)
if 'Line search failed' in msg and not require_converged:
get_logger().warning(msg)
else:
raise RuntimeError(msg)
mle = CmdStanMLE(runset)
return mle
return CmdStanMLE.from_files(
csv_file=runset.csv_files[0],
config_file=runset.config_files[0],
stdout_file=runset.stdout_files[0],
converged=converged,
)

# pylint: disable=too-many-arguments
def sample(
Expand Down Expand Up @@ -1040,7 +1044,12 @@ def generate_quantities(
),
):
fit_object = previous_fit
fit_csv_files = previous_fit.runset.csv_files
if isinstance(
previous_fit, (CmdStanPathfinder, CmdStanLaplace, CmdStanMLE)
):
fit_csv_files = [previous_fit.csv_file]
else:
fit_csv_files = previous_fit.runset.csv_files
elif isinstance(previous_fit, list):
if len(previous_fit) < 1:
raise ValueError(
Expand Down Expand Up @@ -1072,7 +1081,7 @@ def generate_quantities(
elif isinstance(fit_object, CmdStanMLE):
chains = 1
chain_ids = [1]
if fit_object._save_iterations:
if fit_object.config.method_config.save_iterations:
get_logger().warning(
'MLE contains saved iterations which will be used '
'to generate additional quantities of interest.'
Expand Down Expand Up @@ -1553,7 +1562,11 @@ def pathfinder(
' '.join(runset.cmd(0)), runset.get_err_msgs()
)
raise RuntimeError(msg)
return CmdStanPathfinder(runset)
return CmdStanPathfinder.from_files(
csv_file=runset.csv_files[0],
config_file=runset.config_files[0],
stdout_file=runset.stdout_files[0],
)

def log_prob(
self,
Expand Down Expand Up @@ -1739,24 +1752,20 @@ def laplace_sample(
else:
cmdstan_mode = mode

if cmdstan_mode.runset.method != Method.OPTIMIZE:
if not isinstance(cmdstan_mode, CmdStanMLE):
raise ValueError(
"Mode must be a CmdStanMLE or a path to an optimize CSV"
)

mode_jacobian = (
cmdstan_mode.runset._args.method_args.jacobian # type: ignore
)
mode_jacobian = cmdstan_mode.config.method_config.jacobian
if mode_jacobian != jacobian:
raise ValueError(
"Jacobian argument to optimize and laplace must match!\n"
f"Laplace was run with jacobian={jacobian},\n"
f"but optimize was run with jacobian={mode_jacobian}"
)

laplace_args = LaplaceArgs(
cmdstan_mode.runset.csv_files[0], draws, jacobian
)
laplace_args = LaplaceArgs(cmdstan_mode.csv_file, draws, jacobian)

with temp_single_json(data) as _data:
args = CmdStanArgs(
Expand All @@ -1780,7 +1789,12 @@ def laplace_sample(
timeout=timeout,
)
runset.raise_for_timeouts()
return CmdStanLaplace(runset, cmdstan_mode)
return CmdStanLaplace.from_files(
csv_file=runset.csv_files[0],
config_file=runset.config_files[0],
stdout_file=runset.stdout_files[0],
mode=cmdstan_mode,
)

def _run_cmdstan(
self,
Expand Down
131 changes: 69 additions & 62 deletions cmdstanpy/stanfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
import glob
import os

from cmdstanpy.cmdstan_args import (
CmdStanArgs,
LaplaceArgs,
OptimizeArgs,
PathfinderArgs,
SamplerArgs,
VariationalArgs,
)
from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs, VariationalArgs
from cmdstanpy.utils import check_sampler_csv, get_logger, stancsv

from .gq import CmdStanGQ, PrevFit
Expand All @@ -35,7 +28,7 @@
]


def from_csv(
def from_csv( # pylint: disable=too-many-return-statements
path: str | list[str] | os.PathLike | None = None,
method: str | None = None,
) -> (
Expand Down Expand Up @@ -93,7 +86,7 @@ def from_csv(
if os.path.splitext(file)[1] == ".csv":
csvfiles.append(os.path.join(path, file))
elif os.path.exists(path):
csvfiles.append(str(path))
csvfiles.append(os.fspath(path))
else:
raise ValueError('Invalid path specification: {}'.format(path))
else:
Expand Down Expand Up @@ -179,32 +172,52 @@ def from_csv(
fit.draws()
return fit
elif config_dict['method'] == 'optimize':
if len(csvfiles) != 1:
raise ValueError(
'Expecting a single optimize Stan CSV file, '
f'found {len(csvfiles)}'
)
csv_file = csvfiles[0]
config_file = os.path.splitext(csv_file)[0] + '_config.json'
if os.path.exists(config_file):
return CmdStanMLE.from_files(
csv_file=csv_file, config_file=config_file
)
# Legacy path: no config file, build config from CSV metadata
if 'algorithm' not in config_dict:
raise ValueError(
"Cannot find optimization algorithm in file {}.".format(
csvfiles[0]
csv_file
)
)
algorithm: str = config_dict['algorithm'] # type: ignore
save_iterations = config_dict['save_iterations'] == 1
jacobian = config_dict.get('jacobian', 0) == 1
from .metadata import OptimizeConfig, StanConfig

optimize_args = OptimizeArgs(
algorithm=algorithm,
save_iterations=save_iterations,
jacobian=jacobian,
opt_config = OptimizeConfig(
algorithm=config_dict['algorithm'], # type: ignore
save_iterations=config_dict.get('save_iterations', 0) == 1,
jacobian=config_dict.get('jacobian', 0) == 1,
)
cmdstan_args = CmdStanArgs(
model_name=model,
model_exe=model,
chain_ids=None,
method_args=optimize_args,
stan_config = StanConfig[OptimizeConfig].model_validate(
{
'model_name': config_dict['model'],
'stan_major_version': str(
config_dict.get('stan_version_major', '')
),
'stan_minor_version': str(
config_dict.get('stan_version_minor', '')
),
'stan_patch_version': str(
config_dict.get('stan_version_patch', '')
),
'method_config': opt_config.model_dump(),
}
)
return CmdStanMLE(
metadata=InferenceMetadata.from_csv(csv_file),
model_name=stan_config.model_name,
csv_file=csv_file,
config=stan_config,
)
runset = RunSet(args=cmdstan_args)
runset._csv_files = csvfiles
for i in range(len(runset._retcodes)):
runset._set_retcode(i, 0)
return CmdStanMLE(runset)
elif config_dict['method'] == 'variational':
if 'algorithm' not in config_dict:
raise ValueError(
Expand Down Expand Up @@ -234,43 +247,37 @@ def from_csv(
runset._set_retcode(i, 0)
return CmdStanVB(runset)
elif config_dict['method'] == 'laplace':
jacobian = config_dict['jacobian'] == 1
laplace_args = LaplaceArgs(
mode=config_dict['mode'], # type: ignore
draws=config_dict['draws'], # type: ignore
jacobian=jacobian,
)
cmdstan_args = CmdStanArgs(
model_name=model,
model_exe=model,
chain_ids=None,
method_args=laplace_args,
)
runset = RunSet(args=cmdstan_args)
runset._csv_files = csvfiles
for i in range(len(runset._retcodes)):
runset._set_retcode(i, 0)
mode: CmdStanMLE = from_csv(
config_dict['mode'], # type: ignore
method='optimize',
if len(csvfiles) != 1:
raise ValueError(
'Expecting a single Laplace Stan CSV file, '
f'found {len(csvfiles)}'
)
csv_file = csvfiles[0]
config_file = os.path.splitext(csv_file)[0] + '_config.json'
if not os.path.exists(config_file):
raise ValueError(
'Laplace config file not found at expected path: '
f'{config_file}'
)
return CmdStanLaplace.from_files(
csv_file=csv_file, config_file=config_file
)
return CmdStanLaplace(runset, mode=mode)
elif config_dict['method'] == 'pathfinder':
pathfinder_args = PathfinderArgs(
num_draws=config_dict['num_draws'], # type: ignore
num_paths=config_dict['num_paths'], # type: ignore
)
cmdstan_args = CmdStanArgs(
model_name=model,
model_exe=model,
chain_ids=None,
method_args=pathfinder_args,
if len(csvfiles) != 1:
raise ValueError(
'Expecting a single Pathfinder Stan CSV file, '
f'found {len(csvfiles)}'
)
csv_file = csvfiles[0]
config_file = os.path.splitext(csv_file)[0] + '_config.json'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is temporary but it's worth putting a comment saying so, since this is kind of a nasty assumption

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely temporary. I was planning on doing away with this function entirely once the full changes are in.

if not os.path.exists(config_file):
raise ValueError(
'Pathfinder config file not found at expected path: '
f'{config_file}'
)
return CmdStanPathfinder.from_files(
csv_file=csv_file, config_file=config_file
)
runset = RunSet(args=cmdstan_args)
runset._csv_files = csvfiles
for i in range(len(runset._retcodes)):
runset._set_retcode(i, 0)
return CmdStanPathfinder(runset)
else:
get_logger().warning(
'Unable to process CSV output files from method %s.',
Expand Down
Loading