Source code for superphot.optimize

import numpy as np
from sklearn.model_selection import ParameterGrid, ParameterSampler
from .classify import _validate_args, validate_classifier, calc_metrics
from .util import subplots_layout
from astropy.table import Table, vstack, join
from argparse import ArgumentParser
import os
import logging
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.backends.backend_pdf import PdfPages
from itertools import product
import json
from multiprocessing import Pool


[docs]def titlecase(x): """Capitalize the first letter of each word in a string (where words are separated by whitespace).""" words = x.split() cap_words = [word[0].upper() + word[1:] for word in words] title = ' '.join(cap_words) return title
[docs]def plot_hyperparameters_3d(t, ccols, xcol, ycol, zcol, cmap=None, cmin=None, cmax=None, figtitle=''): """ Plot 3D scatter plots of the metrics against the hyperparameters. Parameters ---------- t : astropy.table.Table Table of results from `test_hyperparameters`. ccols : list List of columns to plot as metrics. xcol, ycol, zcol : str Columns to plot on the x-, y-, and z-axes of the scatter plots. cmap : str, optional Name of the colormap to use to color the values in `ccols`. cmin, cmax : float, optional Data limits corresponding to the minimum and maximum colors in `cmap`. figtitle : str, optional Title text for the entire multipanel figure. """ nrows, ncols = subplots_layout(len(ccols)) gs = plt.GridSpec(nrows + 1, ncols, height_ratios=(8,) * nrows + (1,)) fig = plt.figure(figsize=(11., 8.5)) for (i, j), ccol in zip(product(range(nrows), range(ncols)), ccols): ax = fig.add_subplot(gs[i, j], projection='3d') ax.minorticks_off() ax.set_zticks([], minor=True) logz = np.log(t[zcol]) # 3D axes do not support log scaling cbar = ax.scatter(t[xcol], t[ycol], logz, marker='o', c=t[ccol], cmap=cmap, vmin=cmin, vmax=cmax, alpha=1.) vmin, vmax = np.percentile(t[ccol], [0., 100.]) ax.set_xlabel(xcol.split('__')[1]) ax.set_ylabel(ycol.split('__')[1]) ax.set_zlabel(zcol.split('__')[1]) ax.set_xticks(np.unique(t[xcol])) ax.set_yticks(np.unique(t[ycol])) ax.set_zticks(np.unique(logz)) ax.set_zticklabels(np.unique(t[zcol])) title = titlecase(ccol.replace('_', ' ')) if len(title) > 12: title = title[:-7] + '.' ax.set_title(f'${vmin:.2f}$ < {title} < ${vmax:.2f}$', size='medium') cax = fig.add_subplot(gs[-1, :]) cax = fig.colorbar(cbar, cax, orientation='horizontal') if cmin is None or cmax is None: cax.set_ticks(cbar.get_clim()) cax.set_ticklabels(['min', 'max']) fig.text(0.5, 0.99, figtitle, va='top', ha='center') fig.tight_layout(pad=3.) return fig
[docs]def plot_hyperparameters_with_diff(t, dcol=None, xcol=None, ycol=None, zcol=None, saveto=None, **criteria): """ Plot the metrics for one value of `dcol` and the difference in the metrics for the second value of `dcol`. Parameters ---------- t : astropy.table.Table Table of results from `test_hyperparameters`. dcol : str, optional Column to plot as a difference. Default: alphabetically first column starting with 'classifier' xcol, ycol, zcol : str, optional Columns to plot on the x-, y-, and z-axes of the scatter plots. Default: alphabetically 2nd-4th columns starting wtih 'classifier'. saveto : str, optional Save the plot to this filename. If None, the plot is displayed and not saved. criteria : dict, optional Plot only a subset of the data that matches these keyword-value pairs, and add these criteria to the title. If any keywords do not correspond to table columns, all rows are assumed to match. """ for key, val in criteria.items(): if key in t.colnames: t = t[t[key] == val] t.remove_column(key) kcols = sorted({col for col in t.colnames if col.startswith('classifier')} - {dcol, xcol, ycol, zcol})[::-1] if dcol is None: dcol = kcols.pop() if xcol is None: xcol = kcols.pop() if ycol is None: ycol = kcols.pop() if zcol is None: zcol = kcols.pop() t = t.group_by(dcol) if len(t.groups) != 2: raise ValueError('`dcol` must have exactly two possible values') t.sort([dcol, xcol, ycol, zcol]) t0, t1 = t.groups ccols = sorted(set(t.colnames) - {dcol, xcol, ycol, zcol}) j = join(t0, t1, keys=[xcol, ycol, zcol]) for ccol in ccols: j[ccol] = j[ccol + '_1'] - j[ccol + '_2'] criteria_strings = [f'{key} = {val}' for key, val in criteria.items()] title1 = ', '.join(criteria_strings + [f'{dcol} = {t.groups.keys[dcol][0]}']) fig1 = plot_hyperparameters_3d(t0, ccols, xcol, ycol, zcol, figtitle=title1) title2 = ', '.join(criteria_strings + [f'{dcol} = {t.groups.keys[dcol][0]}{t.groups.keys[dcol][1]}']) fig2 = plot_hyperparameters_3d(j, ccols, xcol, ycol, zcol, cmap='coolwarm_r', cmin=-0.2, cmax=0.2, figtitle=title2) if saveto is None: plt.show() else: with PdfPages(saveto) as pdf: pdf.savefig(fig1) pdf.savefig(fig2) plt.close(fig1) plt.close(fig2)
def _plot_hyperparameters_from_file(): parser = ArgumentParser() parser.add_argument('results', help='Table of results from superphot-optimize.') parser.add_argument('--saveto', help='Filename to which to save the plot.') parser.add_argument('--criteria', nargs='+', help='Subset of data to plot and/or criteria to add to title.' 'Format: kwd1=val1 kwd2=val2 etc.') args = parser.parse_args() t = Table.read(args.results, format='ascii') dcol, xcol, ycol, zcol = sorted([col for col in t.colnames if col.startswith('classifier__')]) criteria = {criterion.split('=')[0]: criterion.split('=')[1] for criterion in args.criteria} plot_hyperparameters_with_diff(t, dcol, xcol, ycol, zcol, saveto=args.saveto, **criteria)
[docs]class ParameterOptimizer: """ Class containing the pipeline and data sets for hyperparameter optimization Parameters ---------- pipeline : sklearn.pipeline.Pipeline or imblearn.pipeline.Pipeline The full classification pipeline, including rescaling, resampling, and classification. validation_data : astropy.table.Table Astropy table containing the validation data. Must have a 'features' column. train_data : astropy.table.Table Astropy table containing the training data. Must have a 'features' column and a 'type' column. Attributes ---------- pipeline : sklearn.pipeline.Pipeline or imblearn.pipeline.Pipeline The full classification pipeline, including rescaling, resampling, and classification. validation_data : astropy.table.Table Astropy table containing the validation data. Must have a 'features' column. train_data : astropy.table.Table Astropy table containing the training data. Must have a 'features' column and a 'type' column. """ def __init__(self, pipeline, train_data, validation_data): self.pipeline = pipeline self.train_data = train_data self.validation_data = validation_data
[docs] def test_hyperparams(self, param_set): """ Validates the pipeline for a set of hyperparameters. Measures F1 score and accuracy, as well as completeness and purity for each class. Parameters ---------- param_set : dict A dictionary containing keywords that match the parameters of `pipeline` and values to which to set them. Returns ------- param_set : dict The input `param_set` with the metrics added to the dictionary. These are also saved to a JSON file. """ try: self.pipeline.set_params(classifier__n_jobs=1, **param_set) results = validate_classifier(self.pipeline, self.train_data, self.validation_data) param_set = calc_metrics(results, param_set, save=True) except Exception as e: logging.error(f'Problem testing {param_set}:\n{e}') return param_set
def _main(): parser = ArgumentParser() parser.add_argument('param_dist', help='JSON-encoded parameter grid/distribution to test.') parser.add_argument('pipeline', help='Filename of the pickled pipeline object.') parser.add_argument('validation_data', help='Filename of the metadata table for the validation set.') parser.add_argument('--train-data', help='Filename of the metadata table for the training set, if different than' 'the validation set.') parser.add_argument('-i', '--n-iter', type=int, help='Number of hyperparameter combinations to try. Default: all.') parser.add_argument('-j', '--n-jobs', type=int, help='Number of parallel processes to use. Default: 1.') parser.add_argument('--saveto', default='hyperparameters.txt', help='Filename to which to write the results.') args = parser.parse_args() pipeline, train_data, validation_data = _validate_args(args) optimizer = ParameterOptimizer(pipeline, train_data, validation_data) with open(args.param_dist, 'r') as f: param_distributions = json.load(f) if args.n_iter is None: ps = ParameterGrid(param_distributions) else: ps = ParameterSampler(param_distributions, n_iter=args.n_iter) logging.info(f'Testing {len(ps):d} combinations...') if args.n_jobs is None: rows = [optimizer.test_hyperparams(param_set) for param_set in ps] else: with Pool(args.n_jobs) as p: rows = p.map(optimizer.test_hyperparams, ps) tfinal = Table(rows) for i, snclass in enumerate(pipeline.classes_): tfinal[snclass + ' completeness'] = tfinal['completeness'][:, i] tfinal[snclass + ' purity'] = tfinal['purity'][:, i] tfinal.remove_columns(['completeness', 'purity']) if os.path.exists(args.saveto): logging.warning(f'Appending results to {args.saveto}') tprev = Table.read(args.saveto, format='ascii') tfinal = vstack([tprev, tfinal]) else: logging.info(f'Writing results to {args.saveto}') tfinal.write(args.saveto, format='ascii.fixed_width_two_line', overwrite=True) plot_hyperparameters_with_diff(tfinal, saveto=args.saveto.replace('.txt', '.pdf'))