Source code for datapipe_testbench.benchmarks.dl2.classification

"""
Classification benchmarks.

To run the benchmark requires
* A dl2 file with gammas
* A dl2 file with protons

Currently these are assumed to be two separate files.
"""

import itertools
from pathlib import Path
from typing import override

import matplotlib.pyplot as plt
import matplotlib.scale as scl
import numpy as np
from ctapipe.io import TableLoader
from hist import axis
from tqdm import tqdm

from ...benchmark import Benchmark
from ...compare_1d import CompareByCat1D
from ...comparers import ComparisonResult
from ...constants import telescope_options
from ...metric import Metric
from ...plotting import SqrtScale
from ...store import MetricsStore, ResultStore

scl.register_scale(SqrtScale)

if np.lib.NumpyVersion(np.__version__) < "2.0.0":
    integrator = np.trapz  # noqa:NPY201
else:
    integrator = np.trapezoid


class ClassifierPerTelMetric(Metric):
    """
    Metric for particle classification.

    Must be constructed the first time using ``ClassifierPerTelMetric.setup``
    """

    title = "ClassifierPerTel"

    @classmethod
    def setup(
        cls,
        classifier_column: str,
        telescope_types: list[str],
        label: str | None = None,
        output_store: ResultStore | None = None,
    ):
        """Set up a ClassifierPerTelMetric.

        Parameters
        ----------
        classifier_column : str
            column name to use for the classifier prediction

        telescope_types : list[str]
            List of telescope type names. Should be extracted from
            the SubarrayDescription.

        label : str
            From Metric: If not specified, will be auto-generated

        output_store : ResultStore | None
            From Metric: If specified, used to set the metric name.
        """
        axis_list = [
            axis.Boolean(name="signal"),
            axis.StrCategory(categories=telescope_types, name="tel_type"),
            axis.Regular(
                bins=13,
                start=0.01,
                stop=350,
                name="E_true",
                label="True Energy",
                transform=axis.transform.log,
                overflow=True,
            ),
            axis.Regular(
                bins=6,
                start=0,
                stop=1400,
                name="I_true",
                label="Tel Impact",
                transform=axis.transform.sqrt,
                overflow=True,
            ),
            axis.Regular(
                bins=43,
                start=0.0,
                stop=1,
                name="gammaness",
                label="gammaness score",
                overflow=False,
            ),
        ]

        # call the default Metric constructor:
        return cls(
            axis_list=axis_list,
            unit="",
            label=label,
            output_store=output_store,
            other_attributes=dict(
                classifier_column=classifier_column,
                telescope_types=telescope_types,
                shown_energy_bins=[2, 7, 12],
                shown_impact_bins=[1, 2, 5],
            ),
        )

    def accumulate(self, ev_dict, kind=None):
        """
        Fill axis based on input events.

        Parameters
        ----------
        ev_dict: dict[str, astropy.table.Table]
            Event data coming from table_loader methods
        """
        for tel_type, evs in ev_dict.items():
            if kind and kind in ("signal", "background"):
                signal = kind == "signal"
            else:
                signal = evs["true_shower_primary_id"] == 0

            self.fill(
                signal,
                tel_type,
                evs["true_energy"],
                evs["true_impact_distance"],
                evs[self.other_attributes["classifier_column"]],
            )

    def compare(self, others: list):
        """Compare several instances of the class.

        Uses `self` as reference when comparing with `others`.

        Parameters
        ----------
        others: list[ClassifierPerTelMetric]
            list containing other instances of this class.

        Returns
        -------
        ComparisonResult:
            results of comparison of this reference to the others list
        """
        comparer = CompareClassifiers(self.axes["tel_type"])

        ref = self.project("signal", "tel_type", "E_true", "gammaness")
        others_proj = [
            met.project("signal", "tel_type", "E_true", "gammaness") for met in others
        ]
        others_proj = [(met[1, ...], met[0, ...]) for met in others_proj]
        fig1 = comparer.plot_compare_tuples_by_category_in_bins(
            (ref[1, ...], ref[0, ...]),
            others_proj,
            bins=self.other_attributes["shown_energy_bins"],
        )

        ref = self.project("signal", "tel_type", "I_true", "gammaness")
        others_proj = [
            met.project("signal", "tel_type", "I_true", "gammaness") for met in others
        ]
        others_proj = [(met[1, ...], met[0, ...]) for met in others_proj]
        fig2 = comparer.plot_compare_tuples_by_category_in_bins(
            (ref[1, ...], ref[0, ...]),
            others_proj,
            bins=self.other_attributes["shown_impact_bins"],
            ybin_label="impact",
            ybin_unit="m",
        )

        # Need to reorder the axes for the plotting command
        ref = self.project("signal", "tel_type", "I_true", "E_true", "gammaness")
        others_proj = [
            met.project("signal", "tel_type", "I_true", "E_true", "gammaness")
            for met in others
        ]
        fig3 = comparer.plot_auc_by_category_by_bin(
            ref,
            others_proj,
            bins=self.other_attributes["shown_impact_bins"],
        )

        fig4 = comparer.plot_auc_by_category_by_bin(
            self,
            others,
            bins=self.other_attributes["shown_energy_bins"],
            bin_axes="E_true",
            x_axes="I_true",
            x_scale="sqrt",
            x_label="tel impact [m]",
            ybin_label="True Energy",
            ybin_unit="[TeV]",
        )
        fig5 = comparer.plot_roc_by_category_by_bin(
            ref,
            others_proj,
            bins=self.other_attributes["shown_energy_bins"],
            bin_axes="E_true",
            ybin_label="True Energy",
            ybin_unit="[TeV]",
        )
        fig6 = comparer.plot_roc_by_category_by_bin(
            ref,
            others_proj,
            bins=self.other_attributes["shown_impact_bins"],
            bin_axes="I_true",
        )
        aucs = {}
        aucs[f"{self.name}_auc"] = comparer.auc_by_category_energy(ref)
        for other in others:
            aucs[f"{other.name}_auc"] = comparer.auc_by_category_energy(other)

        # TODO: add distance between AUC curves measurement, will
        # be somewhat involved though
        # measures = comparer.distance_by_category(self, others)

        return ComparisonResult(
            self.__class__.__name__,
            {
                "Cls_Separation_Per_Energy": fig1,
                "Cls_Separation_Per_Tel_Impact": fig2,
                "Cls_AUC_Per_Energy": fig3,
                "Cls_AUC_Per_TelImpact": fig4,
                "Cls_ROC_Per_Energy": fig5,
                "Cls_ROC_Per_TelImpact": fig6,
            },
            aucs,
            None,
        )


