"""
Energy regression benchmarks.
To run the benchmark requires dl2 file with gammas
Currently these are assumed to be two separate files.
"""
import itertools
from pathlib import Path
from typing import override
import matplotlib.scale as scl
import numpy as np
from ctapipe.io import TableLoader
from hist import axis
from tqdm import tqdm
from datapipe_testbench.compare_2d import CompareByCat2D, rms_profile_from_2d_hist
from ...benchmark import Benchmark
from ...comparers import ComparisonResult
from ...constants import telescope_options
from ...metric import Metric
from ...plotting import SqrtScale, crude_median_plot, plt
from ...store import MetricsStore, ResultStore
scl.register_scale(SqrtScale)
class EnergyPredictionPerTelMetric(Metric):
"""Metric for energy regression.
Parameters
----------
predicted_energy_column: str
name of the column containing the per-tel energy prediction,
must be given for accumulation to work.
telescope_types: list | tuple
List of telescope types to categorize by
output_store: ResultStore
Where results will be saved. Also used to update the name of the object
instance.
"""
title = "EnergyPredictionPerTelMetric"
def __init__(
self,
predicted_energy_column,
telescope_types: list[str],
output_store=None,
**kwargs,
):
axis_list = [
axis.StrCategory(categories=telescope_types, name="tel_type"),
axis.Regular(
bins=33,
start=0.01,
stop=350,
name="E_true",
label="True Energy",
transform=axis.transform.log,
overflow=True,
),
axis.Regular(
bins=6,
start=0,
stop=1400,
name="I_true",
label="Tel Impact",
transform=axis.transform.sqrt,
overflow=True,
),
# Hand chosen bins, 11 linear bins around from -1.5 to 1.5,
# then 13 log10 bins from 1.5 to 16
# cust_bin = linspace(-1.5, 1.5, num=11), logspace(log10(1.5), log10(16), num=13)[1:]
axis.Variable(
[
-1.5,
-1.2,
-0.9,
-0.6,
-0.3,
0.0,
0.3,
0.6,
0.9,
1.2,
1.5,
1.82709159,
2.22550913,
2.71080601,
3.30192725,
4.02194901,
4.89897949,
5.96725616,
7.26848237,
8.85345536,
10.78404924,
13.13563047,
16.0,
],
name="rel_err",
label="Relative reconstruction error",
overflow=True,
underflow=True,
),
]
# TODO: make energy column name configurabel at creation
super().__init__(axis_list, **kwargs)
self.update_name_with_store(output_store)
self.reg_name = predicted_energy_column
def accumulate(self, ev_dict):
"""
Fill axis based on input events.
Parameters
----------
ev_dict: dict[str, astropy.table.Table]
Event data coming from table_loader methods
"""
for tel_type, evs in ev_dict.items():
self.fill(
tel_type,
evs["true_energy"],
evs["true_impact_distance"],
(evs[self.reg_name] - evs["true_energy"]) / evs["true_energy"],
)
def compare(self, others: list):
"""Compare several instances of the class.
Uses `self` as reference when comparing with `others`.
Parameters
----------
others: list[ClassifierPerTelMetric]
list containing other instances of this class.
Returns
-------
ComparisonResult:
results of comparison of this reference to the others list
"""
comp = CompareRegs(self.axes["tel_type"])
ref = self.project("tel_type", "E_true", "rel_err")
ref.label = ref.name
others_proj = []
for met in others:
others_proj.append(met.project("tel_type", "E_true", "rel_err"))
others_proj[-1].label = others_proj[-1].name
fig1 = comp.plot_compare_bias_threshold(ref, others_proj, threshold=0.2)
fig2 = comp.plot_compare_2dhist_rms(ref, others_proj, sharey="row")
ref = self.project("tel_type", "I_true", "rel_err")
ref.label = ref.name
others_proj = []
for met in others:
others_proj.append(met.project("tel_type", "I_true", "rel_err"))
others_proj[-1].label = others_proj[-1].name
fig3 = comp.plot_compare_2dhist_mean(
ref, others_proj, xscale="sqrt", sharey="row"
)
fig4 = comp.plot_compare_2dhist_rms(
ref, others_proj, xscale="sqrt", sharey="row"
)
ref = self.project("tel_type", "E_true", "I_true", "rel_err")
ref.label = ref.name
others_proj = []
for met in others:
others_proj.append(met.project("tel_type", "E_true", "I_true", "rel_err"))
others_proj[-1].label = others_proj[-1].name
fig5 = comp.plot_compare_median_err(ref, others_proj)
return ComparisonResult(
self.__class__.__name__,
{
"EnergyBiasPerTel_with_threshold": fig1,
"EnergyResolutionPerTel": fig2,
"EnergyMeanRecoErrorByImpactPerTel": fig3,
"EnergyMeanRMSErrorByImpactPerTel": fig4,
"EnergyMedianErrByEnergyByImpactPerTel": fig5,
},
None,
None,
)
class CompareRegs(CompareByCat2D):
"""Helper Class for comparing EnergyPredictionPerTelMetric instances."""
def plot_compare_bias_threshold(
self, ref, others, threshold=0.1, xscale="log", yscale="linear"
):
"""Plot energy bias and draw line for roughly where the energy threshold lies."""
fig = self.plot_compare_2dhist_mean(ref, others, xscale, yscale, sharey="row")
for ax, cat in zip(fig.axes, self.categories):
self._add_threshold_line(ref[cat, ...], ax, threshold)
for oth in others:
self._add_threshold_line(oth[cat, ...], ax, threshold)
return fig
def _add_threshold_line(self, met, ax, threshold):
axes = met.axes[0]
_, mean, _ = rms_profile_from_2d_hist(met)
mean = np.nan_to_num(mean, nan=10)
sel = np.abs(mean) > threshold
breaks = np.where(np.diff(sel))[0]
# If there are more than two sections where the bias goes beyond the
# threshold things will be weird, for now just let them be weird
thresh = axes.bin(breaks[0])[-1]
ax.axvline(thresh, ls=":", label=f"{met.name} E_th: {thresh:.2f}")
ax.legend()
def plot_compare_median_err(self, ref, others):
"""Plot median estimate as function of two axes.
Takes a ``[x,y,z]`` shape histogram, finds the bin containing the median,
and plots the center value of that bin in a 2d plot per category,
for each metric provided.
"""
fig = plt.figure(
figsize=(1.1 * 4 * len(self.categories), (1 + len(others)) * 4),
layout="constrained",
)
subfigs = fig.subfigures(nrows=len(others) + 1, ncols=1)
for idy, met in enumerate([ref] + others):
axs = subfigs[idy].subplots(nrows=1, ncols=len(self.categories))
for idx, cat in enumerate(self.categories):
crude_median_plot(axs[idx], met[cat, ...])
axs[idx].set_title(cat)
subfigs[idy].suptitle(met.name)
return fig
[docs]
class EnergyRecoBenchmark(Benchmark):
"""
Benchmark performance of gammaness classifier.
For generation, requires:
* A dl2 file with gammas
"""
chunk_size = 10000
metric_path = Path("dl2/energy_regression")
def __init__(self, reg_name="", max_chunks=None):
super().__init__()
self.max_chunks = max_chunks
self.output_names = {
"reg": self.metric_path / "EnergyRegeression.asdf",
}
self.reg_name = reg_name
[docs]
@override
def generate_metrics(self, metric_store: MetricsStore) -> None:
dl2_signal = metric_store.get_inputdata().dl2_signal
if not dl2_signal:
dl2_signal = metric_store.get_inputdata().dl2
sig_file = dl2_signal
options = dict.fromkeys(telescope_options, False)
options["dl2"] = True
options["simulated"] = True
# accumulate the necessary information
print("Starting accumulation")
with TableLoader(sig_file, **options) as loader:
telescope_types = list({str(t) for t in loader.subarray.telescope_types})
reg_tel = EnergyPredictionPerTelMetric(
predicted_energy_column=self.reg_name,
telescope_types=telescope_types,
output_store=metric_store,
)
events = loader.read_telescope_events_by_type_chunked(chunk_size=40_000)
if self.max_chunks:
events = itertools.islice(events, self.max_chunks)
for _, _, events_by_type in tqdm(events):
reg_tel.accumulate(events_by_type)
print("Accumulated")
# store the output data
metric_store.store_data(reg_tel, self.output_names["reg"])
[docs]
@override
def compare_to_reference(
self, metric_store_list: list[MetricsStore], result_store: ResultStore
):
bench_header = {}
bench_header["name"] = "Classification Performance"
bench_header["inputs"] = [
itm.metadata["input_dataset"].to_dict() for itm in metric_store_list
]
result_store.metadata["benchmarks"][self.metric_path] = bench_header
return self._compare_metrics_in_stores(metric_store_list, result_store)