"""
Classification benchmarks.
To run the benchmark requires
* A dl2 file with gammas
* A dl2 file with protons
Currently these are assumed to be two separate files.
"""
import itertools
from pathlib import Path
from typing import override
import matplotlib.pyplot as plt
import matplotlib.scale as scl
import numpy as np
from ctapipe.io import TableLoader
from hist import axis
from tqdm import tqdm
from ...benchmark import Benchmark
from ...compare_1d import CompareByCat1D
from ...comparers import ComparisonResult
from ...constants import telescope_options
from ...metric import Metric
from ...plotting import SqrtScale
from ...store import MetricsStore, ResultStore
scl.register_scale(SqrtScale)
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
integrator = np.trapz # noqa:NPY201
else:
integrator = np.trapezoid
class ClassifierPerTelMetric(Metric):
"""
Metric for particle classification.
Must be constructed the first time using ``ClassifierPerTelMetric.setup``
"""
title = "ClassifierPerTel"
@classmethod
def setup(
cls,
classifier_column: str,
telescope_types: list[str],
label: str | None = None,
output_store: ResultStore | None = None,
):
"""Set up a ClassifierPerTelMetric.
Parameters
----------
classifier_column : str
column name to use for the classifier prediction
telescope_types : list[str]
List of telescope type names. Should be extracted from
the SubarrayDescription.
label : str
From Metric: If not specified, will be auto-generated
output_store : ResultStore | None
From Metric: If specified, used to set the metric name.
"""
axis_list = [
axis.Boolean(name="signal"),
axis.StrCategory(categories=telescope_types, name="tel_type"),
axis.Regular(
bins=13,
start=0.01,
stop=350,
name="E_true",
label="True Energy",
transform=axis.transform.log,
overflow=True,
),
axis.Regular(
bins=6,
start=0,
stop=1400,
name="I_true",
label="Tel Impact",
transform=axis.transform.sqrt,
overflow=True,
),
axis.Regular(
bins=43,
start=0.0,
stop=1,
name="gammaness",
label="gammaness score",
overflow=False,
),
]
# call the default Metric constructor:
return cls(
axis_list=axis_list,
unit="",
label=label,
output_store=output_store,
other_attributes=dict(
classifier_column=classifier_column,
telescope_types=telescope_types,
shown_energy_bins=[2, 7, 12],
shown_impact_bins=[1, 2, 5],
),
)
def accumulate(self, ev_dict, kind=None):
"""
Fill axis based on input events.
Parameters
----------
ev_dict: dict[str, astropy.table.Table]
Event data coming from table_loader methods
"""
for tel_type, evs in ev_dict.items():
if kind and kind in ("signal", "background"):
signal = kind == "signal"
else:
signal = evs["true_shower_primary_id"] == 0
self.fill(
signal,
tel_type,
evs["true_energy"],
evs["true_impact_distance"],
evs[self.other_attributes["classifier_column"]],
)
def compare(self, others: list):
"""Compare several instances of the class.
Uses `self` as reference when comparing with `others`.
Parameters
----------
others: list[ClassifierPerTelMetric]
list containing other instances of this class.
Returns
-------
ComparisonResult:
results of comparison of this reference to the others list
"""
comparer = CompareClassifiers(self.axes["tel_type"])
ref = self.project("signal", "tel_type", "E_true", "gammaness")
others_proj = [
met.project("signal", "tel_type", "E_true", "gammaness") for met in others
]
others_proj = [(met[1, ...], met[0, ...]) for met in others_proj]
fig1 = comparer.plot_compare_tuples_by_category_in_bins(
(ref[1, ...], ref[0, ...]),
others_proj,
bins=self.other_attributes["shown_energy_bins"],
)
ref = self.project("signal", "tel_type", "I_true", "gammaness")
others_proj = [
met.project("signal", "tel_type", "I_true", "gammaness") for met in others
]
others_proj = [(met[1, ...], met[0, ...]) for met in others_proj]
fig2 = comparer.plot_compare_tuples_by_category_in_bins(
(ref[1, ...], ref[0, ...]),
others_proj,
bins=self.other_attributes["shown_impact_bins"],
ybin_label="impact",
ybin_unit="m",
)
# Need to reorder the axes for the plotting command
ref = self.project("signal", "tel_type", "I_true", "E_true", "gammaness")
others_proj = [
met.project("signal", "tel_type", "I_true", "E_true", "gammaness")
for met in others
]
fig3 = comparer.plot_auc_by_category_by_bin(
ref,
others_proj,
bins=self.other_attributes["shown_impact_bins"],
)
fig4 = comparer.plot_auc_by_category_by_bin(
self,
others,
bins=self.other_attributes["shown_energy_bins"],
bin_axes="E_true",
x_axes="I_true",
x_scale="sqrt",
x_label="tel impact [m]",
ybin_label="True Energy",
ybin_unit="[TeV]",
)
fig5 = comparer.plot_roc_by_category_by_bin(
ref,
others_proj,
bins=self.other_attributes["shown_energy_bins"],
bin_axes="E_true",
ybin_label="True Energy",
ybin_unit="[TeV]",
)
fig6 = comparer.plot_roc_by_category_by_bin(
ref,
others_proj,
bins=self.other_attributes["shown_impact_bins"],
bin_axes="I_true",
)
aucs = {}
aucs[f"{self.name}_auc"] = comparer.auc_by_category_energy(ref)
for other in others:
aucs[f"{other.name}_auc"] = comparer.auc_by_category_energy(other)
# TODO: add distance between AUC curves measurement, will
# be somewhat involved though
# measures = comparer.distance_by_category(self, others)
return ComparisonResult(
self.__class__.__name__,
{
"Cls_Separation_Per_Energy": fig1,
"Cls_Separation_Per_Tel_Impact": fig2,
"Cls_AUC_Per_Energy": fig3,
"Cls_AUC_Per_TelImpact": fig4,
"Cls_ROC_Per_Energy": fig5,
"Cls_ROC_Per_TelImpact": fig6,
},
aucs,
None,
)
class CompareClassifiers(CompareByCat1D):
"""Helper Class for comparing ClassifierPerTelMetric instances."""
def plot_auc_by_category_by_bin(
self,
ref,
others,
bins,
bin_axes="I_true",
x_axes="E_true",
x_scale="log",
x_label="True Energy [TeV]",
ybin_label="tel impact",
ybin_unit="m",
prop_axes="tel_type",
):
"""Plot auc by category in bins as function of property."""
n_rows = len(bins)
n_cols = len(self.categories)
x_bins = ref.axes[x_axes].edges
fig, axs = plt.subplots(
n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows), sharex=True
)
ref_tpr, ref_fpr, ref_auc = roc_from_metric(ref)
oth_rocs = []
for oth in others:
oth_rocs.append(roc_from_metric(oth))
for idx, cat in enumerate(self.categories):
for idy, bin in enumerate(bins):
ax = axs[idy, idx]
ax.stairs(ref_auc[idx, bin, :], x_bins, label=ref.name, lw=2)
for oth, roc in zip(others, oth_rocs):
_, _, auc = roc
ax.stairs(auc[idx, bin, :], x_bins, label=oth.name)
low, hig = ref.axes[bin_axes].bin(bin)
ax.set(
ylabel=f"AUC\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}",
xlabel=x_label,
)
ax.set(xscale=x_scale)
if idy == 0:
ax.set_title(prop_axes)
axs[0, 0].legend()
return fig
def plot_roc_by_category_by_bin(
self,
ref,
others,
bins=[1, 2, 5],
bin_axes="I_true",
ybin_label="tel impact",
ybin_unit="m",
prop_axes="tel_type",
):
"""Plot ROC curves per category in bins as function of property."""
n_rows = len(bins)
n_cols = len(self.categories)
fig, axs = plt.subplots(
n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows), sharex=True
)
ref_tpr, ref_fpr, ref_auc = roc_from_metric(
ref.project("signal", bin_axes, "gammaness")
)
oth_rocs = []
for oth in others:
oth_rocs.append(
roc_from_metric(oth.project("signal", bin_axes, "gammaness"))
)
for idx, cat in enumerate(self.categories):
for idy, bin in enumerate(bins):
ax = axs[idy, idx]
ax.plot(
ref_fpr[bin, :],
ref_tpr[bin, :],
label=f"auc {ref_auc[bin]:.3f} {ref.name}",
lw=2,
)
for oth, roc in zip(others, oth_rocs):
tpr, fpr, auc = roc # codespell:ignore fpr
ax.plot(
fpr[bin, :], # codespell:ignore fpr
tpr[bin, :],
label=f"auc {auc[bin]:.3f} {oth.name}",
lw=2,
)
low, hig = ref.axes[bin_axes].bin(bin)
ax.set(
xlabel="FPR", # codespell:ignore fpr
ylabel=f"TPR\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}",
)
ax.legend()
if idy == 0:
ax.set_title(cat)
return fig
def auc_by_category_energy(self, ref):
"""Compute AUC per energy bin and category."""
aucs = {}
for cat in self.categories:
arr = roc_from_metric(
ref[:, cat, ...].project("signal", "E_true", "gammaness")
)[-1]
aucs[cat] = list(arr)
aucs["E_bins"] = list(ref.axes["E_true"].edges)
return aucs
[docs]
class ClassifierBenchmark(Benchmark):
"""
Benchmark performance of gammaness classifier.
For generation, requires:
* A dl2 file with gammas
* A dl2 file with protons
"""
def __init__(self, cls_name, max_chunks=None):
super().__init__()
self.chunk_size = 10000
self.metric_path = Path("dl2/classification")
self.output_names = {
"cls": self.metric_path / "Classification.asdf",
}
self.cls_name = cls_name
self.max_chunks = max_chunks
[docs]
@override
def generate_metrics(self, output_store: MetricsStore) -> None:
sig_file = output_store.get_inputdata().dl2_signal
bak_file = output_store.get_inputdata().dl2_background
options = dict.fromkeys(telescope_options, False)
options["dl2"] = True
options["simulated"] = True
# accumulate the necessary information
print("Starting accumulation")
with TableLoader(sig_file, **options) as loader: # ctap v 0.20 syntax
#
telescope_types = list({str(t) for t in loader.subarray.telescope_types})
clf_tel = ClassifierPerTelMetric.setup(
classifier_column=self.cls_name,
telescope_types=telescope_types,
output_store=output_store,
)
events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000)
if self.max_chunks:
events = itertools.islice(events, self.max_chunks)
for _, _, events_by_type in tqdm(events, desc="signal"):
clf_tel.accumulate(events_by_type, kind="signal")
with TableLoader(bak_file, **options) as loader: # ctap v 0.20 syntax
events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000)
if self.max_chunks:
events = itertools.islice(events, self.max_chunks)
for _, _, events_by_type in tqdm(
events,
desc="background",
):
clf_tel.accumulate(events_by_type, kind="background")
print("Accumulated")
# store the output data
output_store.store_data(clf_tel, self.output_names["cls"])
@property
@override
def required_inputs(self) -> set[str]:
return set(["dl2_signal", "dl2_background"])
[docs]
@override
def compare_to_reference(
self, stores: list[MetricsStore], rstore: ResultStore
) -> dict:
bench_header = {}
bench_header["name"] = "Classification Performance"
bench_header["inputs"] = [
itm.metadata["input_dataset"].to_dict() for itm in stores
]
rstore.metadata["benchmarks"][self.metric_path] = bench_header
return self._compare_metrics_in_stores(stores, rstore)
def roc_from_metric(metric):
"""Calculate ROC and AUC from metric.
Assumes a metric where first axis is binary signal or not-signal
"""
sig_h = metric[1, ...].to_numpy()[0]
bak_h = metric[0, ...].to_numpy()[0]
n_gam = sig_h.sum(axis=-1)
n_pro = bak_h.sum(axis=-1)
fn = sig_h.cumsum(axis=-1)
tp = n_gam[..., np.newaxis] - fn
tn = bak_h.cumsum(axis=-1)
fp = n_pro[..., np.newaxis] - tn
tpr = tp / (tp + fn)
fpr = fp / (fp + tn) # codespell:ignore fpr
auc = -integrator(tpr, fpr) # codespell:ignore fpr
return tpr, fpr, auc # codespell:ignore fpr