"""MCMC fitting kernel for the multi-isotopologue fluorescence model."""
from __future__ import annotations
from typing import Dict, Tuple, Sequence, Optional, Callable, Any
import warnings
import multiprocess as _mp
import os as _os
from contextlib import nullcontext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import emcee
import corner
from astropy import constants as const
from .linelist import (
_as_list,
normalize_cn_systems_arg,
resolve_linelists_with_defaults,
attach_pumping_and_labels,
)
from .rates import (
build_rate_matrix_nbar,
solve_with_normalization,
solve_with_normalization_fast,
g_factors_fast_from_cache,
synth_spectrum_from_lines,
make_lsf,
)
from .collisions import (
_empty_scaffold,
is_atomic_species,
precompute_collision_scaffold_fast,
apply_collisions_inplace_fast,
)
from .config import UNSET, MCMCFitConfig
__all__ = [
"mcmc_fitting",
]
#: Planck constant, in erg*s.
H_CGS: float = const.h.cgs.value
def _resolve_n_cores(n_cores: Optional[int]) -> int:
"""Normalize ``n_cores`` to a concrete worker count.
``None`` or non-positive values map to ``1`` (single-threaded). Otherwise
the value is clamped to the number of available CPUs.
"""
if n_cores is None:
return 1
try:
n = int(n_cores)
except (TypeError, ValueError):
return 1
if n <= 1:
return 1
cpu = _os.cpu_count() or 1
return max(1, min(n, cpu))
def _worker_init(lnprob_callable: Optional[Callable[[Sequence[float]], float]] = None) -> None:
"""Initialize a ``spawn``-based pool worker for emcee.
Each worker (a) installs the pickled ``lnprob_callable`` as the
module-level :data:`_LNPROB_CALLABLE` so the picklable
:func:`_lnprob_worker` proxy can dispatch to it, and (b) pins the BLAS
thread pool to a single thread. Pinning BLAS avoids CPU oversubscription
when the pool already provides parallelism across walkers.
"""
global _LNPROB_CALLABLE
if lnprob_callable is not None:
_LNPROB_CALLABLE = lnprob_callable
try:
from threadpoolctl import threadpool_limits
threadpool_limits(limits=1)
except ImportError:
pass
def _build_mcmc_pool(
n_cores: Optional[int],
lnprob_callable: Optional[Callable[[Sequence[float]], float]],
):
"""Create a ``spawn``-based pool for emcee, or ``None``.
Built on the :mod:`multiprocess` library (drop-in replacement for
:mod:`multiprocessing` that uses :mod:`dill` instead of :mod:`pickle`),
so the lnprob callable's user-provided ``lsf`` may be a lambda, a
closure, or a top-level function defined in a Jupyter notebook cell.
Uses the ``spawn`` start method on every platform (including macOS,
where Apple Accelerate's GCD-backed BLAS deadlocks or segfaults when a
multi-threaded process forks). ``lnprob_callable`` is sent once at pool
creation through ``initargs`` and installed on each worker as the
module-level ``_LNPROB_CALLABLE`` so the :func:`_lnprob_worker` proxy
can dispatch to it. Each worker also pins BLAS threads to 1.
Returns ``None`` when ``n_cores <= 1``.
"""
n = _resolve_n_cores(n_cores)
if n <= 1:
return None
ctx = _mp.get_context("spawn")
return ctx.Pool(
processes=n,
initializer=_worker_init,
initargs=(lnprob_callable,),
)
def _parent_blas_limiter(n_cores: Optional[int]):
"""Limit parent-process BLAS threads while the pool is alive.
Workers already pin their own BLAS to one thread in :func:`_worker_init`.
This context limits the parent the same way for the duration of the
pool's lifetime, so any incidental BLAS work the parent does while
walkers run does not oversubscribe CPUs alongside the pool. The previous
BLAS thread count is restored once the pool is joined.
"""
n = _resolve_n_cores(n_cores)
if n <= 1:
return nullcontext()
try:
from threadpoolctl import threadpool_limits
except ImportError:
warnings.warn(
"threadpoolctl is not installed; multi-core MCMC may deadlock on "
"macOS or oversubscribe BLAS threads on Linux. Install threadpoolctl "
"or run with n_cores=1.",
RuntimeWarning,
stacklevel=2,
)
return nullcontext()
return threadpool_limits(limits=1)
_LNPROB_CALLABLE: Optional[Callable[[Sequence[float]], float]] = None
def _lnprob_worker(theta):
"""Module-level proxy that dispatches to the active ``lnprob`` callable.
``multiprocessing.Pool.map`` pickles the target callable per call to
send through the IPC pipe, which would fail for a closure defined inside
:func:`mcmc_fitting`. This proxy is picklable and forwards to the
module-level :data:`_LNPROB_CALLABLE`, which is installed once on each
worker by :func:`_worker_init` via the pool's ``initargs``.
"""
return _LNPROB_CALLABLE(theta)
class _LnProbCallable:
"""Bundle of state required to evaluate ``lnprob(theta)``.
Replaces the closure that :func:`mcmc_fitting` previously built locally,
so the callable can be serialized and shipped to ``spawn`` workers (the
earlier ``fork``-inheritance trick is unsafe with Apple Accelerate, which
deadlocks or segfaults when a multi-threaded process forks).
Serialized via :mod:`dill` (through the :mod:`multiprocess` library), so
a user-provided ``lsf`` that is a lambda, a closure, or a top-level
function defined in a Jupyter notebook cell all work transparently.
"""
def __init__(
self,
*,
priors,
param_keys,
cache,
iso_list,
x_fit,
y_fit,
y_err_fit,
lsf,
lsf_method,
omega,
include_rotations,
init_logQ,
init_logQ_by_iso,
init_T,
init_T_by_iso,
init_v_kms,
init_v_kms_by_iso,
init_dlam,
init_dlam_by_iso,
init_logN,
init_logN_by_iso,
init_sigma,
init_sigma1,
init_sigma2,
init_sigma_G,
init_fwhm_L,
init_ratio,
):
self.priors = priors
self.param_keys = param_keys
self.cache = cache
self.iso_list = iso_list
self.x_fit = x_fit
self.y_fit = y_fit
self.y_err_fit = y_err_fit
self.lsf = lsf
self.lsf_method = lsf_method
self.omega = omega
self.include_rotations = include_rotations
self.init_logQ = init_logQ
self.init_logQ_by_iso = init_logQ_by_iso
self.init_T = init_T
self.init_T_by_iso = init_T_by_iso
self.init_v_kms = init_v_kms
self.init_v_kms_by_iso = init_v_kms_by_iso
self.init_dlam = init_dlam
self.init_dlam_by_iso = init_dlam_by_iso
self.init_logN = init_logN
self.init_logN_by_iso = init_logN_by_iso
self.init_sigma = init_sigma
self.init_sigma1 = init_sigma1
self.init_sigma2 = init_sigma2
self.init_sigma_G = init_sigma_G
self.init_fwhm_L = init_fwhm_L
self.init_ratio = init_ratio
def theta_to_params(self, theta: Sequence[float]) -> Dict[str, float]:
return {k: float(v) for k, v in zip(self.param_keys, theta)}
def _ln_prior(self, theta: Sequence[float]) -> float:
for val, name in zip(theta, self.param_keys):
lo, hi = self.priors[name]
if val < lo or val > hi:
return -np.inf
return 0.0
def _make_lsf_local(self, pars: Dict[str, float]) -> Optional[Callable[[np.ndarray], np.ndarray]]:
if self.lsf is not None:
return self.lsf
if self.lsf_method is None:
return None
return make_lsf(pars, self.lsf_method)
def _logQ_for_iso(self, iso: str, pars: Dict[str, float]) -> Optional[float]:
key = f"logQ_{iso}"
if key in pars:
return float(pars[key])
if "logQ" in pars:
return float(pars["logQ"])
if self.init_logQ_by_iso is not None and iso in self.init_logQ_by_iso:
try:
return float(self.init_logQ_by_iso[iso])
except TypeError:
return None
if self.init_logQ is not None:
return float(self.init_logQ)
return None
def _T_for_iso(self, iso: str, pars: Dict[str, float]) -> float:
key = f"T_{iso}"
if key in pars:
return float(pars[key])
if "T" in pars:
return float(pars["T"])
if self.init_T_by_iso is not None and iso in self.init_T_by_iso:
return float(self.init_T_by_iso[iso])
return float(self.init_T)
def _v_kms_for_iso(self, iso: str, pars: Dict[str, float]) -> float:
key = f"v_kms_{iso}"
if key in pars:
return float(pars[key])
if "v_kms" in pars:
return float(pars["v_kms"])
if self.init_v_kms_by_iso is not None and iso in self.init_v_kms_by_iso:
return float(self.init_v_kms_by_iso[iso])
return float(self.init_v_kms)
def _dlam_for_iso(self, iso: str, pars: Dict[str, float]) -> float:
key = f"dlam_{iso}"
if key in pars:
return float(pars[key])
if "dlam" in pars:
return float(pars["dlam"])
if self.init_dlam_by_iso is not None and iso in self.init_dlam_by_iso:
return float(self.init_dlam_by_iso[iso])
return float(self.init_dlam)
def _logN_for_iso(self, iso: str, pars: Dict[str, float]) -> Optional[float]:
key = f"logN_{iso}"
if key in pars:
return float(pars[key])
if "logN" in pars:
return float(pars["logN"])
if self.init_logN_by_iso is not None and iso in self.init_logN_by_iso:
try:
return float(self.init_logN_by_iso[iso])
except TypeError:
return None
if self.init_logN is not None:
return float(self.init_logN)
return None
def model_flux(self, theta: Sequence[float], wave: np.ndarray) -> np.ndarray:
pars = self.theta_to_params(theta)
wmin = float(np.min(wave))
wmax = float(np.max(wave))
try:
sigma = float(pars["sigma"]) if "sigma" in pars else float(self.init_sigma)
except TypeError:
sigma = None
try:
sigma1 = float(pars["sigma1"]) if "sigma1" in pars else float(self.init_sigma1)
except TypeError:
sigma1 = None
try:
sigma2 = float(pars["sigma2"]) if "sigma2" in pars else float(self.init_sigma2)
except TypeError:
sigma2 = None
try:
sigma_G = float(pars["sigma_G"]) if "sigma_G" in pars else float(self.init_sigma_G)
except TypeError:
sigma_G = None
try:
fwhm_L = float(pars["fwhm_L"]) if "fwhm_L" in pars else float(self.init_fwhm_L)
except TypeError:
fwhm_L = None
try:
ratio = float(pars["ratio"]) if "ratio" in pars else float(self.init_ratio)
except TypeError:
ratio = None
dict_for_lsf = {
"sigma": sigma, "sigma1": sigma1, "sigma2": sigma2,
"sigma_G": sigma_G, "fwhm_L": fwhm_L, "ratio": ratio,
}
lsf_fun = self._make_lsf_local(dict_for_lsf)
spec_total = np.zeros_like(wave, dtype=float)
for iso in self.iso_list:
C = self.cache[iso]
M = C["M_work"]
np.copyto(M, C["M_rad"])
logQ_i = self._logQ_for_iso(iso, pars)
T_i = self._T_for_iso(iso, pars)
v_kms_i = self._v_kms_for_iso(iso, pars)
dlam_i = self._dlam_for_iso(iso, pars)
if logQ_i is not None:
Q_i = 10.0 ** logQ_i if np.isfinite(logQ_i) else 0.0
if Q_i > 0.0 and self.include_rotations:
apply_collisions_inplace_fast(
M, C["coll_scaf"], Q=Q_i, T=T_i, Cup_work=C["Cup_work"]
)
n = solve_with_normalization_fast(M, C["A_work"], C["b_work"])
_, g_en = g_factors_fast_from_cache(
ui=C["ui"],
A_ul=C["A_ul"],
hnu=C["hnu"],
n=n,
out_g_ph=C["gph_work"],
out_g_en=C["gen_work"],
)
logN_i = self._logN_for_iso(iso, pars)
_, spec_i = synth_spectrum_from_lines(
C["lines_out"],
g_line_energy=g_en,
lam_min=wmin,
lam_max=wmax,
lam_col="Wave_vac_AA",
N_col_cm2=10.0 ** logN_i,
Omega_sr=self.omega,
grid=wave,
lsf=lsf_fun,
v_shift_kms=v_kms_i,
dlam_shift_A=dlam_i,
)
spec_total += spec_i
return spec_total
def lnlike(self, theta: Sequence[float]) -> float:
y_model = self.model_flux(theta, self.x_fit)
if (not np.all(np.isfinite(y_model))) or (y_model.shape != self.x_fit.shape):
return -np.inf
inv_sigma2 = 1.0 / (self.y_err_fit ** 2)
return -0.5 * np.sum(
np.log(2.0 * np.pi * self.y_err_fit ** 2)
+ (self.y_fit - y_model) ** 2 * inv_sigma2
)
def __call__(self, theta: Sequence[float]) -> float:
lp = self._ln_prior(theta)
if not np.isfinite(lp):
return -np.inf
ll = self.lnlike(theta)
if not np.isfinite(ll):
return -np.inf
return lp + ll
[docs]
def mcmc_fitting(
data: Any = None,
window: Optional[Tuple[float, float]] = None,
*,
pumping: Any = None,
isotopologues: str | Sequence[str] = "12C14N",
systems: str | Sequence[str] | None = None,
linelists: pd.DataFrame | dict[str, pd.DataFrame] | Sequence[pd.DataFrame] | None = None,
include_rotations: bool = True,
include_deltaJ0_parity_mix: bool = True,
require_X_only_for_rot: bool = True,
priors: Optional[Dict[str, Tuple[float, float]]] = None,
nwalkers: int = 50,
nsteps: int = 1000,
n_cores: Optional[int] = None,
lsf: Optional[Callable[[np.ndarray], np.ndarray]] = None,
lsf_method: Optional[str] = None,
make_plots: bool = False,
progress: bool = True,
A_min: Optional[float] = 1e4,
a: float = 3,
# NOTE: these control *pumping* wavelength shift for J_nu (radiative rates)
velocity_kms: float = 0.0,
delta_lambda_A: float = 0.0,
# NOTE: these are fallbacks for parameters not present in priors
init_logQ: Optional[float] = None,
init_logQ_by_iso: Optional[Dict[str, Optional[float]]] = None,
init_T: float = 300.0,
init_T_by_iso: Optional[Dict[str, float]] = None,
init_v_kms: float = 0.0,
init_v_kms_by_iso: Optional[Dict[str, float]] = None,
init_dlam: float = 0.0,
init_dlam_by_iso: Optional[Dict[str, float]] = None,
init_logN: Optional[float] = None,
init_logN_by_iso: Optional[Dict[str, float]] = None,
init_sigma: Optional[float] = None,
init_sigma1: Optional[float] = None,
init_sigma2: Optional[float] = None,
init_sigma_G: Optional[float] = None,
init_fwhm_L: Optional[float] = None,
init_ratio: Optional[float] = None,
fig_file: Optional[str] = None,
wave_col: str = "WAVE",
flux_col: str = "FLUX_STACK",
error_col: str = "ERR_STACK",
continuum_col: str = "CONTINUUM",
omega: Optional[float] = None,
verbose: bool = True,
pruning: bool = True,
N_Model: Optional[int] = 20000,
config: Optional[MCMCFitConfig] = None,
) -> Dict[str, Any]:
"""Run MCMC fitting for the fluorescence model in a wavelength window.
This routine builds the model for one or more isotopologues, applies an
optional line-spread function (LSF), and samples posterior distributions for
the parameters provided in ``priors``.
This is the key function for fitting. It can be used standalone, but is used by
:class:`cometspec.fluorescence.FluorescenceModel` as its default fitting method.
Defaults and selector behavior
------------------------------
- ``isotopologues`` defaults to ``"12C14N"``.
- ``systems`` defaults to ``None``, which maps to ``["BX00", "AX_dv1"]``. Only used for CN; ignored for other species.
- Accepted string selectors for ``systems`` include:
- ``"both"``/``"bx+ax"``/``"bxax"`` -> ``["BX00", "AX_dv1", "AX_dv2", "AX_dv3"]``
- ``"all"`` -> ``["ALL"]`` Extremely high computation cost; not recommended.
- ``"bx"``, ``"b-x"``, ``"bx(0,0)"``, ``"bx00"``, ``"bx_00"``, ``"b_x_00"`` -> ``["BX00"]``
- ``"ax"``/``"a-x"`` -> ``["AX_dv1", "AX_dv2"]``
- ``"ax(dv=1)"``/``"ax_dv1"`` -> ``["AX_dv1"]``
- ``"ax(dv=2)"``/``"ax_dv2"`` -> ``["AX_dv2"]``
- ``"ax(dv=3)"``/``"ax_dv3"`` -> ``["AX_dv3"]``
- emcee ``nwalkers=50`` and ``nsteps=1000`` by default.
- Collisions are gated by ``logQ``: if neither a per-iso ``logQ_{iso}`` nor a shared ``logQ`` prior is given it will fall back to ``init_logQ``, and if ``init_logQ`` or ``init_logQ_by_iso`` are also not given for an isotopologue, that isotopologue is treated as collisionless. Other collision controls default to ``include_deltaJ0_parity_mix=True`` (to allow :math:`\Delta J = 0`) and ``require_X_only_for_rot=True`` (to enable just collisions to the lower electronic state found).
- It is recommended to set explicitly ``include_rotations=False`` if no collisions is the desired behavior.
- Pumping-shift controls default to ``velocity_kms=0.0`` and ``delta_lambda_A=0.0``. These values are used when computing :math:`J_\\nu` for the radiative rates and they are different from ``v_kms`` and ``dlam`` parameters of a model which are shifts of the output spectrum lines.
- Fallback parameter values starts with ``init_`` check the default behavior here if this function is used directly, check them in :class:`cometspec.fluorescence.FluorescenceModel` if this function is used via the model's default fitting method.
- If parameters are not present in the priors, they will fall to the ``init_`` default values (if provided) or to the hardcoded defaults in the model_flux function (e.g. T=300K, v_kms=0, etc). This means that if you want to fit for a parameter but don't provide a prior for it, it will not be fitted and instead will use the fallback value. This allows you to control which parameters are fitted and which are fixed without having to change the code of this function.
.. note::
Priors consist on a dict mapping parameter name to a (min, max) tuple defining the uniform prior range for that parameter. The possible keys are: ``"logN"``, ``"logQ"``, ``"T"``, ``"dlam"``, ``"v_kms"``, 'logN_<isotopologue>', 'logQ_<isotopologue>', 'T_<isotopologue>', 'dlam_<isotopologue>', 'v_kms_<isotopologue>' (any iso not found falls back to <parameter>) and lsf priors, ``sigma``, ``sigma1``, ``sigma2``, ``sigma_G``, ``fwhm_L``, ``ratio``.
Parameters
----------
data : Any
Observed spectrum table or DataFrame which is going to be fitted.
window : tuple[float, float]
Wavelength fitting window ``(min_A, max_A)``.
pumping : Any
Pumping spectrum with ``WAVE`` and ``FLUX`` columns.
isotopologues : str or Sequence[str], optional, default "12C14N"
One or more isotopologue labels.
systems : str or Sequence[str], optional, default None
CN system selector(s). If ``None``, uses ``["BX00", "AX_dv1"]``.
linelists : pandas.DataFrame or dict[str, pd.DataFrame] or Sequence[pd.DataFrame], optional, default None
Optional normalized line-list DataFrame or isotopologue mapping.
priors : dict[str, tuple[float, float]]
Parameter prior ranges. Required.
nwalkers : int, optional, default 50
Number of walkers.
nsteps : int, optional, default 1000
Number of MCMC steps.
n_cores : int, optional, default None
Number of CPU cores used to parallelize the per-step walker likelihood
evaluations through the :mod:`multiprocess` library (a drop-in
replacement for the standard :mod:`multiprocessing` that uses
:mod:`dill` for serialization). ``None`` or ``1`` keeps the sampler
single-threaded (the prior default behavior). Values ``>1`` spawn
a ``spawn``-based pool that emcee uses to evaluate walkers in
parallel; the value is clamped to :func:`os.cpu_count`. The ``spawn``
start method is used on every platform so the pool is safe on macOS
(Apple Accelerate's GCD-backed BLAS deadlocks or segfaults when a
multi-threaded process forks) as well as on Linux and Windows. Using
``multiprocess`` (with ``dill``) instead of stock ``multiprocessing``
means user-provided ``lsf`` callables defined as lambdas, closures,
or top-level functions in a Jupyter notebook cell are serialized
correctly — stock pickle would silently fail on those and hang the
sampler. Each worker has its BLAS thread pool pinned to 1 via
``threadpoolctl`` to avoid thread oversubscription. The tqdm
progress bar (when ``progress=True``) keeps working under both modes.
.. important::
When calling this function with ``n_cores > 1`` from a ``.py``
script (not a Jupyter notebook), the call **must** be guarded by
``if __name__ == "__main__":``. Each ``spawn`` worker re-imports
the user's script, and without the guard the call would re-execute
recursively; Python detects this and aborts with a ``RuntimeError``.
No guard is needed in notebooks or when ``n_cores`` is ``None``/``1``.
.. note::
Whether ``n_cores>1`` is faster than ``n_cores=1`` depends on how
expensive each likelihood evaluation is, which scales with the
line-list size — driven by ``A_min``, the number of isotopologues
and systems, and the window width. Multiprocessing adds a fixed
IPC overhead per emcee step (pickle + pipe roundtrip per walker
half). When each ``lnprob`` call is slow (large line list — e.g.
``A_min=1e4`` for CN with B-X + A-X, or multi-isotopologue runs)
the per-step cost dwarfs IPC and ``n_cores=4`` typically gives a
~3-4x speedup. When each call is cheap (small line list — e.g.
``A_min`` ≥ 1e6 with a single isotopologue, or a narrow window
with few transitions) IPC overhead can dominate and ``n_cores=1``
may end up equal or faster. If unsure, time one short run at each
setting on your problem before committing to a long MCMC.
lsf : Callable[[numpy.ndarray], numpy.ndarray], optional, default None
Optional custom LSF callable.
lsf_method : str, optional, default None
Built-in LSF method name. It must be one of ``"2Gauss"``, ``"Gauss_Lorentz"``, ``"Gauss"``, or ``"Lorentz"``. Required if ``lsf`` is not provided.
include_rotations : bool, optional, default True
Enable rotational collisions in the rate matrix.
omega : float, optional, default None
Optional aperture solid angle in sr.
config : MCMCFitConfig, optional, default None
Optional grouped configuration. Any field set on this object overrides
the corresponding individual keyword argument. Fields left at
:data:`cometspec.config.UNSET` (the default) are ignored, so callers
that pass individual kwargs not in the config are unaffected. See
:class:`cometspec.config.MCMCFitConfig`.
Other Parameters
----------------
include_deltaJ0_parity_mix : bool, optional, default True
Allow parity-changing ``Delta J = 0`` collisions.
require_X_only_for_rot : bool, optional, default True
Restrict collisions to the lower electronic
state (auto-detected as the ``lower_es`` label with the smallest minimum
``E_lower_cm1``; works for any spectroscopic notation).
make_plots : bool, optional, default False
Generate diagnostic plots (fit, corner and traces).
progress : bool, optional, default True
Show emcee progress output.
A_min : float, optional, default 1e4
Minimum Einstein A threshold for line lists. User provided line lists are not filtered by A_ul
a : float, optional, default 3
Stretch-move parameter for emcee.
velocity_kms : float, optional, default 0.0
Velocity shift used when evaluating pumping J_nu.
delta_lambda_A : float, optional, default 0.0
Additive wavelength shift used when evaluating pumping J_nu.
init_logQ : float, optional, default None
Fallback ``logQ`` value used by every isotopologue when no
``logQ`` prior is sampled and no per-iso entry is given. ``None``
disables collisions for any isotopologue not covered by ``init_logQ_by_iso``.
init_logQ_by_iso : dict[str, float or None], optional, default None
Per-isotopologue fallback ``logQ`` map. Each value may
be ``None`` to force that isotopologue to be collisionless. Isotopologues
not present in the map fall back to ``init_logQ``.
init_T : float, optional, default 300.0
Global fallback temperature when ``T`` is not sampled and no per-iso entry is given.
init_T_by_iso : dict[str, float], optional, default None
Per-isotopologue fallback temperature map. For each isotopologue, ``init_T_by_iso[iso]``
is used when neither ``T_{iso}`` nor ``T`` appears in the sampled parameters. Isotopologues
not in this map fall back to ``init_T``.
init_v_kms : float, optional, default 0.0
Global fallback emission velocity shift when ``v_kms`` is not sampled and no per-iso entry is given.
init_v_kms_by_iso : dict[str, float], optional, default None
Per-isotopologue fallback emission velocity shift. Falls back to ``init_v_kms`` for
isotopologues not present in the map.
init_dlam : float, optional, default 0.0
Global fallback emission wavelength shift when ``dlam`` is not sampled and no per-iso entry is given.
init_dlam_by_iso : dict[str, float], optional, default None
Per-isotopologue fallback additive wavelength shift. Falls back to ``init_dlam`` for
isotopologues not present in the map.
init_logN : float, optional, default None
Global fallback ``logN`` when neither ``logN_{iso}`` nor ``logN`` appears in the
sampled parameters and no per-iso entry is given.
init_logN_by_iso : dict[str, float], optional, default None
Per-isotopologue fallback ``logN`` map. For each isotopologue, ``init_logN_by_iso[iso]``
is used when neither ``logN_{iso}`` nor ``logN`` are sampled. Falls back to ``init_logN``
for isotopologues not present in the map.
init_sigma : float, optional, default None
Fallback Gaussian sigma when ``"sigma"`` is not sampled.
init_sigma1 : float, optional, default None
Fallback ``sigma1`` for ``"2Gauss"`` when not sampled.
init_sigma2 : float, optional, default None
Fallback ``sigma2`` for ``"2Gauss"`` when not sampled.
init_sigma_G : float, optional, default None
Fallback ``sigma_G`` for ``"Gauss_Lorentz"`` when not sampled.
init_fwhm_L : float, optional, default None
Fallback Lorentzian FWHM when not sampled.
init_ratio : float, optional, default None
Fallback mixture ratio when not sampled.
fig_file : str, optional, default None
Base path for output figures.
wave_col : str, optional, default "WAVE"
Wavelength column in ``data``.
flux_col : str, optional, default "FLUX_STACK"
Flux column in ``data``.
error_col : str, optional, default "ERR_STACK"
Uncertainty column in ``data``.
continuum_col : str, optional, default "CONTINUUM"
Continuum column in ``data``.
verbose : bool, optional, default True
Print diagnostics.
pruning : bool, optional, default True
Apply posterior pruning.
N_Model : int, optional, default 20000
Number of elements in the model grid.
Returns
-------
dict[str, Any]
Dictionary with posterior summaries, samples, and model envelopes. Keys:
* ``param_keys`` (:class:`list` of :class:`str`) -- Ordered list of sampled parameter names, matching the column order of ``samples_pruned``.
* ``median_params`` (:class:`dict` [:class:`str`, :class:`float`]) -- Posterior median for each parameter.
* ``up_errors_params`` (:class:`dict` [:class:`str`, :class:`float`]) -- Upper 1-sigma error (p84 - p50) for each parameter.
* ``low_errors_params`` (:class:`dict` [:class:`str`, :class:`float`]) -- Lower 1-sigma error (p50 - p16) for each parameter.
* ``samples_pruned`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_samples, ndim)``) -- Pruned and burn-in-removed posterior samples.
* ``lnprob_pruned`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_samples,)``) -- Log-probability for each pruned sample.
* ``model_wave`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_Model,)``) -- Wavelength grid over the fitting window.
* ``median_model`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_Model,)``) -- Model flux evaluated at the posterior median parameters.
* ``model_p16`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_Model,)``) -- 16th percentile of the model ensemble (lower 1-sigma envelope).
* ``model_p84`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_Model,)``) -- 84th percentile of the model ensemble (upper 1-sigma envelope).
* ``best_model`` (:class:`numpy.ndarray` of :class:`float`, shape ``(N_Model,)``) -- Model flux evaluated at the highest-likelihood sample.
Raises
------
ValueError
If priors or required parameters are inconsistent.
ValueError
If required columns are missing from the data.
Notes
-----
If extreme prior values cause the rate-matrix solver to fail (``LinAlgError``
from a numerically degenerate matrix, typically due to NaN/Inf collision rates
at very high Q), ``solve_with_normalization_fast`` returns NaN populations.
``lnlike`` then detects a non-finite model and returns ``-inf``, so the walker
step is rejected gracefully rather than crashing.
"""
if config is not None:
if config.data is not UNSET:
data = config.data
if config.window is not UNSET:
window = config.window
if config.pumping is not UNSET:
pumping = config.pumping
if config.priors is not UNSET:
priors = config.priors
if config.isotopologues is not UNSET:
isotopologues = config.isotopologues
if config.systems is not UNSET:
systems = config.systems
if config.linelists is not UNSET:
linelists = config.linelists
if config.include_rotations is not UNSET:
include_rotations = config.include_rotations
if config.include_deltaJ0_parity_mix is not UNSET:
include_deltaJ0_parity_mix = config.include_deltaJ0_parity_mix
if config.require_X_only_for_rot is not UNSET:
require_X_only_for_rot = config.require_X_only_for_rot
if config.nwalkers is not UNSET:
nwalkers = config.nwalkers
if config.nsteps is not UNSET:
nsteps = config.nsteps
if config.n_cores is not UNSET:
n_cores = config.n_cores
if config.lsf is not UNSET:
lsf = config.lsf
if config.lsf_method is not UNSET:
lsf_method = config.lsf_method
if config.make_plots is not UNSET:
make_plots = config.make_plots
if config.progress is not UNSET:
progress = config.progress
if config.A_min is not UNSET:
A_min = config.A_min
if config.a is not UNSET:
a = config.a
if config.velocity_kms is not UNSET:
velocity_kms = config.velocity_kms
if config.delta_lambda_A is not UNSET:
delta_lambda_A = config.delta_lambda_A
if config.init_logQ is not UNSET:
init_logQ = config.init_logQ
if config.init_logQ_by_iso is not UNSET:
init_logQ_by_iso = config.init_logQ_by_iso
if config.init_T is not UNSET:
init_T = config.init_T
if config.init_T_by_iso is not UNSET:
init_T_by_iso = config.init_T_by_iso
if config.init_v_kms is not UNSET:
init_v_kms = config.init_v_kms
if config.init_v_kms_by_iso is not UNSET:
init_v_kms_by_iso = config.init_v_kms_by_iso
if config.init_dlam is not UNSET:
init_dlam = config.init_dlam
if config.init_dlam_by_iso is not UNSET:
init_dlam_by_iso = config.init_dlam_by_iso
if config.init_logN is not UNSET:
init_logN = config.init_logN
if config.init_logN_by_iso is not UNSET:
init_logN_by_iso = config.init_logN_by_iso
if config.init_sigma is not UNSET:
init_sigma = config.init_sigma
if config.init_sigma1 is not UNSET:
init_sigma1 = config.init_sigma1
if config.init_sigma2 is not UNSET:
init_sigma2 = config.init_sigma2
if config.init_sigma_G is not UNSET:
init_sigma_G = config.init_sigma_G
if config.init_fwhm_L is not UNSET:
init_fwhm_L = config.init_fwhm_L
if config.init_ratio is not UNSET:
init_ratio = config.init_ratio
if config.fig_file is not UNSET:
fig_file = config.fig_file
if config.wave_col is not UNSET:
wave_col = config.wave_col
if config.flux_col is not UNSET:
flux_col = config.flux_col
if config.error_col is not UNSET:
error_col = config.error_col
if config.continuum_col is not UNSET:
continuum_col = config.continuum_col
if config.omega is not UNSET:
omega = config.omega
if config.verbose is not UNSET:
verbose = config.verbose
if config.pruning is not UNSET:
pruning = config.pruning
if config.N_Model is not UNSET:
N_Model = config.N_Model
if data is None:
raise ValueError("data must be provided (argument or config.data).")
if window is None:
raise ValueError("window must be provided (argument or config.window).")
if pumping is None:
raise ValueError("pumping must be provided (argument or config.pumping).")
if priors is None:
raise ValueError("priors must be provided (argument or config.priors).")
required_cols = {wave_col, flux_col, error_col, continuum_col}
missing_cols = required_cols - set(data.columns)
if missing_cols:
raise ValueError(
f"Data is missing required columns: {missing_cols}. Please specify the correct columns using wave_col, flux_col, error_col, and continuum_col parameters."
)
iso_list = _as_list(isotopologues)
sys_tokens = normalize_cn_systems_arg(systems)
param_keys = list(priors.keys())
if lsf is not None:
drop = {"sigma_G", "fwhm_L", "sigma", "sigma1", "sigma2", "ratio"}
param_keys = [k for k in param_keys if k not in drop]
priors = {k: priors[k] for k in param_keys}
else:
if lsf_method == "2Gauss":
drop = {"sigma_G", "fwhm_L", "sigma"}
elif lsf_method == "Gauss_Lorentz":
drop = {"sigma1", "sigma2", "sigma"}
elif lsf_method == "Gauss":
drop = {"sigma_G", "fwhm_L", "sigma1", "sigma2", "ratio"}
elif lsf_method == "Lorentz":
drop = {"sigma_G", "sigma1", "sigma2", "sigma", "ratio"}
else:
raise ValueError("Provide `lsf` or lsf_method in {'2Gauss','Gauss_Lorentz','Gauss','Lorentz'}.")
param_keys = [k for k in param_keys if k not in drop]
priors = {k: priors[k] for k in param_keys}
for name in param_keys:
lo, hi = priors[name]
if not (np.isfinite(lo) and np.isfinite(hi) and hi > lo):
raise ValueError(f"Bad prior for {name!r}: {priors[name]}")
trans_by_iso = resolve_linelists_with_defaults(
linelists,
iso_list,
systems=sys_tokens,
A_min=A_min,
use_omega_labels=False,
lambda_min_A=pumping["WAVE"].min(),
lambda_max_A=pumping["WAVE"].max(),
)
def _iso_can_collide(iso: str) -> bool:
if is_atomic_species(iso):
return False
if f"logQ_{iso}" in priors:
return True
if "logQ" in priors:
return True
if init_logQ_by_iso is not None and iso in init_logQ_by_iso:
return init_logQ_by_iso[iso] is not None
return init_logQ is not None
iso_collides = {iso: _iso_can_collide(iso) for iso in trans_by_iso.keys()}
req = {"lower_es", "lower_v", "lower_J", "lower_sym", "E_lower_cm1"}
for iso, df_trans in trans_by_iso.items():
if not iso_collides[iso]:
continue
missing = sorted(list(req - set(df_trans.columns)))
if missing:
raise ValueError(
f"Isotopologue {iso!r} would use collisions (logQ provided) but its linelist "
f"is missing required columns: {missing}. Provide them via "
"from_user_linelist(... lower_*_col=..., E_lower_cm1_col=...), or set its "
"logQ to None to disable collisions for this isotopologue."
)
cache: dict[str, dict[str, Any]] = {}
for iso, df_trans in trans_by_iso.items():
lines_theta = attach_pumping_and_labels(
df_trans,
pumping,
line_v_kms=float(velocity_kms),
line_dlam_A=float(delta_lambda_A),
lsf_for_Jnu=None,
lam_col="lambda_vac_A",
)
M_rad, idx_to_level, lines_out = build_rate_matrix_nbar(
lines_theta,
include_stim_emission=True,
verbose=False,
A_col="A_ul",
upper_id_col="upper_id",
lower_id_col="lower_id",
g_upper_col="g_upper",
g_lower_col="g_lower",
)
ui = np.asarray(lines_out["__upper_idx"], dtype=np.int64)
A_ul = np.asarray(lines_out["A_ul"], dtype=np.float64)
nu = np.asarray(lines_out["__nu_Hz"], dtype=np.float64)
hnu = H_CGS * nu
gph_work = np.empty_like(A_ul, dtype=float)
gen_work = np.empty_like(A_ul, dtype=float)
if iso_collides[iso]:
coll_scaf = precompute_collision_scaffold_fast(
lines_out, idx_to_level,
include_deltaJ0_parity_mix=include_deltaJ0_parity_mix,
require_X_only=require_X_only_for_rot,
iso_name=iso,
)
else:
coll_scaf = _empty_scaffold()
M_work = np.empty_like(M_rad)
A_work = np.empty_like(M_rad)
b_work = np.zeros(M_rad.shape[0], float)
Cup_work = np.empty_like(coll_scaf.get("iu", np.array([], dtype=int)), dtype=float)
cache[iso] = dict(
M_rad=M_rad,
idx_to_level=idx_to_level,
lines_out=lines_out,
coll_scaf=coll_scaf,
M_work=M_work,
A_work=A_work,
b_work=b_work,
ui=ui,
A_ul=A_ul,
hnu=hnu,
Cup_work=Cup_work,
gph_work=gph_work,
gen_work=gen_work,
)
def _col(obj, name: str) -> np.ndarray:
if hasattr(obj, "colnames"):
return np.asarray(obj[name])
if hasattr(obj, "columns"):
return np.asarray(obj[name].values)
return np.asarray(obj[name])
x_data = _col(data, wave_col)
y_data = _col(data, flux_col)
y_err = _col(data, error_col)
cont = _col(data, continuum_col)
mwin = (x_data >= window[0]) & (x_data <= window[1])
x_fit = x_data[mwin]
y_fit = y_data[mwin] - cont[mwin]
y_err_fit = y_err[mwin]
lnprob_callable = _LnProbCallable(
priors=priors,
param_keys=param_keys,
cache=cache,
iso_list=iso_list,
x_fit=x_fit,
y_fit=y_fit,
y_err_fit=y_err_fit,
lsf=lsf,
lsf_method=lsf_method,
omega=omega,
include_rotations=include_rotations,
init_logQ=init_logQ,
init_logQ_by_iso=init_logQ_by_iso,
init_T=init_T,
init_T_by_iso=init_T_by_iso,
init_v_kms=init_v_kms,
init_v_kms_by_iso=init_v_kms_by_iso,
init_dlam=init_dlam,
init_dlam_by_iso=init_dlam_by_iso,
init_logN=init_logN,
init_logN_by_iso=init_logN_by_iso,
init_sigma=init_sigma,
init_sigma1=init_sigma1,
init_sigma2=init_sigma2,
init_sigma_G=init_sigma_G,
init_fwhm_L=init_fwhm_L,
init_ratio=init_ratio,
)
ndim = len(param_keys)
nburn = nsteps // 2
print("Number of iterations:", ndim * nwalkers * nsteps)
p0 = np.array([[np.random.uniform(*priors[name]) for name in param_keys] for _ in range(nwalkers)])
move = emcee.moves.StretchMove(a=a)
global _LNPROB_CALLABLE
prev_lnprob_callable = _LNPROB_CALLABLE
_LNPROB_CALLABLE = lnprob_callable
try:
with _parent_blas_limiter(n_cores):
pool = _build_mcmc_pool(n_cores, lnprob_callable)
log_prob_fn = _lnprob_worker if pool is not None else lnprob_callable
try:
sampler = emcee.EnsembleSampler(
nwalkers, ndim, log_prob_fn, moves=move, pool=pool
)
sampler.run_mcmc(p0, nsteps, progress=progress)
finally:
if pool is not None:
pool.close()
pool.join()
finally:
_LNPROB_CALLABLE = prev_lnprob_callable
chain = sampler.get_chain()
lnprob_full = sampler.get_log_prob()
flat_chain = chain.reshape(-1, ndim)
flat_lnprob = lnprob_full.reshape(-1)
best_idx = int(np.argmax(flat_lnprob))
best_theta = flat_chain[best_idx]
best_params = lnprob_callable.theta_to_params(best_theta)
if verbose:
print("#" * 50)
print("*** Best fit (no pruning) ***")
for name in param_keys:
print(f"{name}: {best_params[name]:.6g}")
af = sampler.acceptance_fraction
if verbose:
print("#" * 50)
print("*** Acceptance Fraction ***")
print("Mean acceptance fraction:", np.mean(af))
af_msg = '''As a rule of thumb, the acceptance fraction (af) should be
between 0.2 and 0.5
If af < 0.2 decrease the MCMC a parameter
If af > 0.5 increase the MCMC a parameter
actual af = {}'''.format(np.mean(af))
if verbose:
print("Mean acceptance fraction:", np.mean(af))
if np.mean(af)<0.2 or np.mean(af)>0.5:
print(af_msg)
warnings.warn("Acceptance fraction out of bounds.", UserWarning)
samples = chain[nburn:, :, :].reshape(-1, ndim)
lnprob_burn = lnprob_full[nburn:, :].reshape(-1)
def prune(samples: np.ndarray,
lnprob_arr: np.ndarray,
scaler: float = 5.0,
quiet: bool = False):
minlnprob = lnprob_arr.max()
dln = np.abs(lnprob_arr - minlnprob)
med = np.median(dln)
avg = np.mean(dln)
skew = abs(avg - med)
rms = np.std(dln)
mask = dln < scaler * rms
ln2 = lnprob_arr[mask]
s2 = samples[mask]
prev_med = 0.0
while skew > 0.1 * med and ln2.size > 0:
minlnprob = ln2.max()
dln = np.abs(ln2 - minlnprob)
rms = np.std(dln)
mask = dln < scaler * rms
if mask.sum() == ln2.size:
mask = dln < (scaler / 2.0) * rms
ln2 = ln2[mask]
s2 = s2[mask]
dln = np.abs(ln2 - minlnprob)
med = np.median(dln)
avg = np.mean(dln)
skew = abs(avg - med)
if not quiet:
print(med, avg, skew)
if med == prev_med:
scaler /= 1.5
prev_med = med
good = ln2 <= ln2.max()
return s2[good], ln2[good]
if pruning:
if verbose:
print("#" * 50)
print("*** Pruning... ***")
try:
samples_pruned, lnprob_pruned = prune(samples, lnprob_burn, quiet=not progress)
except Exception as exc:
print("Pruning failed:", exc)
samples_pruned, lnprob_pruned = samples, lnprob_burn
else:
samples_pruned, lnprob_pruned = samples, lnprob_burn
median_params: Dict[str, float] = {}
up_errors: Dict[str, float] = {}
low_errors: Dict[str, float] = {}
for i, name in enumerate(param_keys):
p16, p50, p84 = np.percentile(samples_pruned[:, i], [16, 50, 84])
median_params[name] = float(p50)
up_errors[name] = float(p84 - p50)
low_errors[name] = float(p50 - p16)
err = 0.5 * ((p84 - p50) + (p50 - p16))
print(f"{name}: {p50:.4f} +/- {err:.4f} [{p16:.4f}, {p84:.4f}]")
x_model = np.linspace(window[0], window[1], N_Model)
n_draw = min(200, samples_pruned.shape[0])
model_stack = np.empty((n_draw, x_model.size))
for i in range(n_draw):
model_stack[i] = lnprob_callable.model_flux(samples_pruned[i], x_model)
theta_med = [median_params[k] for k in param_keys]
best_model = lnprob_callable.model_flux(theta_med, x_model)
p16_m, p50_m, p84_m = np.percentile(model_stack, [16, 50, 84], axis=0)
median_model = p50_m
model_p16 = p16_m
model_p84 = p84_m
param_labels = {
"logN": r"$\mathrm{log}_{10}(N)$",
"logQ": r"$\log_{10}$(Q$_{\rm{col}}$ / [s$^{-1}$])",
"T": r"$T_{kin}$ [K]",
"v_kms": r"$\Delta$v [km s$^{-1}$]",
"dlam": r"$\Delta \lambda$ [Å]",
"sigma": r"$\sigma$ [Å]",
"sigma1": r"$\sigma_1$ [Å]",
"sigma2": r"$\sigma_2$ [Å]",
"sigma_G": r"$\sigma_G$ [Å]",
"fwhm_L": r"FWHM$_L$ [Å]",
"ratio": r"Ratio",
}
if iso_list and len(iso_list) > 1:
param_lab_by_iso = {
f'{k}_{iso}': (rf"$\mathrm{{log}}_{{10}}(N_{{{iso}}})$" if k == "logN" else f"{param_labels[k]}_{iso}")
for k in param_labels for iso in iso_list
}
param_labels = {**param_labels, **param_lab_by_iso}
_logN_units_note = r"$N\ [\mathrm{molecules\ cm}^{-2}]$"
_has_logN = any(k == "logN" or k.startswith("logN_") for k in param_keys)
if make_plots:
trace_ylabel_fontsize = max(13 - 0.4 * ndim, 8)
trace_tick_fontsize = max(11 - 0.3 * ndim, 7)
fig, axes = plt.subplots(ndim, 1, figsize=(10, max(1.8 * ndim, 4)), sharex=True)
if ndim == 1:
axes = [axes]
nsteps_chain, nwalkers_chain, _ = chain.shape
steps = np.arange(nsteps_chain)
for j, name in enumerate(param_keys):
for w in range(nwalkers_chain):
axes[j].plot(steps, chain[:, w, j], alpha=0.7, lw=0.8)
axes[j].set_ylabel(param_labels[name], fontsize=trace_ylabel_fontsize, labelpad=2)
axes[j].tick_params(axis='both', labelsize=trace_tick_fontsize)
if name == "logN" or name.startswith("logN_"):
axes[j].text(0.98, 0.97, _logN_units_note, fontsize=trace_ylabel_fontsize,
va='top', ha='right', transform=axes[j].transAxes,
bbox=dict(boxstyle='round,pad=0.2', fc='white', alpha=0.7, ec='none'))
axes[-1].set_xlabel("iteration", fontsize=trace_ylabel_fontsize)
fig.tight_layout(h_pad=0.3)
plt.savefig(f"{fig_file}_mcmc_traces.pdf", dpi=300, format='pdf')
plt.show()
fig_size = max(3.0 * ndim, 10)
label_fontsize = max(16 - 0.7 * ndim, 9)
title_fontsize = max(14 - 0.7 * ndim, 8)
tick_fontsize = max(12 - 0.4 * ndim, 7)
fig = corner.corner(
samples_pruned,
labels=[param_labels[k] for k in param_keys],
title_kwargs={'fontsize': title_fontsize, 'y': 1.05},
title_fmt=".2f",
use_math_text=True,
bins=15,
quantiles=[0.16, 0.5, 0.84],
show_titles=True,
color='lightseagreen',
hist_kwargs={'color': 'black', 'linewidth': 1.5},
contour_kwargs={'linewidths': 1, 'colors': 'black'},
label_kwargs={'fontsize': label_fontsize},
fig=plt.figure(figsize=(fig_size, fig_size)),
)
margin = max(0.08 + 0.008 * ndim, 0.12)
fig.subplots_adjust(wspace=0.05, hspace=0.05, left=margin, bottom=margin, right=0.97, top=0.95)
elements = [param_labels[k] for k in param_keys]
axes = np.array(fig.axes).reshape((ndim, ndim))
for i in range(ndim):
ax = axes[i, i]
name_i = param_keys[i]
q16, q50, q84 = np.quantile(samples[:, i], [0.16, 0.5, 0.84])
q_minus, q_plus = q50 - q16, q84 - q50
value_str = rf"${q50:.2f}_{{-{q_minus:.2f}}}^{{+{q_plus:.2f}}}$"
if name_i == "logN" or name_i.startswith("logN_"):
title = f"{_logN_units_note}\n{elements[i]}\n{value_str}"
ax.set_title(title, fontsize=title_fontsize, y=1.12)
else:
title = f"{elements[i]}\n{value_str}"
ax.set_title(title, fontsize=title_fontsize, y=1.05)
for ax in fig.get_axes():
ax.tick_params(axis='both', labelsize=tick_fontsize)
for formatter in (ax.xaxis.get_major_formatter(), ax.yaxis.get_major_formatter()):
if hasattr(formatter, 'set_useOffset'):
formatter.set_useOffset(False)
if ndim > 3:
for ax in fig.get_axes():
for label in ax.get_xticklabels():
label.set_rotation(45)
label.set_ha('right')
plt.savefig(f"{fig_file}_corner.pdf", dpi=300, format='pdf')
plt.show()
plt.figure(figsize=(10, 6))
plt.plot(x_fit, y_fit, label="Data (cont-sub)", color="black", alpha=0.8)
plt.fill_between(
x_fit,
y_fit - y_err_fit,
y_fit + y_err_fit,
color="k",
alpha=0.25,
label=r"1$\sigma$ error",
)
plt.plot(x_model, median_model, label="Median Model", color="crimson", alpha=0.9)
plt.xlabel("Wavelength [Å]")
plt.ylabel(r"F$_{\lambda}$ [erg cm$^{-2}$ s$^{-1}$ Å$^{-1}$]")
plt.legend()
plt.tight_layout()
plt.savefig(f"{fig_file}_fit.pdf", dpi=300, format='pdf')
plt.show()
return {
"param_keys": param_keys,
"median_params": median_params,
"up_errors_params": up_errors,
"low_errors_params": low_errors,
"samples_pruned": samples_pruned,
"lnprob_pruned": lnprob_pruned,
"model_wave": x_model,
"median_model": median_model,
"model_p16": model_p16,
"model_p84": model_p84,
"best_model": best_model,
}