from collections import deque
from typing import Union, Sequence, Optional
import logging
from flatdict import FlatterDict
import numpy as np
from scipy.linalg import cho_factor, cho_solve
from scipy.optimize import minimize
import toml
from Starfish import Spectrum
from Starfish.emulator import Emulator
from Starfish.transforms import (
rotational_broaden,
resample,
doppler_shift,
extinct,
rescale,
_get_renorm_factor,
)
from Starfish.utils import calculate_dv, create_log_lam_grid
from .kernels import global_covariance_matrix, local_covariance_matrix
[docs]class SpectrumModel:
"""
A single-order spectrum model.
Parameters
----------
emulator : :class:`Starfish.emulators.Emulator`
The emulator to use for this model.
data : :class:`Starfish.spectrum.Spectrum`
The data to use for this model
grid_params : array-like
The parameters that are used with the associated emulator
max_deque_len : int, optional
The maximum number of residuals to retain in a deque of residuals. Default is
100
name : str, optional
A name for the model. Default is 'SpectrumModel'
Keyword Arguments
-----------------
params : dict
Any remaining keyword arguments will be interpreted as parameters.
Here is a table describing the avialable parameters and their related functions
=========== ===============================================
Parameter Function
=========== ===============================================
vsini :func:`~Starfish.transforms.rotational_broaden`
vz :func:`~Starfish.transforms.doppler_shift`
Av :func:`~Starfish.transforms.extinct`
Rv :func:`~Starfish.transforms.extinct`
log_scale :func:`~Starfish.transforms.rescale`
=========== ===============================================
.. note::
If :attr:`log_scale` is not specified, the model will use
:func:`~Starfish.transforms.renorm` to automatically scale the spectrum to the
data using the ratio of integrated fluxes.
The ``global_cov`` keyword arguments must be a dictionary definining the
hyperparameters for the global covariance kernel,
:meth:`kernels.global_covariance_matrix`
================ =============================================================
Global Parameter Description
================ =============================================================
log_amp The natural logarithm of the amplitude of the Matern kernel
log_ls The natural logarithm of the lengthscale of the Matern kernel
================ =============================================================
The ``local_cov`` keryword argument must be a list of dictionaries defining
hyperparameters for many Gaussian kernels, , :meth:`kernels.local_covariance_matrix`
================ =============================================================
Local Parameter Description
================ =============================================================
log_amp The natural logarithm of the amplitude of the kernel
mu The location of the local kernel
log_sigma The natural logarithm of the standard deviation of the kernel
================ =============================================================
Attributes
----------
params : dict
The dictionary of parameters that are used for doing the modeling.
frozen : list
A list of strings corresponding to frozen parameters
residuals : deque
A deque containing residuals from calling :meth:`SpectrumModel.log_likelihood`
"""
_PARAMS = ["vz", "vsini", "Av", "Rv", "log_scale", "global_cov", "local_cov"]
_GLOBAL_PARAMS = ["log_amp", "log_ls"]
_LOCAL_PARAMS = ["mu", "log_amp", "log_sigma"]
def __init__(
self,
emulator: Union[str, Emulator],
data: Union[str, Spectrum],
grid_params: Sequence[float],
max_deque_len: int = 100,
name: str = "SpectrumModel",
**params,
):
if isinstance(emulator, str):
emulator = Emulator.load(emulator)
if isinstance(data, str):
data = Spectrum.load(data)
if len(data) > 1:
raise ValueError(
"Multiple orders detected in data, please use EchelleModel"
)
self.emulator: Emulator = emulator
self.data_name = data.name
self.data = data[0]
dv = calculate_dv(self.data.wave)
self.min_dv_wave = create_log_lam_grid(
dv, self.emulator.wl.min(), self.emulator.wl.max()
)["wl"]
self.bulk_fluxes = resample(
self.emulator.wl, self.emulator.bulk_fluxes, self.min_dv_wave
)
self.residuals = deque(maxlen=max_deque_len)
self.params = FlatterDict(params)
self.frozen = []
self.name = name
# Unpack the grid parameters
self.n_grid_params = len(grid_params)
self.grid_params = grid_params
self._lnprob = None
self._glob_cov = None
self._loc_cov = None
self.log = logging.getLogger(self.__class__.__name__)
@property
def grid_params(self):
"""
numpy.ndarray : The parameters used for the spectral emulator.
:setter: Sets the values in the order of ``Emulator.param_names``
"""
values = []
for key in self.emulator.param_names:
values.append(self.params[key])
return np.array(values)
@grid_params.setter
def grid_params(self, values):
for key, value in zip(self.emulator.param_names, values):
if key not in self.frozen:
self.params[key] = value
@property
def labels(self):
"""
tuple of str : The thawed parameter names
"""
keys = self.get_param_dict(flat=True).keys()
return tuple(keys)
def __getitem__(self, key):
return self.params[key]
def __setitem__(self, key, value):
if ":" in key:
cov, rest = key.split(":", 1)
k = rest.split(":")[-1] if ":" in rest else rest
if cov == "global_cov" and k in self._GLOBAL_PARAMS:
self.params[key] = value
elif cov == "local_cov" and k in self._LOCAL_PARAMS:
self.params[key] = value
else:
raise ValueError(f"{key} not recognized")
else:
if key in [*self._PARAMS, *self.emulator.param_names]:
self.params[key] = value
else:
raise ValueError(f"{key} not recognized")
def __delitem__(self, key):
if key not in self.params:
raise ValueError(f"{key} not in params")
elif key == "global_cov":
self._glob_cov = None
self.frozen = [
key for key in self.frozen if not key.startswith("global_cov")
]
elif key == "local_cov":
self._loc_cov = None
self.frozen = [
key for key in self.frozen if not key.startswith("local_cov")
]
del self.params[key]
if key in self.frozen:
self.frozen.remove(key)
[docs] def __call__(self):
"""
Performs the transformations according to the parameters available in
``self.params``
Returns
-------
flux, cov : tuple
The transformed flux and covariance matrix from the model
"""
wave = self.min_dv_wave
fluxes = self.bulk_fluxes
if "vsini" in self.params:
fluxes = rotational_broaden(wave, fluxes, self.params["vsini"])
if "vz" in self.params:
wave = doppler_shift(wave, self.params["vz"])
fluxes = resample(wave, fluxes, self.data.wave)
if "Av" in self.params:
fluxes = extinct(self.data.wave, fluxes, self.params["Av"])
# Only rescale flux_mean and flux_std
if "log_scale" in self.params:
scale = np.exp(self.params["log_scale"])
fluxes[-2:] = rescale(fluxes[-2:], scale)
weights, weights_cov = self.emulator(self.grid_params)
L, flag = cho_factor(weights_cov, overwrite_a=True)
# Decompose the bulk_fluxes (see emulator/emulator.py for the ordering)
*eigenspectra, flux_mean, flux_std = fluxes
# Complete the reconstruction
X = eigenspectra * flux_std
flux = weights @ X + flux_mean
# Renorm to data flux if no "log_scale" provided
if "log_scale" not in self.params:
factor = _get_renorm_factor(self.data.wave, flux, self.data.flux)
flux = rescale(flux, factor)
X = rescale(X, factor)
cov = X.T @ cho_solve((L, flag), X)
# Trivial covariance
np.fill_diagonal(cov, cov.diagonal() + self.data.sigma ** 2)
# Global covariance
if "global_cov" in self.params:
if "global_cov" not in self.frozen or self._glob_cov is None:
ag = np.exp(self.params["global_cov:log_amp"])
lg = np.exp(self.params["global_cov:log_ls"])
T = self.params["T"]
self._glob_cov = global_covariance_matrix(self.data.wave, T, ag, lg)
if self._glob_cov is not None:
cov += self._glob_cov
# Local covariance
if "local_cov" in self.params:
if "local_cov" not in self.frozen or self._loc_cov is None:
for kernel in self.params.as_dict()["local_cov"]:
mu = kernel["mu"]
amplitude = np.exp(kernel["log_amp"])
sigma = np.exp(kernel["log_sigma"])
self._loc_cov = local_covariance_matrix(
self.data.wave, amplitude, mu, sigma
)
if self._loc_cov is not None:
cov += self._loc_cov
return flux, cov
[docs] def log_likelihood(self, priors: Optional[dict] = None) -> float:
"""
Returns the log probability of a multivariate normal distribution
Parameters
----------
priors : dict, optional
If provided, will use these priors in the MLE. Should contain keys that
match the model's keys and values that have a `logpdf` method that takes
one value (like ``scipy.stats`` distributions). Default is None.
Warning
-------
No checks will be done on the :attr:`priors` for speed.
Returns
-------
float
"""
# Priors
prior_lp = 0
if priors is not None:
for key, prior in priors.items():
if key in self.params:
prior_lp += prior.logpdf(self[key])
if not np.isfinite(prior_lp):
return -np.inf
# Likelihood
flux, cov = self()
np.fill_diagonal(cov, cov.diagonal() + 1e-10)
factor, flag = cho_factor(cov, overwrite_a=True)
logdet = 2 * np.sum(np.log(factor.diagonal()))
R = flux - self.data.flux
self.residuals.append(R)
sqmah = R @ cho_solve((factor, flag), R)
self._lnprob = -(logdet + sqmah) / 2
return self._lnprob + prior_lp
[docs] def get_param_dict(self, flat: bool = False) -> dict:
"""
Gets the dictionary of thawed parameters.
Parameters
----------
flat : bool, optional
If True, returns the parameters completely flat. For example,
``['local']['0']['mu']`` would have the key ``'local:0:mu'``.
Default is False
Returns
-------
dict
See Also
--------
:meth:`set_param_dict`
"""
params = FlatterDict()
for key, val in self.params.items():
if key not in self.frozen:
params[key] = val
return params if flat else params.as_dict()
[docs] def set_param_dict(self, params):
"""
Sets the parameters with a dictionary. Note that this should not be used to add
new parameters
Parameters
----------
params : dict
The new parameters. If a key is present in ``self.frozen`` it will not be
changed
See Also
--------
:meth:`get_param_dict`
"""
params = FlatterDict(params)
for key, val in params.items():
if key not in self.frozen:
self.params[key] = val
[docs] def get_param_vector(self):
"""
Get a numpy array of the thawed parameters
Returns
-------
numpy.ndarray
See Also
--------
:meth:`set_param_vector`
"""
return np.array(list(self.get_param_dict(flat=True).values()))
[docs] def set_param_vector(self, params):
"""
Sets the parameters based on the current thawed state. The values will be
inserted according to the order of :obj:`SpectrumModel.labels`.
Parameters
----------
params : array_like
The parameters to set in the model
Raises
------
ValueError
If the `params` do not match the length of the current thawed parameters.
See Also
--------
:meth:`get_param_vector`
"""
if len(params) != len(self.labels):
raise ValueError("Param Vector does not match length of thawed parameters")
param_dict = dict(zip(self.labels, params))
self.set_param_dict(param_dict)
[docs] def freeze(self, names):
"""
Freeze the given parameter such that :meth:`get_param_dict` and
:meth:`get_param_vector` no longer include this parameter, however it will
still be used when calling the model.
Parameters
----------
name : str or array-like
The parameter to freeze. If ``'all'``, will freeze all parameters. If
``'global_cov'`` will freeze all global covariance parameters. If
``'local_cov'`` will freeze all local covariance parameters.
Raises
------
ValueError
If the given parameter does not exist
See Also
--------
:meth:`thaw`
"""
names = np.atleast_1d(names)
if names[0] == "all":
for key in self.labels:
if key not in self.frozen:
self.frozen.append(key)
if "global_cov" in self.params:
self.frozen.append("global_cov")
if "local_cov" in self.params:
self.frozen.append("local_cov")
else:
for _name in names:
# Avoid kookyness of numpy.str type
name = str(_name)
if name == "global_cov":
self.frozen.append("global_cov")
self._glob_cov = None
for key in self.params.as_dict()["global_cov"].keys():
flat_key = f"global_cov:{key}"
if flat_key not in self.frozen:
self.frozen.append(flat_key)
elif name == "local_cov":
self.frozen.append("local_cov")
self._loc_cov = None
for i, kern in enumerate(self.params.as_dict()["local_cov"]):
for key in kern.keys():
flat_key = f"local_cov:{i}:{key}"
if flat_key not in self.frozen:
self.frozen.append(flat_key)
elif name not in self.frozen and name in self.params:
self.frozen.append(name)
[docs] def thaw(self, names):
"""
Thaws the given parameter. Opposite of freezing
Parameters
----------
name : str or array-like
The parameter to thaw. If ``'all'``, will thaw all parameters. If
``'global_cov'`` will thaw all global covariance parameters. If
``'local_cov'`` will thaw all local covariance parameters.
Raises
------
ValueError
If the given parameter does not exist.
See Also
--------
:meth:`freeze`
"""
names = np.atleast_1d(names)
if names[0] == "all":
self.frozen = []
else:
for _name in names:
# Avoid kookyness of numpy.str type
name = str(_name)
if name == "global_cov":
self.frozen.remove("global_cov")
for key in self.params.as_dict()["global_cov"].keys():
flat_key = f"global_cov:{key}"
self.frozen.remove(flat_key)
elif name == "local_cov":
self.frozen.remove("local_cov")
for i, kern in enumerate(self.params.as_dict()["local_cov"]):
for key in kern.keys():
flat_key = f"local_cov:{i}:{key}"
self.frozen.remove(flat_key)
elif name in self.frozen:
self.frozen.remove(name)
[docs] def save(self, filename, metadata=None):
"""
Saves the model as a set of parameters into a TOML file
Parameters
----------
filename : str or path-like
The TOML filename to save to.
metadata : dict, optional
If provided, will save the provided dictionary under a 'metadata' key. This
will not be read in when loading models but provides a way of providing
information in the actual TOML files. Default is None.
"""
output = {"parameters": self.params.as_dict(), "frozen": self.frozen}
meta = {}
meta["name"] = self.name
meta["data"] = self.data_name
if self.emulator.name is not None:
meta["emulator"] = self.emulator.name
if metadata is not None:
meta.update(metadata)
output["metadata"] = meta
with open(filename, "w") as handler:
encoder = toml.TomlNumpyEncoder(output.__class__)
toml.dump(output, handler, encoder=encoder)
self.log.info(f"Saved current state at {filename}")
[docs] def load(self, filename):
"""
Load a saved model state from a TOML file
Parameters
----------
filename : str or path-like
The saved state to load
"""
with open(filename, "r") as handler:
data = toml.load(handler)
self.params = FlatterDict(data["parameters"])
self.frozen = data["frozen"]
[docs] def train(self, priors: Optional[dict] = None, **kwargs):
"""
Given a :class:`SpectrumModel` and a dictionary of priors, will perform
maximum-likelihood estimation (MLE). This will use ``scipy.optimize.minimize`` to
find the maximum a-posteriori (MAP) estimate of the current model state. Note
that this alters the state of the model. This means that you can run this
method multiple times until the optimization succeeds. By default, we use the
"Nelder-Mead" method in `minimize` to avoid approximating any derivatives.
Parameters
----------
priors : dict, optional
Priors to pass to :meth:`log_likelihood`
**kwargs : dict, optional
These keyword arguments will be passed to `scipy.optimize.minimize`
Returns
-------
soln : `scipy.optimize.minimize_result`
The output of the minimization.
Raises
------
ValueError
If the priors are poorly specified
RuntimeError
If any priors evaluate to non-finite values
See Also
--------
:meth:`log_likelihood`
"""
if priors is None:
priors = {}
# Check priors for validity
for key, val in priors.items():
# Key exists
if key not in self.params:
raise ValueError(f"Invalid priors. {key} not a vlid key.")
# has logpdf method
if not callable(getattr(val, "logpdf", None)):
raise ValueError(
f"Invalid priors. {key} does not have a `logpdf` method"
)
# Evaluates to a finite number in current state
log_prob = val.logpdf(self[key])
if not np.isfinite(log_prob):
raise RuntimeError(f"{key}'s logpdf evaluated to {log_prob}")
def nll(P):
self.set_param_vector(P)
return -self.log_likelihood(priors)
p0 = self.get_param_vector()
params = {"method": "Nelder-Mead"}
params.update(kwargs)
soln = minimize(nll, p0, **params)
if soln.success:
self.set_param_vector(soln.x)
return soln
[docs] def plot(self, axes=None, plot_kwargs=None, resid_kwargs=None):
"""
Plot the model.
This will create two subplots, one which shows the current model against the
data, and another which shows the current residuals with 3:math:`\\sigma`
contours from the diagonal of the covariance matrix. Note this requires
matplotlib to be installed, which is not installed by default with Starfish.
Parameters
----------
axes : iterable of matplotlib.Axes, optional
If provided, will use the first two axes to plot, otherwise will create new
axes, by default None
plot_kwargs : dict, optional
If provided, will use these kwargs for the comparison plot, by default None
resid_kwargs : dict, optional
If provided, will use these kwargs for the residuals plot, by default None
Returns
-------
list of matplotlib.Axes
The returned axes, for the user to edit as they please
"""
import matplotlib.pyplot as plt
from matplotlib import rcParams
if axes is None:
# Set up a 4x4 grid with the main plot taking the whole left column
figsize = rcParams["figure.figsize"]
plt.figure(figsize=(figsize[0] * 1.75, figsize[1] * 1.1))
grid = plt.GridSpec(2, 2, width_ratios=(1.25, 1))
axes = [
plt.subplot(grid[:, 0]),
plt.subplot(grid[0, 1]),
plt.subplot(grid[1, 1]),
]
axes[1].tick_params(labelbottom=False)
if plot_kwargs is None:
plot_kwargs = {}
if resid_kwargs is None:
resid_kwargs = {}
model_flux, model_cov = self()
# Comparison plot
plot_params = {"lw": 0.7}
plot_params.update(plot_kwargs)
ax = axes[0]
ax.plot(self.data.wave, self.data.flux, label="Data", **plot_params)
ax.plot(self.data.wave, model_flux, label="Model", **plot_params)
ax.set_yscale("log")
ax.set_xlabel(r"$\lambda$ [$\AA$]")
ax.set_ylabel(r"$f_\lambda$ [$erg/cm^2/s/cm$]")
ax.legend()
# Residuals plot
R = self.data.flux - model_flux
std = np.sqrt(model_cov.diagonal())
resid_params = {"lw": 0.3}
resid_params.update(resid_kwargs)
ax = axes[1]
ax.plot(self.data.wave, R, c="k", label="Data - Model", **resid_params)
ax.fill_between(
self.data.wave, -std, std, color="C2", alpha=0.6, label=r"$\sigma$"
)
ax.fill_between(
self.data.wave, -2 * std, 2 * std, color="C2", alpha=0.4, label=r"$2\sigma$"
)
ax.fill_between(
self.data.wave, -3 * std, 3 * std, color="C2", alpha=0.2, label=r"$3\sigma$"
)
ax.set_ylabel(r"$\Delta f_\lambda$")
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
ax.legend()
# Relative Error plot
R_f = R / self.data.flux
ax = axes[2]
ax.plot(self.data.wave, R_f, label="Data - Model", c="k", **resid_params)
ax.set_xlabel(r"$\lambda$ [$\AA$]")
ax.set_ylabel(r"$\Delta f_\lambda / f_\lambda$")
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
plt.suptitle(self.data_name)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
return axes
def __repr__(self):
output = f"{self.name}\n"
output += "-" * len(self.name) + "\n"
output += f"Data: {self.data_name}\n"
output += f"Emulator: {self.emulator.name}\n"
output += f"Log Likelihood: {self._lnprob}\n"
output += "\nParameters\n"
for key, value in self.get_param_dict().items():
if key == "global_cov":
output += " global_cov:\n"
for gkey, gval in value.items():
output += f" {gkey}: {gval}\n"
elif key == "local_cov":
output += " local_cov:\n"
for i, kern in enumerate(value.values()):
output += f" {i}: "
for lkey, lval in kern.items():
output += f"{lkey}: {lval}, "
# Remove trailing whitespace and comma
output = output[:-2]
output += "\n"
else:
output += f" {key}: {value}\n"
if len(self.frozen) > 0:
output += "\nFrozen Parameters\n"
for key in self.frozen:
if key in ["global_cov", "local_cov"]:
continue
output += f" {key}: {self[key]}\n"
return output[:-1] # No trailing newline