#!/usr/bin/env python3
"""
Defines a base Benchmark and related Metrics for dispersion-related measurements.
"""
import enum
import itertools
import warnings
from collections.abc import Iterable
from importlib import resources
from typing import Self, override
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from ctapipe.io import TableLoader
from hist import axis
from tqdm.auto import tqdm
from ..benchmark import Benchmark
from ..compare_2d import CompareByCat2D
from ..comparers import ComparisonResult
from ..metric import Metric
from ..store import MetricsStore, ResultStore
from ..utils import normalize_category_name
SINGLE_CATEGORY_NAME = "subarray"
def relative_difference(x_reco, x_true):
"""Return relative difference between x_true and x_reco."""
return x_reco / x_true - 1.0
def single_category_chunks(event_chunks: Iterable, category_name=SINGLE_CATEGORY_NAME):
"""
Return an iterator that mimics a per-telescope-type chunk iterator.
But for a single category, named by category_name. This way, per-telescope
category, and per subarray filling can be done in the same way.
"""
for start, stop, data in event_chunks:
yield start, stop, {category_name: data}
class DispersionMetric(Metric):
"""The resolution of a reconstructed property compared to truth."""
@classmethod
def setup(
cls,
reco_column: str,
true_axis,
unit: u.Unit | str,
reco_symbol: str | None = None,
categories: list[str] | None = None,
num_delta_bins: int = 400,
delta_range: tuple[float, float] = (-20.0, 20.0),
other_attributes: dict | None = None,
) -> Self:
"""Create a default DispersionMetric."""
if not other_attributes:
other_attributes = dict()
other_attributes["reco_column"] = reco_column
other_attributes["true_column"] = true_axis.name
# the deltaX/X column:
delta_label = f"{reco_column} / {true_axis.name} - 1.0 "
if reco_symbol:
delta_label = rf"$\Delta {reco_symbol} / {reco_symbol}_{{true}}$"
delta_axis = axis.Regular(
bins=num_delta_bins,
start=delta_range[0],
stop=delta_range[1],
name="delta",
label=delta_label,
)
axis_list = [true_axis, delta_axis]
if categories:
axis_list.insert(0, axis.StrCategory(categories, name="category"))
other_attributes["categories"] = categories
return cls(
axis_list=axis_list,
unit=unit,
label=delta_label,
output_store=None,
other_attributes=other_attributes,
)
class ResolutionMethod(enum.StrEnum):
"""How to define resolution."""
RMS = enum.auto() #: Root-mean squared
STD = enum.auto() #: Standard deviation
def compute_resolution(method: ResolutionMethod, bias_metric: Metric):
"""Return resolution based on choice of algorithm."""
if method == ResolutionMethod.RMS:
resolution = np.sqrt(
bias_metric._storage.variances() * bias_metric._storage.counts()
+ bias_metric._storage.values() ** 2
)
elif method == ResolutionMethod.STD:
resolution = np.sqrt(
bias_metric._storage.variances() * bias_metric._storage.counts()
)
else:
raise ValueError(f"Unsupported resolution method: {method}")
return resolution
[docs]
class ResolutionBenchmark(Benchmark):
"""Base for benchmarks that measure the resolution of a measurement."""
def __init__(
self,
input_data_level: str,
reco_column: str,
true_axis,
per_tel_type: bool,
resolution_method: ResolutionMethod = ResolutionMethod.RMS,
reco_symbol: str | None = None,
num_delta_bins: int = 100,
delta_range: tuple[float, float] = (-10.0, 10.0),
chunk_size=50_000,
max_chunks=None,
filter_function=lambda true, reco: np.isfinite(reco),
resolution_requirement_table: str | None = None,
bias_requirement_table: str | None = None,
bias_plot_range: tuple[float, float] | None = None,
):
"""Initialize Resolution Benchmark.
Parameters
----------
input_data_level : strq
name of input data in the InputDataset. Must be one of the fields in that structure.
reco_column : str
name of column in the event data file to use as the reconstructed quantity
true_axis : axis
Histogram axis for the true axis of the dispersion. The name of the axis
must be the name of a column in the input events data
per_tel_type : bool
True if this should produce results per telescope type
resolution_method : ResolutionMethod
How to compute resolution
reco_symbol : str, optional
Alternate label for the quantity, which can be in latex. If specified, this
will be used in the plot labels.
num_delta_bins : int
how many bins to use in the relative difference (delta) axis
delta_range : tuple[float, float]
bin range for the delta axis
chunk_size : int
how many events to load per chunk. Higher is faster, but uses more memory.
max_chunks : int
maximum number of chunks to process, or None to process all.
filter_function : Callable
function of f(true_values, reco_values) -> bool to filter what gets filled.
resolution_requirement_table : str, optional
name of table specifying the requirement in the datapipe_testbench resources.
If specified, it will be plotted
bias_requirement_table : str, optional
name of table specifying the requirement in the datapipe_testbench resources.
If specified, it will be plotted
bias_plot_range : tuple[float, float], optional
range to use when plotting bias (in units of the reco_column). If not specified
it will be automatic.
"""
super().__init__()
self.input_data_level = input_data_level
self.reco_column = reco_column
self.true_axis = true_axis
self.chunk_size = chunk_size
self.max_chunks = max_chunks
self.reco_symbol = reco_symbol
self.per_tel_type = per_tel_type
self.resolution_method = ResolutionMethod(resolution_method)
self.filter_function = filter_function
self.resolution_requirement_table = resolution_requirement_table
self.bias_requirement_table = bias_requirement_table
self.num_delta_bins = num_delta_bins
self.delta_range = delta_range
self.bias_plot_range = bias_plot_range
@property
def required_inputs(self) -> set[str]: # noqa: D102
return set(
[
self.input_data_level,
]
)
[docs]
def output_path(self, metric_name):
"""Return the output path of this metric based on its name."""
return f"{self.input_data_level}/{self.reco_column}_{metric_name.replace(' ', '_')}.asdf"
[docs]
@override
def outputs(self):
"""Return names and paths of the metrics produced by this Benchmark."""
return {
"dispersion": self.output_path("dispersion"),
"bias": self.output_path("bias"),
"resolution": self.output_path("resolution"),
}
@property
def name(self):
"""Return friendly name of this benchmark."""
return f"{self.reco_column} resolution"
[docs]
@override
def generate_metrics(self, metric_store: MetricsStore) -> dict | None:
events_file = getattr(metric_store.get_inputdata(), self.input_data_level)
options = dict(
dl1_images=True if "image" in self.input_data_level else False,
true_images=True if "image" in self.input_data_level else False,
dl1_parameters=True if "dl1" in self.input_data_level else False,
true_parameters=True if "dl1" in self.input_data_level else False,
dl2=True if "dl2" in self.input_data_level else False,
simulated=True,
instrument=False,
)
self._log.info(f"{self.reco_column} vs {self.true_axis.name}")
self._log.info(f" * Filling from: {events_file}")
self._log.info(f" * per_tel: {self.per_tel_type}")
with TableLoader(events_file, **options) as loader:
if self.per_tel_type:
categories = list(
{
normalize_category_name(str(t))
for t in loader.subarray.telescope_types
}
)
event_chunks = loader.read_telescope_events_by_type_chunked(
chunk_size=self.chunk_size
)
else:
categories = [SINGLE_CATEGORY_NAME]
event_chunks = single_category_chunks(
loader.read_subarray_events_chunked(chunk_size=self.chunk_size)
)
dispersion_metric = DispersionMetric.setup(
reco_column=self.reco_column,
true_axis=self.true_axis,
categories=categories,
unit=u.dimensionless_unscaled, # these are counts
reco_symbol=self.reco_symbol,
num_delta_bins=self.num_delta_bins,
delta_range=self.delta_range,
)
dispersion_metric.update_name_with_store(metric_store)
num_iters = None
if self.max_chunks:
event_chunks = itertools.islice(event_chunks, self.max_chunks)
num_iters = self.max_chunks
# loop over chunks
for _, _, events_by_type in tqdm(
event_chunks,
desc=f" * Filling {self.reco_column:.<20.20}",
unit="chunks",
total=num_iters,
):
for category, events in events_by_type.items():
# we use ravel() here so that if the column is an image,
# it still works. If it is a scalar, ravel() has no
# effect.
samples_true = events[self.true_axis.name].ravel()
samples_reco = events[self.reco_column].ravel()
mask = self.filter_function(samples_true, samples_reco)
delta = relative_difference(samples_reco[mask], samples_true[mask])
dispersion_metric.fill(
normalize_category_name(category), samples_true[mask], delta
)
self._log.info(
"Filled with %s samples for %s", len(samples_true), category
)
# now compute the derived metrics: the bias and resolution. For that we
# can use hist's profile method to generate a profile along the delta
# direction.
bias_hist = dispersion_metric._storage.profile("delta")
bias_metric = Metric(
axis_list=bias_hist.axes,
hist=bias_hist,
label=f"{self.reco_column} bias",
unit=dispersion_metric.unit,
)
bias_metric.update_name_with_store(metric_store)
# for the resolution, we have to compute it from the bias_hist's
# variances on the mean, which need to be converted to stdandard deviations:
# std = sqrt(var*N)
# rms = sqrt(var*N + mean**2)
resolution = compute_resolution(self.resolution_method, bias_metric)
resolution_metric = Metric(
axis_list=bias_metric.axes,
label=f"{self.reco_column} resolution ({self.resolution_method})",
data=resolution,
unit=dispersion_metric.unit,
)
resolution_metric.update_name_with_store(metric_store)
outputs = self.outputs()
metric_store.store_data(dispersion_metric, outputs["dispersion"])
metric_store.store_data(bias_metric, outputs["bias"])
metric_store.store_data(resolution_metric, outputs["resolution"])
return (dispersion_metric, bias_metric, resolution_metric)
[docs]
@override
def compare_to_reference(
self, metric_store_list: list[MetricsStore], result_store: ResultStore
) -> dict | None:
# ========================================
# Compare dispersion
# ========================================
disp_ref, *disp_others = [
m.retrieve_data(self.output_path("dispersion")) for m in metric_store_list
]
comp = CompareByCat2D(disp_ref.axes[0])
scale = "log" if str(disp_ref.axes[1].transform) == "log" else "linear"
fig1 = comp.plot_compare_disp_by_category(
disp_ref, disp_others, normed=False, xscale=scale
)
# define some plot options to make resolution and bias not look like
# histograms (since they are not).
profile_plot_opts = dict(
xerr=True, histtype="errorbar", markersize=0, elinewidth=2
)
# ========================================
# compare resolution
# ========================================
res_ref, *res_others = [
m.retrieve_data(self.output_path("resolution")) for m in metric_store_list
]
categories = list(res_ref.axes[0])
figsize = (0.8 * 4 * len(categories), 5)
fig2 = plt.figure(figsize=figsize, layout="constrained")
fig2.suptitle(f"{self.reco_column} resolution ({self.resolution_method})")
subfigs = fig2.subfigures(nrows=1, ncols=len(res_ref.axes[0]))
subfigs = np.atleast_1d(subfigs)
sharex, sharey, sharey_diff = None, None, None
if self.resolution_requirement_table:
req = np.loadtxt(
resources.files("datapipe_testbench.requirements_data")
/ self.resolution_requirement_table
)
for ii, category in enumerate(categories):
_, ax = res_ref[ii, ...].plot_compare_1d(
[r[category, ...] for r in res_others if category in r.axes[0]],
fig=subfigs[ii],
legend=False,
yerr=0.0, # y-error is meaningless here
**profile_plot_opts,
)
ax["value"].set_title(category)
if self.resolution_requirement_table:
ax["value"].plot(
req[:, 0],
req[:, 1],
label="requirement",
linestyle="dotted",
color="black",
)
ax["value"].legend(loc="best")
# also fix the x-range, which should be the same as the original
# disp axis. If we don't set it here, it gets autoscaled to where there is
# data, which then doesn't match the disp plots.
ax["value"].set_xlim(disp_ref.axes[1].edges[0], disp_ref.axes[1].edges[-1])
# ensure subsequent plots share an axis
sharex = ax["value"]
sharey = ax["value"]
sharey_diff = ax["diff"]
# ========================================
# compare bias
# ========================================
bias_ref, *bias_others = [
m.retrieve_data(self.output_path("bias")) for m in metric_store_list
]
fig3 = plt.figure(figsize=figsize, layout="constrained")
fig3.suptitle(f"{self.reco_column} bias")
subfigs = fig3.subfigures(nrows=1, ncols=len(bias_ref.axes[0]))
subfigs = np.atleast_1d(subfigs)
sharex, sharey, sharey_diff = None, None, None
for ii, category in enumerate(bias_ref.axes[0]):
with warnings.catch_warnings(
action="ignore", category=RuntimeWarning
): # ignore div by zero
fig, ax = bias_ref[ii, ...].plot_compare_1d(
[r[ii, ...] for r in bias_others if category in r.axes[0]],
fig=subfigs[ii],
yerr=0.0, # y-error is meaningless, Mean storage type is lost on load
sharex=sharex,
sharey=sharey,
sharey_diff=sharey_diff,
**profile_plot_opts,
)
ax["value"].axhline(0.0, ls="dotted", color="black")
ax["value"].set_title(category)
if self.bias_plot_range:
ax["value"].set_ylim(*self.bias_plot_range)
# also fix the x-range, which should be the same as the original
# disp axis. If we don't set it here, it gets autoscaled to where there is
# data, which then doesn't match the disp plots.
ax["value"].set_xlim(disp_ref.axes[1].edges[0], disp_ref.axes[1].edges[-1])
# ensure subsequent plots share an axis
sharex = ax["value"]
sharey = ax["value"]
sharey_diff = ax["diff"]
# Save the results
result = ComparisonResult(
name=self.reco_column,
plots={
f"{self.reco_column}_dispersion": fig1,
f"{self.reco_column}_resolution": fig2,
f"{self.reco_column}_bias": fig3,
},
)
result.store(store=result_store, metric_id=self.input_data_level)
return result