Source code for superphot.util

import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table, hstack, join
from astropy.stats import mad_std, sigma_clip
import logging

filter_colors = {'g': '#00CCFF', 'r': '#FF7D00', 'i': '#90002C', 'z': '#000000', 'y': 'y',
                 'U': '#3C0072', 'B': '#0057FF', 'V': '#79FF00', 'R': '#FF7000', 'I': '#66000B'}
meta_columns = ['filename', 'type', 'MWEBV', 'redshift', 'prediction', 'confidence']
CLASS_KEYWORDS = {'type': 'Spectroscopic Classification', 'prediction': 'Photometric Classification'}
plt.rcParams['xtick.minor.visible'] = True
plt.rcParams['ytick.minor.visible'] = True

[docs]def load_data(meta_file, data_file=None): """ Read input from a text file (the metadata table) and a Numpy file (the features) and return as an Astropy table. Parameters ---------- meta_file : str Filename of the input metadata table. Must in an ASCII format readable by Astropy. data_file : str, optional Filename where the features are saved. Must be in Numpy binary format. If None, replace the extension of `meta_file` with .npz. Returns ------- data_table : astropy.table.Table Table containing the metadata along with a 'features' column. """ if data_file is None: meta_file_parts = meta_file.split('.') meta_file_parts[-1] = 'npz' data_file = '.'.join(meta_file_parts) t =, format='ascii', fill_values=('', '')) if 'type' in t.colnames: t['type'] =['type']) stored = np.load(data_file) meta_keys = set(stored.keys()) & {'filters', 'ndraws', 'paramnames', 'featnames'} column_keys = set(stored.keys()) - meta_keys t_stored = Table({key: stored[key] for key in column_keys}) if set(t.colnames) & column_keys: data_table = join(t, t_stored) else: data_table = hstack([t[np.repeat(np.arange(len(t)), stored['ndraws'])], t_stored]) for col in data_table.colnames: if data_table[col].dtype.type is np.str_: data_table[col].fill_value = '' for key in meta_keys: data_table.meta[key] = stored[key]'data loaded from {meta_file} and {data_file}') return data_table
[docs]def subplots_layout(n): """ Calculate the number of rows and columns for a multi-panel plot, staying as close to a square as possible. Parameters ---------- n : int The number of subplots required. Returns ------- nrows, ncols : int The number of rows and columns in the layout. """ nrows = round(n ** 0.5) ncols = -(-n // nrows) return nrows, ncols
[docs]def plot_histograms(data_table, colname, class_kwd='type', var_kwd=None, row_kwd=None, saveto=None): """ Plot a grid of histograms of the column `colname` of `data_table`, grouped by the column `groupby`. Parameters ---------- data_table : astropy.table.Table Data table containing the columns `colname` and `groupby` for each supernova. colname : str Column name of `data_table` to plot (e.g., 'params' or 'features'). class_kwd : str, optional Column name of `data_table` to group by before plotting (e.g., 'type' or 'prediction'). Default: 'type'. var_kwd : str, optional Keyword in `data_table.meta` containing the parameter names to list on the x-axes. Default: no labels. row_kwd : str, optional Keyword in `data_table.meta` containing labels for the leftmost y-axes. saveto : str, optional Filename to which to save the plot. Default: display the plot instead of saving it. """ _, nrows, ncols = data_table[colname].shape if class_kwd: data_table = data_table.group_by(class_kwd) data_table.groups.keys['patch'] = None else: data_table = data_table.group_by(np.ones(len(data_table))) ngroups = len(data_table.groups) fig, axarr = plt.subplots(nrows, ncols, sharex='col', squeeze=False) for j in range(ncols): xlims = [] for i in range(nrows): ylims = [] for k in range(ngroups): histdata = data_table.groups[k][colname][:, i, j] histrange = np.percentile(histdata, (5., 95.)) n, b, p = axarr[i, j].hist(histdata, range=histrange, density=True, histtype='step') if class_kwd: data_table.groups.keys['patch'][k] = p[0] xlims.append(b) ylims.append(n) axarr[i, j].set_ylim(0., 1.05 * np.max(ylims)) axarr[i, j].set_yticks([]) xlims = sigma_clip(xlims, stdfunc=mad_std, masked=False) xmin, xmax = np.percentile(xlims, (0., 100.)) axarr[-1, j].set_xlim(1.05 * xmin - 0.05 * xmax, 1.05 * xmax - 0.05 * xmin) axarr[-1, j].xaxis.set_major_locator(plt.MaxNLocator(2)) if class_kwd: fig.legend(data_table.groups.keys['patch'], data_table.groups.keys[class_kwd], loc='upper center', ncol=ngroups, title=CLASS_KEYWORDS.get(class_kwd, class_kwd)) if var_kwd is not None: for ax, var in zip(axarr[-1], data_table.meta[var_kwd]): ax.set_xlabel(var, size='small') ax.tick_params(labelsize='small') if 'mag' in var.lower(): ax.invert_xaxis() if row_kwd is not None: for ax, filt in zip(axarr[:, 0], data_table.meta[row_kwd]): ax.set_ylabel(filt, rotation=0, va='center') fig.tight_layout(h_pad=0., w_pad=0., rect=(0., 0., 1., 0.9 if class_kwd else 1.)) if saveto is None: else: fig.savefig(saveto) plt.close(fig)