Source code for datapipe_testbench.benchmarks.dl2.energy_regression

"""
Energy regression benchmarks.

To run the benchmark requires dl2 file with gammas

Currently these are assumed to be two separate files.
"""

import itertools
from pathlib import Path
from typing import override

import matplotlib.scale as scl
import numpy as np
from ctapipe.io import TableLoader
from hist import axis
from tqdm import tqdm

from datapipe_testbench.compare_2d import CompareByCat2D, rms_profile_from_2d_hist

from ...benchmark import Benchmark
from ...comparers import ComparisonResult
from ...constants import telescope_options
from ...metric import Metric
from ...plotting import SqrtScale, crude_median_plot, plt
from ...store import MetricsStore, ResultStore

scl.register_scale(SqrtScale)


class EnergyPredictionPerTelMetric(Metric):
    """Metric for energy regression.

    Parameters
    ----------
    predicted_energy_column: str
        name of the column containing the per-tel energy prediction,
        must be given for accumulation to work.
    telescope_types: list | tuple
        List of telescope types to categorize by
    output_store: ResultStore
      Where results will be saved. Also used to update the name of the object
      instance.
    """

    title = "EnergyPredictionPerTelMetric"

    def __init__(
        self,
        predicted_energy_column,
        telescope_types: list[str],
        output_store=None,
        **kwargs,
    ):
        axis_list = [
            axis.StrCategory(categories=telescope_types, name="tel_type"),
            axis.Regular(
                bins=33,
                start=0.01,
                stop=350,
                name="E_true",
                label="True Energy",
                transform=axis.transform.log,
                overflow=True,
            ),
            axis.Regular(
                bins=6,
                start=0,
                stop=1400,
                name="I_true",
                label="Tel Impact",
                transform=axis.transform.sqrt,
                overflow=True,
            ),
            # Hand chosen bins, 11 linear bins around from -1.5 to 1.5,
            # then 13 log10 bins from 1.5 to 16
            # cust_bin = linspace(-1.5, 1.5, num=11), logspace(log10(1.5), log10(16), num=13)[1:]
            axis.Variable(
                [
                    -1.5,
                    -1.2,
                    -0.9,
                    -0.6,
                    -0.3,
                    0.0,
                    0.3,
                    0.6,
                    0.9,
                    1.2,
                    1.5,
                    1.82709159,
                    2.22550913,
                    2.71080601,
                    3.30192725,
                    4.02194901,
                    4.89897949,
                    5.96725616,
                    7.26848237,
                    8.85345536,
                    10.78404924,
                    13.13563047,
                    16.0,
                ],
                name="rel_err",
                label="Relative reconstruction error",
                overflow=True,
                underflow=True,
            ),
        ]
        # TODO: make energy column name configurabel at creation
        super().__init__(axis_list, **kwargs)
        self.update_name_with_store(output_store)
        self.reg_name = predicted_energy_column

    def accumulate(self, ev_dict):
        """
        Fill axis based on input events.

        Parameters
        ----------
        ev_dict: dict[str, astropy.table.Table]
            Event data coming from table_loader methods
        """
        for tel_type, evs in ev_dict.items():
            self.fill(
                tel_type,
                evs["true_energy"],
                evs["true_impact_distance"],
                (evs[self.reg_name] - evs["true_energy"]) / evs["true_energy"],
            )

    def compare(self, others: list):
        """Compare several instances of the class.

        Uses `self` as reference when comparing with `others`.

        Parameters
        ----------
        others: list[ClassifierPerTelMetric]
            list containing other instances of this class.

        Returns
        -------
        ComparisonResult:
            results of comparison of this reference to the others list
        """
        comp = CompareRegs(self.axes["tel_type"])
        ref = self.project("tel_type", "E_true", "rel_err")
        ref.label = ref.name
        others_proj = []
        for met in others:
            others_proj.append(met.project("tel_type", "E_true", "rel_err"))
            others_proj[-1].label = others_proj[-1].name
        fig1 = comp.plot_compare_bias_threshold(ref, others_proj, threshold=0.2)

        fig2 = comp.plot_compare_2dhist_rms(ref, others_proj, sharey="row")

        ref = self.project("tel_type", "I_true", "rel_err")
        ref.label = ref.name
        others_proj = []
        for met in others:
            others_proj.append(met.project("tel_type", "I_true", "rel_err"))
            others_proj[-1].label = others_proj[-1].name

        fig3 = comp.plot_compare_2dhist_mean(
            ref, others_proj, xscale="sqrt", sharey="row"
        )
        fig4 = comp.plot_compare_2dhist_rms(
            ref, others_proj, xscale="sqrt", sharey="row"
        )
        ref = self.project("tel_type", "E_true", "I_true", "rel_err")
        ref.label = ref.name
        others_proj = []
        for met in others:
            others_proj.append(met.project("tel_type", "E_true", "I_true", "rel_err"))
            others_proj[-1].label = others_proj[-1].name

        fig5 = comp.plot_compare_median_err(ref, others_proj)

        return ComparisonResult(
            self.__class__.__name__,
            {
                "EnergyBiasPerTel_with_threshold": fig1,
                "EnergyResolutionPerTel": fig2,
                "EnergyMeanRecoErrorByImpactPerTel": fig3,
                "EnergyMeanRMSErrorByImpactPerTel": fig4,
                "EnergyMedianErrByEnergyByImpactPerTel": fig5,
            },
            None,
            None,
        )