class CompareClassifiers(CompareByCat1D):
    """Helper Class for comparing ClassifierPerTelMetric instances."""

    def plot_auc_by_category_by_bin(
        self,
        ref,
        others,
        bins,
        bin_axes="I_true",
        x_axes="E_true",
        x_scale="log",
        x_label="True Energy [TeV]",
        ybin_label="tel impact",
        ybin_unit="m",
        prop_axes="tel_type",
    ):
        """Plot auc by category in bins as function of property."""
        n_rows = len(bins)
        n_cols = len(self.categories)
        x_bins = ref.axes[x_axes].edges

        fig, axs = plt.subplots(
            n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows), sharex=True
        )

        ref_tpr, ref_fpr, ref_auc = roc_from_metric(ref)
        oth_rocs = []
        for oth in others:
            oth_rocs.append(roc_from_metric(oth))
        for idx, cat in enumerate(self.categories):
            for idy, bin in enumerate(bins):
                ax = axs[idy, idx]
                ax.stairs(ref_auc[idx, bin, :], x_bins, label=ref.name, lw=2)
                for oth, roc in zip(others, oth_rocs):
                    _, _, auc = roc
                    ax.stairs(auc[idx, bin, :], x_bins, label=oth.name)
                low, hig = ref.axes[bin_axes].bin(bin)
                ax.set(
                    ylabel=f"AUC\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}",
                    xlabel=x_label,
                )
                ax.set(xscale=x_scale)
            if idy == 0:
                ax.set_title(prop_axes)
        axs[0, 0].legend()
        return fig

    def plot_roc_by_category_by_bin(
        self,
        ref,
        others,
        bins=[1, 2, 5],
        bin_axes="I_true",
        ybin_label="tel impact",
        ybin_unit="m",
        prop_axes="tel_type",
    ):
        """Plot ROC curves per category in bins as function of property."""
        n_rows = len(bins)
        n_cols = len(self.categories)

        fig, axs = plt.subplots(
            n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows), sharex=True
        )

        ref_tpr, ref_fpr, ref_auc = roc_from_metric(
            ref.project("signal", bin_axes, "gammaness")
        )
        oth_rocs = []
        for oth in others:
            oth_rocs.append(
                roc_from_metric(oth.project("signal", bin_axes, "gammaness"))
            )
        for idx, cat in enumerate(self.categories):
            for idy, bin in enumerate(bins):
                ax = axs[idy, idx]
                ax.plot(
                    ref_fpr[bin, :],
                    ref_tpr[bin, :],
                    label=f"auc {ref_auc[bin]:.3f} {ref.name}",
                    lw=2,
                )
                for oth, roc in zip(others, oth_rocs):
                    tpr, fpr, auc = roc  # codespell:ignore fpr
                    ax.plot(
                        fpr[bin, :],  # codespell:ignore fpr
                        tpr[bin, :],
                        label=f"auc {auc[bin]:.3f} {oth.name}",
                        lw=2,
                    )
                low, hig = ref.axes[bin_axes].bin(bin)
                ax.set(
                    xlabel="FPR",  # codespell:ignore fpr
                    ylabel=f"TPR\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}",
                )
                ax.legend()
                if idy == 0:
                    ax.set_title(cat)

        return fig

    def auc_by_category_energy(self, ref):
        """Compute AUC per energy bin and category."""
        aucs = {}
        for cat in self.categories:
            arr = roc_from_metric(
                ref[:, cat, ...].project("signal", "E_true", "gammaness")
            )[-1]
            aucs[cat] = list(arr)
        aucs["E_bins"] = list(ref.axes["E_true"].edges)

        return aucs


