import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
import emcee
import corner
from .models import UniformPrior, CompanionShocking, BaseCompanionShocking
from .lightcurve import filter_legend, flux2mag
from .filters import filtdict
from pkg_resources import resource_filename
import warnings
PRIOR_WARNING = 'The p_max/p_min keywords are deprecated. Use the priors keyword instead.'
MODEL_KWARGS_WARNING = 'The model_kwargs keyword is deprecated. These are now included in the model intialization.'
[docs]def lightcurve_mcmc(lc, model, priors=None, p_min=None, p_max=None, p_lo=None, p_up=None,
nwalkers=100, nsteps=1000, nsteps_burnin=1000, model_kwargs=None,
show=False, save_plot_as='', save_sampler_as='', use_sigma=False, sigma_type='relative'):
"""
Fit an analytical model to observed photometry using a Markov-chain Monte Carlo routine
Parameters
----------
lc : lightcurve_fitting.lightcurve.LC
Table of broadband photometry including columns "MJD", "mag", "dmag", "filter"
model : lightcurve_fitting.models.Model
The model to fit to the light curve. Available models: :class:`.models.ShockCooling`,
:class:`.models.ShockCooling2`, :class:`.models.ShockCooling3`, :class:`.models.CompanionShocking`,
:class:`.models.CompanionShocking2`, :class:`.models.CompanionShocking3`
priors : list, optional
Prior probability distributions for each model parameter. Available priors:
:class:`.models.UniformPrior` (default), :class:`.models.LogUniformPrior`, :class:`.models.GaussianPrior`
p_min, p_max : list, optional
DEPRECATED: Use `priors` instead
p_lo : list
Lower bounds on the starting guesses for each paramter
p_up : list
Upper bounds on the starting guesses for each parameter
nwalkers : int, optional
Number of walkers (chains) for the MCMC routine. Default: 100
nsteps : int, optional
Number of steps (iterations) for the MCMC routine, excluding burn-in. Default: 1000
nsteps_burnin : int, optional
Number of steps (iterations) for the MCMC routine during burn-in. Default: 1000
model_kwargs : dict, optional
DEPRECATED: Keyword arguments are now included in the model initialization
show : bool, optional
If True, plot and display the chain histories
save_plot_as : str, optional
Save a plot of the chain histories to this filename
save_sampler_as : str, optional
Save the aggregated chain histories to this filename
use_sigma : bool, optional
If True, treat the last parameter as an intrinsic scatter parameter that does not get passed to the model
sigma_type : str, optional
If 'relative' (default), sigma will be in units of the individual photometric uncertainties.
If 'absolute', sigma will be in units of the median photometric uncertainty.
Returns
-------
sampler : emcee.EnsembleSampler
EnsembleSampler object containing the results of the fit
"""
if model_kwargs is not None:
raise Exception(MODEL_KWARGS_WARNING)
if model.output_quantity == 'flux':
lc.calcFlux()
elif model.output_quantity == 'lum':
lc.calcAbsMag()
lc.calcLum()
if use_sigma and model.input_names[-1] != '\\sigma':
model.input_names.append('\\sigma')
model.units.append(u.dimensionless_unscaled)
ndim = model.nparams
# DEPRECATED
if p_min is None:
p_min = np.tile(-np.inf, ndim)
elif len(p_min) == ndim:
p_min = np.array(p_min, float)
warnings.warn(PRIOR_WARNING)
else:
raise Exception(PRIOR_WARNING)
# DEPRECATED
if p_max is None:
p_max = np.tile(np.inf, ndim)
elif len(p_max) == ndim:
p_max = np.array(p_max, float)
warnings.warn(PRIOR_WARNING)
else:
raise Exception(PRIOR_WARNING)
if p_lo is None:
p_lo = p_min
elif len(p_lo) == ndim:
p_lo = np.array(p_lo, float)
else:
raise Exception('p_lo must have length {:d}'.format(ndim))
if len(p_up) == ndim:
p_up = np.array(p_up, float)
else:
raise Exception('p_up must have length {:d}'.format(ndim))
if priors is None:
priors = [UniformPrior(p0, p1) for p0, p1 in zip(p_min, p_max)]
elif len(priors) != ndim:
raise Exception('priors must have length {:d}'.format(ndim))
for param, prior, p0, p1 in zip(model.input_names, priors, p_lo, p_up):
if p0 < prior.p_min:
raise Exception(f'starting guess for {param} (p_lo = {p0}) is outside prior (p_min = {prior.p_min})')
if p1 > prior.p_max:
raise Exception(f'starting guess for {param} (p_up = {p1}) is outside prior (p_max = {prior.p_max})')
def log_posterior(p):
log_prior = 0.
for prior, p_i in zip(priors, p):
log_prior += prior(p_i)
if np.isinf(log_prior):
return log_prior
log_likelihood = model.log_likelihood(lc, p, use_sigma=use_sigma, sigma_type=sigma_type)
return log_prior + log_likelihood
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_posterior)
starting_guesses = np.random.rand(nwalkers, ndim) * (p_up - p_lo) + p_lo
pos, _, _ = sampler.run_mcmc(starting_guesses, nsteps_burnin, progress=True, progress_kwargs={'desc': ' Burn-in'})
if show or save_plot_as:
fig, ax = plt.subplots(ndim, 2, figsize=(12., 2. * ndim))
ax1 = ax[:, 0]
for i in range(ndim):
ax1[i].plot(sampler.chain[:, :, i].T, 'k', alpha=0.2)
ax1[i].set_ylabel(model.axis_labels[i])
ax1[0].set_title('During Burn In')
ax1[-1].set_xlabel('Step Number')
sampler.reset()
sampler.run_mcmc(pos, nsteps, progress=True, progress_kwargs={'desc': 'Sampling'}, skip_initial_state_check=True)
if save_sampler_as:
np.save(save_sampler_as, sampler.flatchain)
print('saving sampler.flatchain as ' + save_sampler_as)
if show or save_plot_as:
ax2 = ax[:, 1]
for i in range(ndim):
ax2[i].plot(sampler.chain[:, :, i].T, 'k', alpha=0.2)
ax2[i].set_ylabel(model.axis_labels[i])
ax2[i].yaxis.set_label_position('right')
ax2[i].yaxis.tick_right()
ax2[0].set_title('After Burn In')
ax2[-1].set_xlabel('Step Number')
fig.tight_layout()
if save_plot_as:
print('saving chain plot as ' + save_plot_as)
fig.savefig(save_plot_as)
if show:
plt.show()
return sampler
[docs]def lightcurve_corner(lc, model, sampler_flatchain, model_kwargs=None,
num_models_to_plot=100, lcaxis_posn=(0.7, 0.55, 0.2, 0.4),
filter_spacing=1., tmin=None, tmax=None, t0_offset=None, save_plot_as='', ycol=None,
textsize='medium', param_textsize='large', use_sigma=False, xscale='linear',
filters_to_model=None):
"""
Plot the posterior distributions in a corner (pair) plot, with an inset showing the observed and model light curves.
Parameters
----------
lc : lightcurve_fitting.lightcurve.LC
Table of broadband photometry including columns "MJD", "mag", "dmag", "filter"
model : lightcurve_fitting.models.Model
The model that was fit to the light curve.
sampler_flatchain : array-like
2D array containing the aggregated MCMC chain histories
model_kwargs : dict, optional
DEPRECATED: Keyword arguments are now included in the model initialization
num_models_to_plot : int, optional
Number of model realizations to plot in the light curve inset. Default: 100
lcaxis_posn : tuple, optional
Light curve inset position and size specification in figure units: (left, bottom, width, height)
filter_spacing : float, optional
Spacing between filters in the light curve inset, in units determined by the order of magnitude of the
luminosities. Default: 1.
tmin, tmax : float, optional
Starting and ending times for which to plot the models in the light curve inset. Default: determined by the
time range of the observed light curve.
t0_offset : float, optional
Reference time for the explosion time in the corner plot and the horizontal axis of the light curve inset.
Default: the earliest explosion time in `sampler_flatchain`, rounded down.
save_plot_as : str, optional
Filename to which to save the resulting plot
ycol : str, optional
Quantity to plot on the light curve inset. Choices: "lum", "flux", or "absmag". Default: model.output_quantity
textsize : str, optional
Font size for the x- and y-axis labels, as well as the tick labels. Default: 'medium'
param_textsize : str, optional
Font size for the parameter text. Default: 'large'
use_sigma : bool, optional
If True, treat the last parameter as an intrinsic scatter parameter that does not get passed to the model
xscale : str, optional
Scale for the x-axis of the model plot. Choices: "linear" (default) or "log".
filters_to_model : list, set, optional
(Unique) list of filters for which to calculate the model light curves. Default: all filters in `lc`.
Returns
-------
fig : matplotlib.pyplot.Figure
Figure object containing the plot
corner_ax : array-like
Array of matplotlib.pyplot.Axes objects corresponding to the corner plot
ax : matplotlib.pyplot.Axes
Axes object for the light curve inset
"""
if model_kwargs is not None:
raise Exception(MODEL_KWARGS_WARNING)
if ycol is None:
ycol = model.output_quantity
plt.style.use(resource_filename('lightcurve_fitting', 'serif.mplstyle'))
if use_sigma and model.input_names[-1] != '\\sigma':
model.input_names.append('\\sigma')
model.units.append(u.dimensionless_unscaled)
sampler_flatchain_corner = sampler_flatchain.copy()
axis_labels_corner = model.axis_labels
for var in ['t_0', 't_\\mathrm{max}']:
if var in model.input_names:
i_t0 = model.input_names.index(var)
if t0_offset is None:
t0_offset = np.floor(sampler_flatchain_corner[:, i_t0].min())
if t0_offset != 0.:
sampler_flatchain_corner[:, i_t0] -= t0_offset
t0_offset_formatted = '{:f}'.format(t0_offset).rstrip('0').rstrip('.')
axis_labels_corner[i_t0] = f'${var} - {t0_offset_formatted}$ (d)'
fig = corner.corner(sampler_flatchain_corner, labels=axis_labels_corner, label_kwargs={'size': textsize})
corner_axes = np.array(fig.get_axes()).reshape(sampler_flatchain.shape[-1], sampler_flatchain.shape[-1])
for i in range(sampler_flatchain.shape[-1]):
corner_axes[i, 0].tick_params(labelsize=textsize)
corner_axes[-1, i].tick_params(labelsize=textsize)
for ax in np.diag(corner_axes):
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')
ax = fig.add_axes(lcaxis_posn)
lightcurve_model_plot(lc, model, sampler_flatchain, model_kwargs, num_models_to_plot, filter_spacing,
tmin, tmax, ycol, textsize, ax, t0_offset, use_sigma, xscale, filters_to_model)
paramtexts = format_credible_interval(sampler_flatchain, varnames=model.input_names, units=model.units)
fig.text(0.45, 0.95, '\n'.join(paramtexts), va='top', ha='center', fontdict={'size': param_textsize})
if save_plot_as:
fig.savefig(save_plot_as)
print('saving figure as ' + save_plot_as)
return fig, corner_axes, ax
[docs]def lightcurve_model_plot(lc, model, sampler_flatchain, model_kwargs=None, num_models_to_plot=100, filter_spacing=1.,
tmin=None, tmax=None, ycol=None, textsize='medium', ax=None, mjd_offset=None, use_sigma=False,
xscale='linear', filters_to_model=None):
"""
Plot the observed and model light curves.
Parameters
----------
lc : lightcurve_fitting.lightcurve.LC
Table of broadband photometry including columns "MJD", "mag", "dmag", "filter"
model : lightcurve_fitting.models.Model
The model that was fit to the light curve.
sampler_flatchain : array-like
2D array containing the aggregated MCMC chain histories
model_kwargs : dict, optional
DEPRECATED: Keyword arguments are now included in the model initialization.
num_models_to_plot : int, optional
Number of model realizations to plot in the light curve inset. Default: 100
filter_spacing : float, optional
Spacing between filters in the light curve inset, in units determined by the order of magnitude of the
luminosities. Default: 1.
tmin, tmax : float, optional
Starting and ending times for which to plot the models in the light curve inset. Default: determined by the
time range of the observed light curve.
ycol : str, optional
Quantity to plot on the light curve inset. Choices: "lum", "flux", or "absmag". Default: model.output_quantity
textsize : str, optional
Font size for the x- and y-axis labels, as well as the tick labels. Default: 'medium'
ax : matplotlib.pyplot.Axes
Axis on which to plot the light curves
mjd_offset : float, optional
Reference time on the horizontal axis of the light curve inset. Default: determined by the starting time of
the model light curve.
use_sigma : bool, optional
If True, treat the last parameter as an intrinsic scatter parameter that does not get passed to the model
xscale : str, optional
Scale for the x-axis. Choices: "linear" (default) or "log".
filters_to_model : list, set, optional
(Unique) list of filters for which to calculate the model light curves. Default: all filters in `lc`.
"""
if model_kwargs is not None:
raise Exception(MODEL_KWARGS_WARNING)
if ycol is None:
ycol = model.output_quantity
if ax is None:
ax = plt.axes()
if use_sigma and model.input_names[-1] != '\\sigma':
model.input_names.append('\\sigma')
model.units.append(u.dimensionless_unscaled)
choices = np.random.choice(sampler_flatchain.shape[0], num_models_to_plot)
ps = sampler_flatchain[choices].T
if tmin is None:
tmin = np.min(lc['MJD'])
if tmax is None:
tmax = np.max(lc['MJD'])
xfit = np.geomspace(tmin, tmax, 1000) if xscale == 'log' else np.linspace(tmin, tmax, 1000)
if filters_to_model is None:
ufilts = np.unique(lc['filter'])
else:
ufilts = np.array([filtdict[f] for f in filters_to_model])
if use_sigma:
y_fit = model(xfit, ufilts, *ps[:-1])
else:
y_fit = model(xfit, ufilts, *ps)
# for CompanionShocking, add SiFTO model as dashed lines
if isinstance(model, CompanionShocking):
y_fit1 = model.stretched_sifto(xfit, ufilts, *ps[3:5])
y_fit1[ufilts == filtdict['r']] *= ps[5]
y_fit1[ufilts == filtdict['i']] *= ps[6]
elif isinstance(model, BaseCompanionShocking):
y_fit1 = model.stretched_sifto(xfit, ufilts, *ps[3:7])
else:
y_fit1 = [None] * len(ufilts)
if mjd_offset is None:
mjd_offset = np.floor(tmin)
if ycol == 'lum':
dycol = 'dlum'
yscale = 10. ** np.round(np.log10(y_fit.max()))
ylabel = 'Luminosity $L_\\nu$ (10$^{{{:.0f}}}$ erg s$^{{-1}}$ Hz$^{{-1}}$) + Offset'.format(
np.log10(yscale) + 7) # W --> erg / s
elif ycol == 'absmag':
dycol = 'dmag'
yscale = 1.
ylabel = 'Absolute Magnitude + Offset'
y_fit, _ = flux2mag(y_fit, zp=[[[filt.M0]] for filt in ufilts])
if y_fit1[0] is not None:
y_fit1, _ = flux2mag(y_fit1, zp=[[[filt.M0]] for filt in ufilts])
ax.invert_yaxis()
elif ycol == 'flux':
dycol = 'dflux'
yscale = 10. ** np.round(np.log10(y_fit.max()))
ylabel = 'Flux $F_\\nu$ (10$^{{{:.0f}}}$ erg s$^{{-1}}$ m$^{{-2}}$ Hz$^{{-1}}$) + Offset'.format(
np.log10(yscale) + 7) # W --> erg / s
else:
raise ValueError(f'ycol="{ycol}" is not recognized. Use "lum", "absmag", "flux".')
if xscale == 'log':
ax.set_xscale('log')
ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%g'))
lc = lc.where(MJD_min=mjd_offset)
else:
lc = lc.copy()
lc['MJD'] -= mjd_offset
lc[ycol] /= yscale
lc[dycol] /= yscale
lc.plot(xcol='MJD', ycol=ycol, offset_factor=filter_spacing, appmag_axis=False, tight_layout=False)
plt.autoscale(False)
_, labels, _ = filter_legend(np.array(ufilts), filter_spacing)
for yfit, yfit1, filt, txt in zip(y_fit, y_fit1, ufilts, labels):
offset = -filt.offset * filter_spacing
ax.plot(xfit - mjd_offset, yfit / yscale + offset, color=filt.linecolor, alpha=0.05)
if yfit1 is not None:
ax.plot(xfit - mjd_offset, np.median(yfit1, axis=1) / yscale + offset, color=filt.linecolor, ls='--')
ax.text(1.03, yfit[-1, 0] / yscale + offset, txt, color=filt.textcolor, fontdict={'size': textsize},
ha='left', va='center', transform=ax.get_yaxis_transform())
ax.set_xlabel('MJD $-$ {:f}'.format(mjd_offset).rstrip('0').rstrip('.'), size=textsize)
ax.set_ylabel(ylabel, size=textsize)
ax.tick_params(labelsize=textsize)