Source code for datapipe_testbench.benchmarks.resolution

#!/usr/bin/env python3

"""
Defines a base Benchmark and related Metrics for dispersion-related measurements.
"""

import enum
import itertools
import warnings
from collections.abc import Iterable
from importlib import resources
from typing import Self, override

import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from ctapipe.io import TableLoader
from hist import axis
from tqdm.auto import tqdm

from ..benchmark import Benchmark
from ..compare_2d import CompareByCat2D
from ..comparers import ComparisonResult
from ..metric import Metric
from ..store import MetricsStore, ResultStore
from ..utils import normalize_category_name

SINGLE_CATEGORY_NAME = "subarray"


def relative_difference(x_reco, x_true):
    """Return relative difference between x_true and x_reco."""
    return x_reco / x_true - 1.0


def single_category_chunks(event_chunks: Iterable, category_name=SINGLE_CATEGORY_NAME):
    """
    Return an iterator that mimics a per-telescope-type chunk iterator.

    But for a single category, named by category_name. This way, per-telescope
    category, and per subarray filling can be done in the same way.
    """
    for start, stop, data in event_chunks:
        yield start, stop, {category_name: data}


class DispersionMetric(Metric):
    """The resolution of a reconstructed property compared to truth."""

    @classmethod
    def setup(
        cls,
        reco_column: str,
        true_axis,
        unit: u.Unit | str,
        reco_symbol: str | None = None,
        categories: list[str] | None = None,
        num_delta_bins: int = 400,
        delta_range: tuple[float, float] = (-20.0, 20.0),
        other_attributes: dict | None = None,
    ) -> Self:
        """Create a default DispersionMetric."""
        if not other_attributes:
            other_attributes = dict()

        other_attributes["reco_column"] = reco_column
        other_attributes["true_column"] = true_axis.name

        # the deltaX/X column:

        delta_label = f"{reco_column} / {true_axis.name} - 1.0 "
        if reco_symbol:
            delta_label = rf"$\Delta {reco_symbol} / {reco_symbol}_{{true}}$"

        delta_axis = axis.Regular(
            bins=num_delta_bins,
            start=delta_range[0],
            stop=delta_range[1],
            name="delta",
            label=delta_label,
        )

        axis_list = [true_axis, delta_axis]
        if categories:
            axis_list.insert(0, axis.StrCategory(categories, name="category"))
            other_attributes["categories"] = categories

        return cls(
            axis_list=axis_list,
            unit=unit,
            label=delta_label,
            output_store=None,
            other_attributes=other_attributes,
        )


class ResolutionMethod(enum.StrEnum):
    """How to define resolution."""

    RMS = enum.auto()  #: Root-mean squared
    STD = enum.auto()  #: Standard deviation


def compute_resolution(method: ResolutionMethod, bias_metric: Metric):
    """Return resolution based on choice of algorithm."""
    if method == ResolutionMethod.RMS:
        resolution = np.sqrt(
            bias_metric._storage.variances() * bias_metric._storage.counts()
            + bias_metric._storage.values() ** 2
        )
    elif method == ResolutionMethod.STD:
        resolution = np.sqrt(
            bias_metric._storage.variances() * bias_metric._storage.counts()
        )
    else:
        raise ValueError(f"Unsupported resolution method: {method}")

    return resolution