[docs] class ClassifierBenchmark(Benchmark): """ Benchmark performance of gammaness classifier. For generation, requires: * A dl2 file with gammas * A dl2 file with protons """ def __init__(self, cls_name, max_chunks=None): super().__init__() self.chunk_size = 10000 self.metric_path = Path("dl2/classification") self.output_names = { "cls": self.metric_path / "Classification.asdf", } self.cls_name = cls_name self.max_chunks = max_chunks
[docs] @override def generate_metrics(self, output_store: MetricsStore) -> None: sig_file = output_store.get_inputdata().dl2_signal bak_file = output_store.get_inputdata().dl2_background options = dict.fromkeys(telescope_options, False) options["dl2"] = True options["simulated"] = True # accumulate the necessary information print("Starting accumulation") with TableLoader(sig_file, **options) as loader: # ctap v 0.20 syntax # telescope_types = list({str(t) for t in loader.subarray.telescope_types}) clf_tel = ClassifierPerTelMetric.setup( classifier_column=self.cls_name, telescope_types=telescope_types, output_store=output_store, ) events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000) if self.max_chunks: events = itertools.islice(events, self.max_chunks) for _, _, events_by_type in tqdm(events, desc="signal"): clf_tel.accumulate(events_by_type, kind="signal") with TableLoader(bak_file, **options) as loader: # ctap v 0.20 syntax events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000) if self.max_chunks: events = itertools.islice(events, self.max_chunks) for _, _, events_by_type in tqdm( events, desc="background", ): clf_tel.accumulate(events_by_type, kind="background") print("Accumulated") # store the output data output_store.store_data(clf_tel, self.output_names["cls"])
@property @override def required_inputs(self) -> set[str]: return set(["dl2_signal", "dl2_background"])
[docs] @override def compare_to_reference( self, stores: list[MetricsStore], rstore: ResultStore ) -> dict: bench_header = {} bench_header["name"] = "Classification Performance" bench_header["inputs"] = [ itm.metadata["input_dataset"].to_dict() for itm in stores ] rstore.metadata["benchmarks"][self.metric_path] = bench_header return self._compare_metrics_in_stores(stores, rstore)
def roc_from_metric(metric): """Calculate ROC and AUC from metric. Assumes a metric where first axis is binary signal or not-signal """ sig_h = metric[1, ...].to_numpy()[0] bak_h = metric[0, ...].to_numpy()[0] n_gam = sig_h.sum(axis=-1) n_pro = bak_h.sum(axis=-1) fn = sig_h.cumsum(axis=-1) tp = n_gam[..., np.newaxis] - fn tn = bak_h.cumsum(axis=-1) fp = n_pro[..., np.newaxis] - tn tpr = tp / (tp + fn) fpr = fp / (fp + tn) # codespell:ignore fpr auc = -integrator(tpr, fpr) # codespell:ignore fpr return tpr, fpr, auc # codespell:ignore fpr