class CompareRegs(CompareByCat2D):
    """Helper Class for comparing EnergyPredictionPerTelMetric instances."""

    def plot_compare_bias_threshold(
        self, ref, others, threshold=0.1, xscale="log", yscale="linear"
    ):
        """Plot energy bias and draw line for roughly where the energy threshold lies."""
        fig = self.plot_compare_2dhist_mean(ref, others, xscale, yscale, sharey="row")

        for ax, cat in zip(fig.axes, self.categories):
            self._add_threshold_line(ref[cat, ...], ax, threshold)
            for oth in others:
                self._add_threshold_line(oth[cat, ...], ax, threshold)
        return fig

    def _add_threshold_line(self, met, ax, threshold):
        axes = met.axes[0]
        _, mean, _ = rms_profile_from_2d_hist(met)
        mean = np.nan_to_num(mean, nan=10)
        sel = np.abs(mean) > threshold
        breaks = np.where(np.diff(sel))[0]

        # If there are more than two sections where the bias goes beyond the
        # threshold things will be weird, for now just let them be weird
        thresh = axes.bin(breaks[0])[-1]
        ax.axvline(thresh, ls=":", label=f"{met.name} E_th: {thresh:.2f}")
        ax.legend()

    def plot_compare_median_err(self, ref, others):
        """Plot median estimate as function of two axes.

        Takes a ``[x,y,z]`` shape histogram, finds the bin containing the median,
        and plots the center value of that bin in a 2d plot per category,
        for each metric provided.
        """
        fig = plt.figure(
            figsize=(1.1 * 4 * len(self.categories), (1 + len(others)) * 4),
            layout="constrained",
        )
        subfigs = fig.subfigures(nrows=len(others) + 1, ncols=1)
        for idy, met in enumerate([ref] + others):
            axs = subfigs[idy].subplots(nrows=1, ncols=len(self.categories))
            for idx, cat in enumerate(self.categories):
                crude_median_plot(axs[idx], met[cat, ...])
                axs[idx].set_title(cat)
            subfigs[idy].suptitle(met.name)
        return fig


[docs] class EnergyRecoBenchmark(Benchmark): """ Benchmark performance of gammaness classifier. For generation, requires: * A dl2 file with gammas """ chunk_size = 10000 metric_path = Path("dl2/energy_regression") def __init__(self, reg_name="", max_chunks=None): super().__init__() self.max_chunks = max_chunks self.output_names = { "reg": self.metric_path / "EnergyRegeression.asdf", } self.reg_name = reg_name
[docs] @override def generate_metrics(self, metric_store: MetricsStore) -> None: dl2_signal = metric_store.get_inputdata().dl2_signal if not dl2_signal: dl2_signal = metric_store.get_inputdata().dl2 sig_file = dl2_signal options = dict.fromkeys(telescope_options, False) options["dl2"] = True options["simulated"] = True # accumulate the necessary information print("Starting accumulation") with TableLoader(sig_file, **options) as loader: telescope_types = list({str(t) for t in loader.subarray.telescope_types}) reg_tel = EnergyPredictionPerTelMetric( predicted_energy_column=self.reg_name, telescope_types=telescope_types, output_store=metric_store, ) events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000) if self.max_chunks: events = itertools.islice(events, self.max_chunks) for _, _, events_by_type in tqdm(events): reg_tel.accumulate(events_by_type) print("Accumulated") # store the output data metric_store.store_data(reg_tel, self.output_names["reg"])
[docs] @override def compare_to_reference( self, metric_store_list: list[MetricsStore], result_store: ResultStore ): bench_header = {} bench_header["name"] = "Classification Performance" bench_header["inputs"] = [ itm.metadata["input_dataset"].to_dict() for itm in metric_store_list ] result_store.metadata["benchmarks"][self.metric_path] = bench_header return self._compare_metrics_in_stores(metric_store_list, result_store)