import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import logging
from astropy.table import Table, join
from astropy.io.ascii import masked
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
from sklearn.inspection import permutation_importance
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.utils._docstring import Substitution, _random_state_docstring
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline
import pickle
from .util import meta_columns, plot_histograms, filter_colors, load_data, CLASS_KEYWORDS
import itertools
from tqdm import tqdm
from argparse import ArgumentParser
import json
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)
[docs]def plot_confusion_matrix(confusion_matrix, classes, cmap='Blues', purity=False, title='',
xlabel='Photometric Classification', ylabel='Spectroscopic Classification', ax=None):
"""
Plot a confusion matrix with each cell labeled by its fraction and absolute number.
Based on tutorial: https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
Parameters
----------
confusion_matrix : array-like
The confusion matrix as a square array of integers.
classes : list
List of class labels for the axes of the confusion matrix.
cmap : str, optional
Name of a Matplotlib colormap to color the matrix.
purity : bool, optional
If False (default), aggregate by row (spec. class). If True, aggregate by column (phot. class).
title : str, optional
Text to go above the plot. Default: no title.
xlabel, ylabel : str, optional
Labels for the x- and y-axes. Default: "Spectroscopic Classification" and "Photometric Classification".
ax : matplotlib.pyplot.axes, optional
Axis on which to plot the confusion matrix. Default: new axis.
"""
n_per_true_class = confusion_matrix.sum(axis=1)
n_per_pred_class = confusion_matrix.sum(axis=0)
if purity:
cm = confusion_matrix / n_per_pred_class[np.newaxis, :]
else:
cm = confusion_matrix / n_per_true_class[:, np.newaxis]
if ax is None:
ax = plt.axes()
ax.imshow(cm, interpolation='nearest', cmap=cmap, aspect='equal')
ax.set_title(title)
nclasses = len(classes)
ax.set_xticks(range(nclasses))
ax.set_yticks(range(nclasses))
ax.set_xticklabels(['{}\n({:.0f})'.format(label, n) for label, n in zip(classes, n_per_pred_class)])
ax.set_yticklabels(['{}\n({:.0f})'.format(label, n) for label, n in zip(classes, n_per_true_class)])
ax.set_xticks([], minor=True)
ax.set_yticks([], minor=True)
ax.set_ylim(nclasses - 0.5, -0.5)
thresh = np.nanmax(cm) / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, f'{cm[i, j]:.2f}\n({confusion_matrix[i, j]:.0f})', ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
[docs]@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring.replace('dict or callable', 'dict, callable or int'),
random_state=_random_state_docstring)
class MultivariateGaussian(BaseOverSampler):
"""Class to perform over-sampling using a multivariate Gaussian (``numpy.random.multivariate_normal``).
Parameters
----------
{sampling_strategy}
- When ``int``, it corresponds to the total number of samples in each
class (including the real samples). Can be used to oversample even
the majority class. If ``sampling_strategy`` is smaller than the
existing size of a class, that class will not be oversampled and
the classes may not be balanced.
{random_state}
"""
def __init__(self, sampling_strategy='all', random_state=None):
self.random_state = random_state
self.mean_ = dict()
self.cov_ = dict()
if isinstance(sampling_strategy, int):
self.samples_per_class = sampling_strategy
sampling_strategy = 'all'
else:
self.samples_per_class = None
super().__init__(sampling_strategy=sampling_strategy)
def _fit_resample(self, X, y):
self.fit(X, y)
X_resampled = X.copy()
y_resampled = y.copy()
for class_sample, n_samples in self.sampling_strategy_.items():
X_class = X[y == class_sample]
self.mean_[class_sample] = np.mean(X_class, axis=0)
self.cov_[class_sample] = np.cov(X_class, rowvar=False)
if self.samples_per_class is not None:
n_samples = self.samples_per_class - X_class.shape[0]
if n_samples <= 0:
continue
self.rs_ = check_random_state(self.random_state)
X_new = self.rs_.multivariate_normal(self.mean_[class_sample], self.cov_[class_sample], n_samples)
y_new = np.repeat(class_sample, n_samples)
X_resampled = np.vstack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))
return X_resampled, y_resampled
[docs] def more_samples(self, n_samples):
"""Draw more samples from the same distribution of an already fitted sampler."""
if not self.mean_ or not self.cov_:
raise Exception('Mean and covariance not set. You must first run fit_resample(X, y).')
classes = sorted(self.sampling_strategy_.keys())
X = np.vstack([self.rs_.multivariate_normal(self.mean_[class_sample], self.cov_[class_sample], n_samples)
for class_sample in classes])
y = np.repeat(classes, n_samples)
return X, y
[docs]def train_classifier(pipeline, train_data):
"""
Train a classification pipeline on `test_data`.
Parameters
----------
pipeline : imblearn.pipeline.Pipeline
The full classification pipeline, including rescaling, resampling, and classification.
train_data : astropy.table.Table
Astropy table containing the test data. Must have a 'features' and a 'type' column.
"""
pipeline.fit(train_data['features'].reshape(len(train_data), -1), train_data['type'])
[docs]def classify(pipeline, test_data, aggregate=True):
"""
Use a trained classification pipeline to classify `test_data`.
Parameters
----------
pipeline : imblearn.pipeline.Pipeline
The full classification pipeline, including rescaling, resampling, and classification.
test_data : astropy.table.Table
Astropy table containing the test data. Must have a 'features' column.
aggregate : bool, optional
If True (default), average the probabilities for a given supernova across the multiple model light curves.
Returns
-------
results : astropy.table.Table
Astropy table containing the supernova metadata and classification probabilities for each supernova
"""
results = test_data.copy()
results['probabilities'] = pipeline.predict_proba(results['features'].reshape(len(results), -1))
if aggregate:
results = aggregate_probabilities(results)
results['prediction'] = pipeline.classes_[results['probabilities'].argmax(axis=1)]
results['confidence'] = results['probabilities'].max(axis=1)
return results
[docs]def mean_axis0(x, axis=0):
"""Equivalent to the numpy.mean function but with axis=0 by default."""
return x.mean(axis=axis)
[docs]def aggregate_probabilities(table):
"""
Average the classification probabilities for a given supernova across the multiple model light curves.
Parameters
----------
table : astropy.table.Table
Astropy table containing the metadata for a supernova and the classification probabilities ('probabilities')
Returns
-------
results : astropy.table.Table
Astropy table containing the supernova metadata and average classification probabilities for each supernova
"""
table = table[[col for col in table.colnames if col in meta_columns] + ['probabilities']]
grouped = table.filled().group_by(table.colnames[:-1])
results = grouped.groups.aggregate(mean_axis0)
if 'type' in results.colnames:
results['type'] = np.ma.array(results['type'])
return results
[docs]def validate_classifier(pipeline, train_data, test_data=None, aggregate=True):
"""
Validate the performance of a machine-learning classifier using leave-one-out cross-validation.
Parameters
----------
pipeline : imblearn.pipeline.Pipeline
The full classification pipeline, including rescaling, resampling, and classification.
train_data : astropy.table.Table
Astropy table containing the training data. Must have a 'features' column and a 'type' column.
test_data : astropy.table.Table, optional
Astropy table containing the test data. Must have a 'features' column to which to apply the trained classifier.
If None, use the training data itself for validation.
aggregate : bool, optional
If True (default), average the probabilities for a given supernova across the multiple model light curves.
Returns
-------
results : astropy.table.Table
Astropy table containing the supernova metadata and classification probabilities for each supernova
"""
classes, n_per_class = np.unique(train_data['type'], return_counts=True)
if np.any(n_per_class <= train_data.meta['ndraws']):
raise ValueError('Training data must have at least two samples per class for cross-validation')
if test_data is None:
test_data = train_data
train_classifier(pipeline, train_data)
test_data['probabilities'] = pipeline.predict_proba(test_data['features'].reshape(len(test_data), -1))
for filename in tqdm(np.unique(train_data['filename']), desc='Cross-validation'):
train_index = train_data['filename'] != filename
test_index = test_data['filename'] == filename
train_classifier(pipeline, train_data[train_index])
test_features = test_data['features'][test_index].reshape(np.count_nonzero(test_index), -1)
test_data['probabilities'][test_index] = pipeline.predict_proba(test_features)
if aggregate:
test_data = aggregate_probabilities(test_data)
test_data['prediction'] = pipeline.classes_[test_data['probabilities'].argmax(axis=1)]
test_data['confidence'] = test_data['probabilities'].max(axis=1)
test_data['correct'] = test_data['prediction'] == test_data['type']
return test_data
[docs]def make_confusion_matrix(results, classes=None, p_min=0., saveto=None, purity=False, binary=False, title=None):
"""
Given a data table with classification probabilities, calculate and plot the confusion matrix.
Parameters
----------
results : astropy.table.Table
Astropy table containing the supernova metadata and classification probabilities (column name = 'probabilities')
classes : array-like, optional
Labels corresponding to the 'probabilities' column. If None, use the sorted entries in the 'type' column.
p_min : float, optional
Minimum confidence to be included in the confusion matrix. Default: include all samples.
saveto : str, optional
Save the plot to this filename. If None, the plot is displayed and not saved.
purity : bool, optional
If False (default), aggregate by row (true label). If True, aggregate by column (predicted label).
binary : bool, optional
If True, plot a SNIa vs non-SNIa (CCSN) binary confusion matrix.
title : str, optional
A title for the plot. If the plot is big enough, statistics ($N$, $A$, $F_1$) are appended in parentheses.
Default: 'Completeness' or 'Purity' depending on `purity`.
"""
results = results[~results['type'].mask]
if classes is None:
classes = np.unique(results['type'])
if binary:
results['type'] = ['SNIa' if sntype == 'SNIa' else 'CCSN' for sntype in results['type']]
SNIa_probs = results['probabilities'][:, np.where(classes == 'SNIa')[0][0]]
classes = np.array(['CCSN', 'SNIa'])
predicted_types = np.choose(np.round(SNIa_probs).astype(int), classes)
include = (SNIa_probs > p_min) | (SNIa_probs < 1. - p_min)
else:
predicted_types = classes[np.argmax(results['probabilities'], axis=1)]
include = results['probabilities'].max(axis=1) > p_min
cnf_matrix = confusion_matrix(results['type'][include], predicted_types[include])
if title is None:
title = 'Purity' if purity else 'Completeness'
size = (len(classes) + 1.) * 5. / 6.
if size > 3.: # only add stats to title if figure is big enough
accuracy = accuracy_score(results['type'][include], predicted_types[include])
f1 = f1_score(results['type'][include], predicted_types[include], average='macro')
title += f' ($N={include.sum():d}$, $A={accuracy:.2f}$, $F_1={f1:.2f}$)'
xlabel = 'Photometric Classification'
ylabel = 'Spectroscopic Classification'
else:
xlabel = 'Phot. Class.'
ylabel = 'Spec. Class.'
fig = plt.figure(figsize=(size, size))
plot_confusion_matrix(cnf_matrix, classes, purity=purity, title=title, xlabel=xlabel, ylabel=ylabel)
fig.tight_layout()
if saveto is None:
plt.show()
else:
fig.savefig(saveto)
[docs]def load_results(filename):
results = Table.read(filename, format='ascii')
if 'type' in results.colnames:
results['type'] = np.ma.array(results['type'])
classes = np.array([col for col in results.colnames if col not in meta_columns])
results['probabilities'] = np.stack([results[sntype].data for sntype in classes]).T
results.meta['classes'] = classes
results.remove_columns(classes)
return results
[docs]def write_results(test_data, classes, filename, max_lines=None, latex=False, latex_title='Classification Results',
latex_label='tab:results'):
"""
Write the classification results to a text file.
Parameters
----------
test_data : astropy.table.Table
Astropy table containing the supernova metadata and the classification probabilities ('probabilities').
classes : list
The labels that correspond to the columns in 'probabilities'
filename : str
Name of the output file
max_lines : int, optional
Maximum number of table rows to write to the file
latex : bool, optional
If False (default), write in the Astropy 'ascii.fixed_width_two_line' format. If True, write in the Astropy
'ascii.aastex' format and add fancy table headers, etc.
latex_title : str, optional
Table caption if written in AASTeX format. Default: 'Classification Results'
latex_label : str, optional
LaTeX label if written in AASTeX format. Default: 'tab:results'
"""
test_data = test_data[:max_lines]
output = test_data[[col for col in test_data.colnames if col in meta_columns]]
output['MWEBV'].format = '%.4f'
output['redshift'].format = '%.4f'
output['confidence'].format = '%.3f'
for i, classname in enumerate(classes):
col = f'$p_\\mathrm{{{classname}}}$' if latex else classname
output[col] = test_data['probabilities'][:, i]
output[col].format = '%.3f'
if latex:
# latex formatting for data
output['filename'] = [name.replace('_', '\\_') for name in output['filename']]
if 'type' in output.colnames:
output['type'] = [classname.replace("SNI", "SN~I") for classname in output['type']]
output['prediction'] = [classname.replace("SNI", "SN~I") for classname in output['prediction']]
# AASTeX header and footer
latexdict = {'tabletype': 'deluxetable*'}
if latex_title:
if latex_label:
latex_title += f'\\label{{{latex_label}}}'
latexdict['caption'] = latex_title
if max_lines is not None:
latexdict['tablefoot'] = '\\tablecomments{The full table is available in machine-readable form.}'
# human-readable column headers
column_headers = {
'filename': 'Transient Name',
'redshift': 'Redshift',
'MWEBV': '$E(B-V)$',
'type': 'Spec. Class.',
'prediction': 'Phot. Class.',
'confidence': 'Confidence',
}
for column in output.colnames:
output.rename_column(column, column_headers.get(column, column))
output.write(filename, format='ascii.aastex', overwrite=True, latexdict=latexdict,
fill_values=[(masked, '\\nodata'), ('inf', '\\infty')])
else:
output.write(filename, format='ascii.fixed_width_two_line', overwrite=True)
logging.info(f'classification results saved to {filename}')
[docs]def plot_feature_importance(pipeline, train_data, width=0.8, nsamples=1000, saveto=None):
"""
Plot a bar chart of feature importance using mean decrease in impurity, with permutation importances overplotted.
Mean decrease in impurity is assumed to be stored in `pipeline.feature_importances_`. If the classifier does not
have this attribute (e.g., SVM, MLP), only permutation importance is calculated.
Parameters
----------
pipeline : sklearn.pipeline.Pipeline or imblearn.pipeline.Pipeline
The trained pipeline for which to plot feature importances. Steps should be named 'classifier' and 'sampler'.
train_data : astropy.table.Table
Data table containing 'features' and 'type' for training when calculating permutation importances. Must also
include 'featnames' and 'filters' in `train_data.meta`.
width : float, optional
Total width of the bars in units of the separation between bars. Default: 0.8.
nsamples : int, optional
Number of samples to draw for the fake validation data set. Default: 1000.
saveto : str, optional
Filename to which to save the plot. Default: show instead of saving.
"""
logging.info('calculating feature importance')
featnames = train_data.meta['featnames']
filters = train_data.meta['filters']
featnames = np.append(featnames, 'Random Number')
xoff = 0.5 * width / filters.size * np.linspace(1 - filters.size, filters.size - 1, filters.size)
xranges = np.arange(featnames.size) + xoff[:, np.newaxis]
random_feature_train = np.random.random(len(train_data))
random_feature_validate = np.random.random(nsamples * pipeline.classes_.size)
has_mdi = hasattr(pipeline.named_steps['classifier'], 'feature_importances_')
fig, ax = plt.subplots(1, 1)
for real_features, xrange, fltr in zip(np.moveaxis(train_data['features'], 1, 0), xranges, filters):
X = np.hstack([real_features, random_feature_train[:, np.newaxis]])
pipeline.fit(X, train_data['type'])
X_val, y_val = pipeline.named_steps['sampler'].more_samples(nsamples)
X_val[:, -1] = random_feature_validate
result = permutation_importance(pipeline.named_steps['classifier'], X_val, y_val, n_jobs=-1)
importance = result.importances_mean
std = result.importances_std
c = filter_colors.get(fltr)
if has_mdi:
importance0 = pipeline.named_steps['classifier'].feature_importances_
ax.barh(xrange[:-1], importance0[:-1], width / filters.size, color=c)
ax.errorbar(importance, xrange, xerr=std, fmt='o', color=c, mfc='w')
if has_mdi:
proxy_artists = [Patch(color='gray'), ax.errorbar([], [], xerr=[], fmt='o', color='gray', mfc='w')]
ax.legend(proxy_artists, ['Mean Decrease in Impurity', 'Permutation Importance'], loc='best')
for i, featname in enumerate(featnames):
ax.text(-0.03, i, featname, ha='right', va='center', transform=ax.get_yaxis_transform())
ax.set_yticks([])
ax.set_yticks(xranges.flatten(), minor=True)
ax.set_yticklabels(np.repeat(filters, featnames.size), minor=True, size='x-small', ha='center')
ax.invert_yaxis()
ax.set_xlabel('Feature Importance')
ax.set_xlim(0., ax.get_xlim()[1])
fig.tight_layout()
if saveto is None:
plt.show()
else:
fig.savefig(saveto)
plt.close(fig)
[docs]def cumhist(data, reverse=False, mark=None, ax=None, **kwargs):
"""
Plot a cumulative histogram of `data`, optionally with certain indices marked with an x.
Parameters
----------
data : array-like
Data to include in the histogram
reverse : bool, optional
If False (default), the histogram increases with increasing `data`. If True, it decreases with increasing `data`
mark : array-like, optional
An array of indices to mark with an x
ax : matplotlib.pyplot.axes, optional
Axis on which to plot the confusion matrix. Default: current axis.
kwargs : dict, optional
Keyword arguments to be passed to `matplotlib.pyplot.step`
Returns
-------
p : list
The list of `matplotlib.lines.Line2D` objects returned by `matplotlib.pyplot.step`
"""
if mark is None:
mark = np.zeros(len(data), bool)
if ax is None:
ax = plt.gca()
i = np.argsort(data)
x = data[i]
mark = mark[i]
x = np.append(x, x[-1])
y = np.linspace(0., 1., x.size)
if reverse:
y = y[::-1]
p = ax.step(x, y, **kwargs)
ax.scatter(data[i][mark], (y[:-1] + 0.5 * np.diff(y))[mark], marker='x')
return p
[docs]def plot_results_by_number(results, xval='confidence', class_kwd='prediction', title=None, saveto=None):
"""
Plot cumulative histograms of the results for each class against a specified table column.
If `results` contains the column 'correct', incorrect classifications will be marked with an x.
Parameters
----------
results : astropy.table.Table
Table of classification results. Must contain columns 'type'/'prediction' and the column specified by `xval`.
xval : str, optional
Table column to use as the horizontal axis of the histogram. Default: 'confidence'.
class_kwd : str, optional
Table column to use as the class grouping. Default: 'prediction'.
title : str, optional
Title for the plot. Default: "Training/Test Set, Grouped by {class_kwd}", where the first word is determined
by the presence of the column 'correct' in `results`.
saveto : str, optional
Save the plot to this filename. If None, the plot is displayed and not saved.
"""
if 'correct' in results.colnames:
label = '{ngroup:d} {snclass}, {correct:d} correct'
if title is None:
title = f'Training Set, Grouped by {CLASS_KEYWORDS.get(class_kwd, class_kwd)}'
else:
label = '{ngroup:d} {snclass}'
if title is None:
title = f'Test Set, Grouped by {CLASS_KEYWORDS.get(class_kwd, class_kwd)}'
grouped = results.group_by(class_kwd)
fig = plt.figure()
for group, snclass in zip(grouped.groups, grouped.groups.keys[class_kwd]):
correct = group['correct'] if 'correct' in results.colnames else np.ones(len(group), bool)
grouplabel = label.format(snclass=snclass, ngroup=len(group), correct=correct.sum())
cumhist(group[xval], lw=1, label=grouplabel, mark=~correct)
if 'correct' in results.colnames:
plt.plot([], [], 'kx', label='incorrect')
plt.legend(loc='best', frameon=False)
plt.tick_params(labelsize='large')
plt.xlabel(xval.title(), size='large')
plt.ylabel('Cumulative Fraction', size='large')
plt.ylim(0, 1)
plt.title(title)
fig.tight_layout()
if saveto is None:
plt.show()
else:
fig.savefig(saveto)
[docs]def calc_metrics(results, param_set, save=False):
"""
Calculate completeness, purity, accuracy, and F1 score for a table of validation results.
The metrics are returned in a dictionary and saved in a json file.
Parameters
----------
results : astropy.table.Table
Astropy table containing the results. Must have columns 'type' and 'prediction'.
param_set : dict
A dictionary containing metadata to store along with the metrics.
save : bool, optional
If True, save the results to a json file in addition to returning the results. Default: False
Returns
-------
param_set : dict
A dictionary containing the input metadata and the calculated metrics.
"""
param_names = sorted(param_set.keys())
classes = results.meta.get('classes', np.unique(results['prediction']))
cnf_matrix = confusion_matrix(results['type'], results['prediction'], labels=classes)
correct = np.diag(cnf_matrix)
n_per_spec_class = cnf_matrix.sum(axis=1)
n_per_phot_class = cnf_matrix.sum(axis=0)
param_set['completeness'] = list(correct / n_per_spec_class)
param_set['purity'] = list(correct / n_per_phot_class)
param_set['accuracy'] = accuracy_score(results['type'], results['prediction'])
param_set['f1_score'] = f1_score(results['type'], results['prediction'], average='macro', labels=classes)
if save:
filename = '_'.join([str(param_set[key]) for key in param_names]) + '.json'
with open(filename, 'w') as f:
json.dump(param_set, f)
return param_set
[docs]def plot_metrics_by_number(validation, xval='confidence', classes=None, saveto=None):
"""
Plot completeness, purity, accuracy, F1 score, and fractions remaining as a function of confidence threshold.
Parameters
----------
validation : astropy.table.Table
Astropy table containing the results. Must have columns 'type', 'prediction', 'probabilities', and 'confidence'.
xval : str, optional
Table column to use as the horizontal axis of the plot. Default: 'confidence'.
classes : array-like, optional
The classes for which to calculate completeness and purity. Default: all classes in the 'type' column.
saveto : str, optional
Save the plot to this filename. If None, the plot is displayed and not saved.
"""
if classes is None:
classes = np.unique(validation['type'])
validation.sort(xval)
metrics = Table([calc_metrics(validation[i:], {xval: validation[xval][i]}) for i in range(len(validation))])
ccsne = validation[validation['type'] != 'SNIa']
fig, (ax1, ax2) = plt.subplots(2, sharex=True, sharey=True, figsize=(6., 8.))
lines1 = ax1.step(metrics[xval], metrics['completeness'])
ax2.step(metrics[xval], metrics['purity'])
for ax in [ax1, ax2]:
lines2 = ax.step(metrics[xval], metrics['accuracy'], 'k-')
lines2 += ax.step(metrics[xval], metrics['f1_score'], 'k--')
lines2 += cumhist(validation[xval], reverse=True, ax=ax, color='k', ls='-.')
lines2 += cumhist(ccsne[xval], reverse=True, ax=ax, color='k', ls=':')
ax.grid(alpha=0.3)
ax.tick_params(labelsize='large')
fig.legend(lines2, ['accuracy', '$F_1$', 'frac.', 'CCSN frac.'], ncol=len(lines2), loc='upper center',
bbox_to_anchor=(0.5, 0.975), frameon=False)
ax1.legend(lines1, classes, ncol=5, loc='upper center', bbox_to_anchor=(0.5, 1.15), frameon=False)
ax1.set_ylabel('Completeness', size='large')
ax2.set_ylabel('Purity', size='large')
ax2.set_ylim(0, 1)
ax2.set_xlabel(f'Minimum {xval}', size='large')
fig.tight_layout()
if saveto is None:
plt.show()
else:
fig.savefig(saveto)
[docs]def bar_plot(vresults, tresults, saveto=None):
"""
Make a stacked bar plot showing the class breakdown in the training set compared to the test set.
Parameters
----------
vresults : astropy.table.Table
Astropy table containing the training data. Must have a 'type' column and a 'prediction' column.
tresults : astropy.table.Table
Astropy table containing the test data. Must have a 'prediction' column.
saveto : str, optional
Save the plot to this filename. If None, the plot is displayed and not saved.
"""
labels, n_per_true_class = np.unique(vresults['type'], return_counts=True)
labels_pred, n_per_pred_class = np.unique(tresults['prediction'], return_counts=True)
if np.any(labels_pred != labels):
raise ValueError('photometric and spectroscopic class labels do not match')
purity = confusion_matrix(vresults['type'], vresults['prediction'], normalize='pred')
corrected = purity @ n_per_pred_class
names = ['Spectroscopically\nClassified', 'Photometrically\nClassified', 'Phot. Class.\n(Corrected)']
rawcounts = np.transpose([n_per_true_class, n_per_pred_class, corrected])
fractions = rawcounts / rawcounts.sum(axis=0)
cumulative_fractions = fractions.cumsum(axis=0)
fig = plt.figure()
ax = plt.axes()
for counts, fracs, cumfracs, label in zip(rawcounts, fractions, cumulative_fractions, labels):
plt.bar(names, -fracs, bottom=cumfracs)
heights = cumfracs - fracs / 2.
for count, frac, name, height in zip(counts, fracs, names, heights):
if height < 0.05:
h = 0.005
va = 'top'
elif height > 0.95:
h = 0.995
va = 'bottom'
else:
h = height
va = 'center'
ax.text(name, h, f'{frac:.2f} ({count:.0f})', ha='center', va=va, color='w')
ax.text(2.5, heights[-1], label, ha='left', va='center')
ax.set_ylim(1., 0.)
for spine in ax.spines.values():
spine.set_visible(False)
ax.tick_params(axis='both', which='both', labelbottom=False, labelleft=False, bottom=False, left=False,
labeltop=True, labelsize='large')
fig.tight_layout()
if saveto is None:
plt.show()
else:
fig.savefig(saveto)
def _bar_plot_from_file():
parser = ArgumentParser()
parser.add_argument('validation_results', help='Filename of the validation results.')
parser.add_argument('test_results', help='Filename of the classification results.')
parser.add_argument('--saveto', help='Filename to which to save the bar plot.')
args = parser.parse_args()
vresults = load_results(args.validation_results)
tresults = load_results(args.test_results)
bar_plot(vresults, tresults, args.saveto)
def _plot_confusion_matrix_from_file():
parser = ArgumentParser()
parser.add_argument('filename', type=str, help='Filename containing the table of classification results.')
parser.add_argument('--pmin', type=float, default=0.,
help='Minimum confidence to be included in the confusion matrix.')
parser.add_argument('--saveto', type=str, help='If provided, save the confusion matrix to this file.')
parser.add_argument('--purity', action='store_true', help='Aggregate by column instead of by row.')
parser.add_argument('--binary', action='store_true', help='Plot a SNIa vs non-SNIa (CCSN) binary confusion matrix.')
args = parser.parse_args()
results = load_results(args.filename)
make_confusion_matrix(results, p_min=args.pmin, saveto=args.saveto, purity=args.purity, binary=args.binary)
def _train():
parser = ArgumentParser()
parser.add_argument('train_data', help='Filename of the metadata table for the training set.')
parser.add_argument('--classifier', choices=['rf', 'svm', 'mlp'], default='rf', help='The classification algorithm '
'to use. Current choices are "rf" (random forest; default), "svm" (support vector machine), or '
'"mlp" (multilayer perceptron).')
parser.add_argument('--sampler', choices=['mvg', 'smote'], default='mvg', help='The resampling algorithm to use. '
'Current choices are "mvg" (multivariate Gaussian; default) or "smote" (synthetic minority '
'oversampling technique).')
parser.add_argument('--random-state', type=int, help='Seed for the random number generator (for reproducibility).')
parser.add_argument('--output', default='pipeline.pickle',
help='Filename to which to save the pickled classification pipeline.')
args = parser.parse_args()
logging.info('started training')
train_data = load_data(args.train_data)
if train_data['type'].mask.any():
raise ValueError('training data is missing values in the "type" column')
if args.classifier == 'rf':
clf = RandomForestClassifier(criterion='entropy', max_features=5, n_jobs=-1, random_state=args.random_state)
elif args.classifier == 'svm':
clf = SVC(C=1000, gamma=0.1, probability=True, random_state=args.random_state)
elif args.classifier == 'mlp':
clf = MLPClassifier(hidden_layer_sizes=(10, 5), alpha=1e-5, early_stopping=True, random_state=args.random_state)
else:
raise NotImplementedError(f'{args.classifier} is not a recognized classifier type')
if args.sampler == 'mvg':
sampler = MultivariateGaussian(sampling_strategy=1000, random_state=args.random_state)
elif args.sampler == 'smote':
sampler = SMOTE(random_state=args.random_state)
else:
raise NotImplementedError(f'{args.sampler} is not a recognized sampler type')
pipeline = Pipeline([('scaler', StandardScaler()), ('sampler', sampler), ('classifier', clf)])
train_classifier(pipeline, train_data)
with open(args.output, 'wb') as f:
pickle.dump(pipeline, f)
logging.info('finished training')
def _classify():
parser = ArgumentParser()
parser.add_argument('pipeline', help='Filename of the pickled classification pipeline.')
parser.add_argument('test_data', help='Filename of the metadata table for the test set.')
parser.add_argument('--output', default='test_data', help='Filename (without extension) to save the results.')
args = parser.parse_args()
logging.info('started classification')
with open(args.pipeline, 'rb') as f:
pipeline = pickle.load(f)
test_data = load_data(args.test_data)
results = classify(pipeline, test_data)
write_results(results, pipeline.classes_, f'{args.output}_results.txt')
plot_results_by_number(results, saveto=f'{args.output}_confidence.pdf')
test_data_full = join(test_data, results)
plot_histograms(test_data_full, 'params', 'prediction', var_kwd='paramnames', row_kwd='filters',
saveto=f'{args.output}_parameters.pdf')
plot_histograms(test_data_full, 'features', 'prediction', var_kwd='featnames', row_kwd='filters',
saveto=f'{args.output}_features.pdf')
logging.info('finished classification')
def _validate_args(args):
with open(args.pipeline, 'rb') as f:
pipeline = pickle.load(f)
validation_data = load_data(args.validation_data)
if validation_data['type'].mask.any():
raise ValueError('validation data is missing values in the "type" column')
if args.train_data is None:
train_data = validation_data
else:
train_data = load_data(args.train_data)
if train_data['type'].mask.any():
raise ValueError('training data is missing values in the "type" column')
return pipeline, train_data, validation_data
def _validate():
parser = ArgumentParser()
parser.add_argument('pipeline', help='Filename of the pickled classification pipeline.')
parser.add_argument('validation_data', help='Filename of the metadata table for the validation set.')
parser.add_argument('--train-data', help='Filename of the metadata table for the training set, if different than '
'the validation set.')
parser.add_argument('--pmin', type=float, default=0.,
help='Minimum confidence to be included in the confusion matrix.')
args = parser.parse_args()
pipeline, train_data, validation_data = _validate_args(args)
logging.info('started validation')
plot_feature_importance(pipeline, train_data, saveto='feature_importance.pdf')
results_validate = validate_classifier(pipeline, train_data, validation_data)
write_results(results_validate, pipeline.classes_, 'validation_results.txt')
make_confusion_matrix(results_validate, pipeline.classes_, args.pmin, 'confusion_matrix.pdf')
make_confusion_matrix(results_validate, pipeline.classes_, args.pmin, 'confusion_matrix_purity.pdf', purity=True)
plot_results_by_number(results_validate, class_kwd='type', saveto='validation_confidence_specclass.pdf')
plot_results_by_number(results_validate, saveto='validation_confidence_photclass.pdf')
plot_metrics_by_number(results_validate, classes=pipeline.classes_, saveto='threshold.pdf')
logging.info('finished validation')
def _latex():
parser = ArgumentParser()
parser.add_argument('filename', help='Filename of the results to format into a LaTeX table')
parser.add_argument('-m', '--max-lines', type=int, help='Maximum number of table rows to write')
parser.add_argument('-t', '--title', default='Classification Results', help='Table caption')
parser.add_argument('-l', '--label', default='tab:results', help='LaTeX label')
args = parser.parse_args()
results = load_results(args.filename)
write_results(results, results.meta['classes'], args.filename.split('.')[0] + '.tex', max_lines=args.max_lines,
latex=True, latex_title=args.title, latex_label=args.label)