[docs] class ResolutionBenchmark(Benchmark): """Base for benchmarks that measure the resolution of a measurement.""" def __init__( self, input_data_level: str, reco_column: str, true_axis, per_tel_type: bool, resolution_method: ResolutionMethod = ResolutionMethod.RMS, reco_symbol: str | None = None, num_delta_bins: int = 100, delta_range: tuple[float, float] = (-10.0, 10.0), chunk_size=50_000, max_chunks=None, filter_function=lambda true, reco: np.isfinite(reco), resolution_requirement_table: str | None = None, bias_requirement_table: str | None = None, bias_plot_range: tuple[float, float] | None = None, ): """Initialize Resolution Benchmark. Parameters ---------- input_data_level : strq name of input data in the InputDataset. Must be one of the fields in that structure. reco_column : str name of column in the event data file to use as the reconstructed quantity true_axis : axis Histogram axis for the true axis of the dispersion. The name of the axis must be the name of a column in the input events data per_tel_type : bool True if this should produce results per telescope type resolution_method : ResolutionMethod How to compute resolution reco_symbol : str, optional Alternate label for the quantity, which can be in latex. If specified, this will be used in the plot labels. num_delta_bins : int how many bins to use in the relative difference (delta) axis delta_range : tuple[float, float] bin range for the delta axis chunk_size : int how many events to load per chunk. Higher is faster, but uses more memory. max_chunks : int maximum number of chunks to process, or None to process all. filter_function : Callable function of f(true_values, reco_values) -> bool to filter what gets filled. resolution_requirement_table : str, optional name of table specifying the requirement in the datapipe_testbench resources. If specified, it will be plotted bias_requirement_table : str, optional name of table specifying the requirement in the datapipe_testbench resources. If specified, it will be plotted bias_plot_range : tuple[float, float], optional range to use when plotting bias (in units of the reco_column). If not specified it will be automatic. """ super().__init__() self.input_data_level = input_data_level self.reco_column = reco_column self.true_axis = true_axis self.chunk_size = chunk_size self.max_chunks = max_chunks self.reco_symbol = reco_symbol self.per_tel_type = per_tel_type self.resolution_method = ResolutionMethod(resolution_method) self.filter_function = filter_function self.resolution_requirement_table = resolution_requirement_table self.bias_requirement_table = bias_requirement_table self.num_delta_bins = num_delta_bins self.delta_range = delta_range self.bias_plot_range = bias_plot_range @property def required_inputs(self) -> set[str]: # noqa: D102 return set( [ self.input_data_level, ] )
[docs] def output_path(self, metric_name): """Return the output path of this metric based on its name.""" return f"{self.input_data_level}/{self.reco_column}_{metric_name.replace(' ', '_')}.asdf"
[docs] @override def outputs(self): """Return names and paths of the metrics produced by this Benchmark.""" return { "dispersion": self.output_path("dispersion"), "bias": self.output_path("bias"), "resolution": self.output_path("resolution"), }
@property def name(self): """Return friendly name of this benchmark.""" return f"{self.reco_column} resolution"
[docs] @override def generate_metrics(self, metric_store: MetricsStore) -> dict | None: events_file = getattr(metric_store.get_inputdata(), self.input_data_level) options = dict( dl1_images=True if "image" in self.input_data_level else False, true_images=True if "image" in self.input_data_level else False, dl1_parameters=True if "dl1" in self.input_data_level else False, true_parameters=True if "dl1" in self.input_data_level else False, dl2=True if "dl2" in self.input_data_level else False, simulated=True, instrument=False, ) self._log.info(f"{self.reco_column} vs {self.true_axis.name}") self._log.info(f" * Filling from: {events_file}") self._log.info(f" * per_tel: {self.per_tel_type}") with TableLoader(events_file, **options) as loader: if self.per_tel_type: categories = list( { normalize_category_name(str(t)) for t in loader.subarray.telescope_types } ) event_chunks = loader.read_telescope_events_by_type_chunked( chunk_size=self.chunk_size ) else: categories = [SINGLE_CATEGORY_NAME] event_chunks = single_category_chunks( loader.read_subarray_events_chunked(chunk_size=self.chunk_size) ) dispersion_metric = DispersionMetric.setup( reco_column=self.reco_column, true_axis=self.true_axis, categories=categories, unit=u.dimensionless_unscaled, # these are counts reco_symbol=self.reco_symbol, num_delta_bins=self.num_delta_bins, delta_range=self.delta_range, ) dispersion_metric.update_name_with_store(metric_store) num_iters = None if self.max_chunks: event_chunks = itertools.islice(event_chunks, self.max_chunks) num_iters = self.max_chunks # loop over chunks for _, _, events_by_type in tqdm( event_chunks, desc=f" * Filling {self.reco_column:.<20.20}", unit="chunks", total=num_iters, ): for category, events in events_by_type.items(): # we use ravel() here so that if the column is an image, # it still works. If it is a scalar, ravel() has no # effect. samples_true = events[self.true_axis.name].ravel() samples_reco = events[self.reco_column].ravel() mask = self.filter_function(samples_true, samples_reco) delta = relative_difference(samples_reco[mask], samples_true[mask]) dispersion_metric.fill( normalize_category_name(category), samples_true[mask], delta ) self._log.info( "Filled with %s samples for %s", len(samples_true), category ) # now compute the derived metrics: the bias and resolution. For that we # can use hist's profile method to generate a profile along the delta # direction. bias_hist = dispersion_metric._storage.profile("delta") bias_metric = Metric( axis_list=bias_hist.axes, hist=bias_hist, label=f"{self.reco_column} bias", unit=dispersion_metric.unit, ) bias_metric.update_name_with_store(metric_store) # for the resolution, we have to compute it from the bias_hist's # variances on the mean, which need to be converted to stdandard deviations: # std = sqrt(var*N) # rms = sqrt(var*N + mean**2) resolution = compute_resolution(self.resolution_method, bias_metric) resolution_metric = Metric( axis_list=bias_metric.axes, label=f"{self.reco_column} resolution ({self.resolution_method})", data=resolution, unit=dispersion_metric.unit, ) resolution_metric.update_name_with_store(metric_store) outputs = self.outputs() metric_store.store_data(dispersion_metric, outputs["dispersion"]) metric_store.store_data(bias_metric, outputs["bias"]) metric_store.store_data(resolution_metric, outputs["resolution"]) return (dispersion_metric, bias_metric, resolution_metric)
[docs] @override def compare_to_reference( self, metric_store_list: list[MetricsStore], result_store: ResultStore ) -> dict | None: # ======================================== # Compare dispersion # ======================================== disp_ref, *disp_others = [ m.retrieve_data(self.output_path("dispersion")) for m in metric_store_list ] comp = CompareByCat2D(disp_ref.axes[0]) scale = "log" if str(disp_ref.axes[1].transform) == "log" else "linear" fig1 = comp.plot_compare_disp_by_category( disp_ref, disp_others, normed=False, xscale=scale ) # define some plot options to make resolution and bias not look like # histograms (since they are not). profile_plot_opts = dict( xerr=True, histtype="errorbar", markersize=0, elinewidth=2 ) # ======================================== # compare resolution # ======================================== res_ref, *res_others = [ m.retrieve_data(self.output_path("resolution")) for m in metric_store_list ] categories = list(res_ref.axes[0]) figsize = (0.8 * 4 * len(categories), 5) fig2 = plt.figure(figsize=figsize, layout="constrained") fig2.suptitle(f"{self.reco_column} resolution ({self.resolution_method})") subfigs = fig2.subfigures(nrows=1, ncols=len(res_ref.axes[0])) subfigs = np.atleast_1d(subfigs) sharex, sharey, sharey_diff = None, None, None if self.resolution_requirement_table: req = np.loadtxt( resources.files("datapipe_testbench.requirements_data") / self.resolution_requirement_table ) for ii, category in enumerate(categories): _, ax = res_ref[ii, ...].plot_compare_1d( [r[category, ...] for r in res_others if category in r.axes[0]], fig=subfigs[ii], legend=False, yerr=0.0, # y-error is meaningless here **profile_plot_opts, ) ax["value"].set_title(category) if self.resolution_requirement_table: ax["value"].plot( req[:, 0], req[:, 1], label="requirement", linestyle="dotted", color="black", ) ax["value"].legend(loc="best") # also fix the x-range, which should be the same as the original # disp axis. If we don't set it here, it gets autoscaled to where there is # data, which then doesn't match the disp plots. ax["value"].set_xlim(disp_ref.axes[1].edges[0], disp_ref.axes[1].edges[-1]) # ensure subsequent plots share an axis sharex = ax["value"] sharey = ax["value"] sharey_diff = ax["diff"] # ======================================== # compare bias # ======================================== bias_ref, *bias_others = [ m.retrieve_data(self.output_path("bias")) for m in metric_store_list ] fig3 = plt.figure(figsize=figsize, layout="constrained") fig3.suptitle(f"{self.reco_column} bias") subfigs = fig3.subfigures(nrows=1, ncols=len(bias_ref.axes[0])) subfigs = np.atleast_1d(subfigs) sharex, sharey, sharey_diff = None, None, None for ii, category in enumerate(bias_ref.axes[0]): with warnings.catch_warnings( action="ignore", category=RuntimeWarning ): # ignore div by zero fig, ax = bias_ref[ii, ...].plot_compare_1d( [r[ii, ...] for r in bias_others if category in r.axes[0]], fig=subfigs[ii], yerr=0.0, # y-error is meaningless, Mean storage type is lost on load sharex=sharex, sharey=sharey, sharey_diff=sharey_diff, **profile_plot_opts, ) ax["value"].axhline(0.0, ls="dotted", color="black") ax["value"].set_title(category) if self.bias_plot_range: ax["value"].set_ylim(*self.bias_plot_range) # also fix the x-range, which should be the same as the original # disp axis. If we don't set it here, it gets autoscaled to where there is # data, which then doesn't match the disp plots. ax["value"].set_xlim(disp_ref.axes[1].edges[0], disp_ref.axes[1].edges[-1]) # ensure subsequent plots share an axis sharex = ax["value"] sharey = ax["value"] sharey_diff = ax["diff"] # Save the results result = ComparisonResult( name=self.reco_column, plots={ f"{self.reco_column}_dispersion": fig1, f"{self.reco_column}_resolution": fig2, f"{self.reco_column}_bias": fig3, }, ) result.store(store=result_store, metric_id=self.input_data_level) return result