Source code for datapipe_testbench.metric

"""
Ndhistogram metric and datamodel location definition.
"""

import logging
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Self

import asdf
import numpy as np
from asdf import AsdfFile
from astropy import units as u
from astropy.table import Row, Table
from hist import Hist, axis
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.pyplot import FigureBase, figure

from . import auto_constants, utils
from .stat_helpers import normalize_along_axis
from .storable import Storable

logger = logging.getLogger(__name__)


__all__ = ["Metric"]


@contextmanager
def _open_asdf(init, schema_url):
    """Open an existing or create a new ASDF file as a context manager."""
    if isinstance(init, dict | asdf.AsdfFile):
        yield asdf.AsdfFile(tree=init, custom_schema=schema_url)
    elif isinstance(init, str | Path) or hasattr(init, "read"):
        with asdf.open(init, custom_schema=schema_url) as asdffile:
            yield asdffile
    else:
        raise ValueError(f"Unexpected init value in Metric.load(): '{init}'")


def get_transform_func(string):
    """
    Return transform object from a string representation.

    Functions are expected to be part of hist.axis.transform.

    Only the following functions are supported now:
    * log
    * sqrt
    * pow(int)

    Parameters
    ----------
    string :
    """
    # Special case for "Pow(2)" that has a parameter in it.
    if "pow(" in string:
        func_name, str_int = string[:-1].split("(")
        func = axis.transform.Pow(int(str_int))
    else:
        func = getattr(axis.transform, string)
    return func


def _get_axis_unit(ax: axis) -> u.Unit:
    if ax.metadata and "unit" in ax.metadata:
        return ax.metadata["unit"]
    else:
        return u.dimensionless_unscaled


[docs] class Metric(Storable): """ND-Histogram that can be serialized to ASDF.""" schema_url = "asdf://cta-observatory.org/benchmark/ndhist.schema" # TODO: Storables are dataclasses: shouldn't this be __post_init__? def __init__( self, axis_list: list[axis.AxesMixin], unit: u.Unit | str = "", label: str | None = None, output_store: Storable | None = None, other_attributes: dict | None = None, data=None, hist: Hist | None = None, ): """Class stores ND-histogram based on scikit-hep Hist package to create bins by accumulation. This initializer **should not be overridden** in subclasses! If you need a specialized constructor, define a @classmethod that does what is needed and calls this generic one to return a properly constructed Metric. Otherwise serialization will not work. Parameters ---------- axis_list : list[axis.AxesMixin] List of hist.axis objects. Each can define a metadata keyword (str) that will be used as the. If the axis has a unit, set it in the "unit" key in the axis.metadata dict label : str | None if none, an automatic label will be generated. unit: u.Unit | None: Unit of the _data_ stored in the metric. Note that you should set _axis_ units by passing them as metadata to the axes in the axis_list, i.e. ``RegularAxis(..., metadata=dict(unit="deg"))`` output_store: ResultStore | None If specified, the name of this metric will be updated to reflect the output name. This is used e.g. by AutoBenchmark for nicer bookkeeping and to avoid name collisions. other_attributes: dict[str,Any] | None Other attributes of this metric that should be stored along with it. data: np.array | None If specified, will set the internal histogram to this value.. Must be the same shape as the axes specify. hist: Hist | None If set, use this for the internal histogram rather than constructing one. Must have the same axis definition as specified . """ super().__init__() # clean up axis list for ax in axis_list: if ax.metadata is not None and not isinstance(ax.metadata, dict): raise ValueError( "Axis {ax} metadata should be a dict, not `{ax.metadata}`" ) # ensure we always hae a metadata dict if ax.metadata is None: ax.metadata = dict() # convert unit strings in axis lists if necessary: ax.metadata["unit"] = u.Unit(ax.metadata.get("unit", "")) if hist: self._storage = hist else: self._storage = Hist(*axis_list, data=data) self.unit = u.Unit(unit) # ensure that if we have a category axis, it is the first one. This is # currently an assumption we make in plotting and comparing. if self.category_dimensions() and self.category_dimensions()[0] != 0: raise ValueError("Metrics expect non-category axes to be the first.") # This will be used in auto-benchmark plots to name the plots, you can still change the label but don't do it # in AutoBenchmark because you'll loose the axis information in the process. if label is None: label = auto_constants.column_separator.join(self.get_identifier()) self.label = label if output_store: # rename this by prepending the output_store's name if one is given self.update_name_with_store(output_store) # add subclass-specific data so we can serialize it. self.other_attributes = other_attributes if other_attributes else dict() logger.debug( "Metric %s (%s) initialized: other_attributes=%s", self.name, self.label, self.other_attributes, ) def __repr__(self): """Nicer repr than the base one.""" return f"{self.__class__.__name__}(name='{self.name}', label='{self.label}', naxes={len(self.axes)})"
[docs] @classmethod @abstractmethod def setup(cls, *args, **kwargs) -> Self: """Create a new Metric with specialized parameters. Overridden by subclasses to properly setup the Metric the first time with user-specified and metric-specific parameters. This should be called instead of the default constructor to set up the Metric correctly. Afterward, it will be loaded from a file using the default Metric constructor. This is because all Metrics must share the same ``__init__`` method for serialization to work. """ return cls(*args, **kwargs) # default is just to call __init__
[docs] @abstractmethod def compare(self, others: list): """Compare several instances of the class. Uses ``self`` as reference when comparing with ``others``. Plots the spectra on top of each other, and calculates the wasserstein distance. Parameters ---------- others : list[Metric] List containing other instances of this class. Returns ------- ComparisonResult | None Results of comparison of this reference to the others list. """ return None
[docs] def fill(self, *columns): """Fill defined axis. Parameters ---------- *columns : """ self._storage.fill(*columns)
@property def axes(self): """Wrapper to transfer to _storage.""" return self._storage.axes
[docs] def get_identifier(self): """Auto generation of a unique identifier for the metric. Use axis order and name to return a tuple that need to be unique and allow identification of the current metric throughout multiple metric stores. """ # Force list then tuple to avoid a generator identifier = tuple([ax.name for ax in self.axes]) return identifier
[docs] def project(self, *columns): """Project underlying histogram. Parameters ---------- *columns : """ new_storage = self._storage.project(*columns) # Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to # axis_list for each subclass) new_obj = self.__class__( axis_list=new_storage.axes, unit=self.unit, label=self.label ) new_obj._storage = new_storage new_obj.name = self.name return new_obj
[docs] def stack(self, *columns): """Stack underlying histogram. Parameters ---------- *columns : """ new_storage = self._storage.stack(*columns) # Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to # axis_list for each subclass) new_obj = Metric(axis_list=new_storage.axes, label=self.label) new_obj._storage = new_storage new_obj.name = self.name return new_obj
[docs] def to_numpy(self, flow=False): """Return numpy representation of underlying histogram.""" result = list(self._storage.to_numpy(flow=flow)) # Add units to axis arrays # result is first the ndarray, then 1d samples for each axis for i, ax in enumerate(self.axes): result[i + 1] = u.Quantity(result[i + 1], _get_axis_unit(ax)) result[0] = u.Quantity(result[0], self.unit) return tuple(result)
def __getitem__(self, args): """Redirect getitem calls to underlying histogram and handle result. If the requested slice return a single bin, return a single value, else return a new metric. Parameters ---------- args : """ new_storage = self._storage.__getitem__(args) # If we ask for one specific bin we just get its content # as a primitive type if not isinstance(new_storage, Hist): return new_storage # Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to # axis_list for each subclass) new_obj = self.__class__( axis_list=new_storage.axes, label=self.label, unit=self.unit ) new_obj._storage = new_storage new_obj.name = self.name return new_obj def __truediv__(self, other): """Divide a Metric.""" if isinstance(other, Metric): denominator = other._storage extra_label = "ratio" if not self.compatible(other): raise ValueError("Can't divide incompatible metrics.") else: # is a scalar or array other = u.Quantity(other) denominator = other.value extra_label = f" / {other}" return self.__class__( axis_list=self.axes, label=" ".join([self.label, extra_label]), unit=self.unit / other.unit, hist=self._storage / denominator, ) def __add__(self, other): """Add another metric or scalar.""" if isinstance(other, Metric): other_value = other._storage if not self.unit.is_equivalent(other.unit): raise ValueError("Can't add incompatible units.") other_label = other.label else: other_value = u.Quantity(other).to_value(self.unit) other_label = f"{u.Quantity(other):latex}" return self.__class__( axis_list=self.axes, label=f"{self.label} + {other_label}", unit=self.unit, hist=self._storage + other_value, ) def __sub__(self, other): """Add another metric or scalar.""" if isinstance(other, Metric): other_value = other._storage if not self.unit.is_equivalent(other.unit): raise ValueError("Can't add incompatible units.") other_label = other.label else: other_value = u.Quantity(other).to_value(self.unit) other_label = f"{u.Quantity(other):latex}" return self.__class__( axis_list=self.axes, label=f"{self.label} - {other_label}", unit=self.unit, hist=self._storage - other_value, )
[docs] def category_dimensions(self) -> list[int]: """Return indices of non-continuous dimensions.""" return [i for i, ax in enumerate(self.axes) if ax.traits.discrete]
[docs] def plot(self, full=False, *args, **kwargs): """Redirect plot calls to underlying histogram.""" from matplotlib.pyplot import gca noncategory_dims = list( set(range(len(self.axes))) - set(self.category_dimensions()) ) num_noncategory_dims = len(noncategory_dims) # make a copy so we can modify the labels without breaking things hist = self._storage.copy(deep=False) for ax in hist.axes: # update the labels with units unit = ax.metadata.get("unit", u.dimensionless_unscaled) if unit != u.dimensionless_unscaled: ax.label = f"{ax.label} [{unit:latex}]" plot_ax = kwargs.get("ax", gca()) if num_noncategory_dims == 1: # for 1D histograms, we want to label by the metric and it's unit. ylabel = ( self.label if self.unit == u.dimensionless_unscaled else f"{self.label} [{self.unit:latex}]" ) plot_ax.set_ylabel(ylabel) if num_noncategory_dims == 2: plot_ax.set_title(f"{self.name} [{self.unit:latex}]") if full and num_noncategory_dims == 2: return hist.plot2d_full(*args, **kwargs) ret = hist.plot(*args, **kwargs) x_axis_dim = noncategory_dims[0] plot_ax.set_xscale( "log" if str(self.axes[x_axis_dim].transform) == "log" else "linear" ) return ret
[docs] def plot_compare_1d( self, others: Self | list[Self] = None, fig: None | FigureBase = None, legend=True, show_xlabel=True, layout: str | None = "constrained", sharex: Axes | None = None, sharey: Axes | None = None, sharey_diff: Axes | None = None, **kwargs, ) -> dict[str, Axes]: r"""Generate a subfigure with comparisons and residuals. The residuals are the relative differences, :math:`\Delta_\mathrm{rel} \equiv \frac{(v_i - v_\mathrm{ref})}{v_\mathrm{ref}}` Parameters ---------- others: Metric | list[Metric] | None List of metrics to compare to reference fig: matplotlib.Figure | None Figure or SubFigure to add the axes to, or none to use current. legend: bool Show legend on plot layout: str | None If fig not passed, in, use this maptlotlib layout engine. sharex: Axes | None If specified, share all x-axes with this axis. sharey: Axes | None If specified share the value's y-axis with this axis sharey_diff: Axes | None If specified share the diff's y-axis with this axis. **kwargs: Any other args are passed to the plot() functions. Returns ------- dict[matplotlib.Axes]: "val": value_axis, "diff": relative_difference_axis """ fig = fig or figure(layout=layout) others = others or [] others = others if isinstance(others, list) else [others] if len(self.axes) != 1: raise ValueError(f"Expected 1D metric, but got {len(self.axes)}D.") gs = GridSpec(2, 1, figure=fig, hspace=0, height_ratios=[0.7, 0.3]) ax_val = fig.add_subplot(gs[0], sharex=sharex, sharey=sharey) ax_diff = fig.add_subplot(gs[1], sharex=ax_val, sharey=sharey_diff) # plot the comparisons self.plot(ax=ax_val, label=self.name, **kwargs) for metric in others: metric.plot(ax=ax_val, label=metric.name, **kwargs) # plot the relative differences ax_diff.axhline(0.0, lw=3, color="C0") for ii, metric in enumerate(others): (metric / self - 1).plot( ax=ax_diff, label=metric.name, color=f"C{ii + 1}", **kwargs ) if legend: ax_val.legend(loc="best") ax_diff.set_ylabel(r"$\Delta_\mathrm{rel}$") ax_val.tick_params(axis="x", direction="in") ax_val.set_xlabel(None, labelpad=0) # don't show x-tic labels on the value axis for label in ax_val.get_xticklabels(): label.set_visible(False) if show_xlabel is False: # optionally don't show x-tic labels on the diff axis for label in ax_diff.get_xticklabels(): ax_diff.set_xlabel(None, labelpad=0) # has no effect label.set_visible(False) return fig, dict(value=ax_val, diff=ax_diff)
[docs] def save(self, output_file: Path): """Save metric to file. Parameters ---------- output_file : Path Output filename. """ init = self._init_treedict() data, *dummy = self.to_numpy(flow=True) axis_spec = [] for ax in self.axes: axis_spec.append(serialize_hist_ax(ax)) init["data"] = data.to_value(self.unit) init["axes"] = axis_spec init["label"] = self.label init["metadata"] = self.other_attributes init["unit"] = self.unit.to_string("fits") logger.debug("Saving ASDF file with tree: %s", init) with AsdfFile(init, custom_schema=self.schema_url) as asdffile: if not output_file.suffix == ".asdf": output_file = output_file.with_suffix(output_file.suffix + ".asdf") asdffile.write_to(output_file)
[docs] def compatible(self, other): """ Compare objects and ensure number and name of axis match. Parameters ---------- other : Metric other : Other instance to compare self to. Raises ------ IndexError. Returns ------- True if the objects are compatible. """ is_compatible = True self_ndim = len(self.axes) other_ndim = len(other.axes) if self_ndim != other_ndim: raise IndexError( f"Dimensions of {self.label} does not match those of {other.label}: {self_ndim} != {other_ndim}" ) for self_ax, other_ax in zip(self.axes, other.axes): self_class, self_params = serialize_hist_ax(self_ax) other_class, other_params = serialize_hist_ax(other_ax) if self_class != other_class: raise IndexError( f"Axes type of {self.label} does not match those of {other.label}: {self_class} != {other_class}" ) if self_params["name"] != other_params["name"]: raise IndexError( f"Axes name of {self.label} does not match those of {other.label}: {self_params['name']} != {other_params['name']}" ) return is_compatible
[docs] @classmethod def load(cls, init): """ Load class instance from file. Parameters ---------- cls : init : Dict or string """ instance = None with _open_asdf(init, cls.schema_url) as asdffile: metric_type = asdffile.tree["filetype"] if metric_type != cls.__name__: # TODO: with the refactored init method, not sure if this is really necessary anymore msg = f"File is of type {metric_type}, can't load with {cls.__name__}" asdffile.close() # Close file before raising error. raise TypeError(msg) axis_list = [unserialize_hist_ax(*args) for args in asdffile.tree["axes"]] # Create the instance and connect any other information instance = cls( axis_list=axis_list, label=asdffile.tree["label"], unit=asdffile.tree["unit"], output_store=None, other_attributes=asdffile.tree["metadata"], ) # Ensure stored axis_list in file is compatible with default from # class (in case of class changes). However, if the instance axes is # set to None, then we want to use the one in the file. ax_equal, msg = compare_axis_list(axis_list, instance.axes) if instance.axes and not ax_equal: msg = ( f"Stored axis_list is different that the one defined for {cls}\n" + msg ) asdffile.close() # Close file before raising error. raise ValueError(msg) # Overwrite axis_list, for StrCategory, now that we know they are compatible instance._storage = Hist(*axis_list) instance._storage[...] = asdffile.tree["data"] instance.name = asdffile.tree["name"] return instance
[docs] def normalize_to_probability( self, axis_index: int = 1, threshold_frac=0.005, fill_value=np.nan ) -> Self: """ Normalize the data such that the integral over axis ``axis_index`` 1.0. Parameters ---------- arr : np.ndarray The input array. axis_index : int The axis along which to normalize. threshold_frac: float if the integral along the axis is below this fraction of the total integral, mask off this row as having too low stats to plot. This prevents low-stats values form saturating the color scale fill_value: value to replace low-stats entries with Returns ------- np.ndarray A new Metric normalized along the specified axis. """ if not self.unit.is_equivalent(u.dimensionless_unscaled): raise ValueError( f"Can't normalize a non-dimensionless quantity ({self.unit})" ) data = self.to_numpy()[0] normed = normalize_along_axis( data, axis_index, threshold_frac=threshold_frac, fill_value=fill_value ) return self.__class__( axis_list=self.axes, label=f"{self.label} (normed)", unit=self.unit, data=normed, )
[docs] def split_by_category(self) -> dict[Self]: """Turn a Metric with categories into a dict of Metrics by category.""" raise NotImplementedError()
def serialize_hist_ax(ax): """ Extract parameters from a Hist axes to pass to a future Metric class. Parameters ---------- ax : """ params = {} params["name"] = ax.name params["label"] = ax.label params["growth"] = ax.traits.growth if ax.metadata and isinstance(ax.metadata, dict): unit = u.Unit(ax.metadata.get("unit", "")) else: unit = u.dimensionless_unscaled params["metadata"] = dict(unit=unit.to_string("fits")) klass_name = ax.__class__.__name__ if klass_name == "Boolean": del params["growth"] elif klass_name in ["StrCategory", "IntCategory"]: params["categories"] = [c for c in ax] elif klass_name in ["Variable"]: del params["growth"] params["edges"] = ax.edges else: # get start and stop from edges params["start"] = ax.edges[0] params["stop"] = ax.edges[-1] params["bins"] = ax.size # No need for nbins in StrCategory, it makes no sense params["underflow"] = ax.traits.underflow params["overflow"] = ax.traits.overflow params["circular"] = ax.traits.circular if ax.transform is not None: params["transform"] = str(ax.transform) return (klass_name, params) def unserialize_hist_ax(klass_name, in_params): """ From the axis definition in the asdf file, provide the instantiated axis. Parameters ---------- in_params : klass_name : params : """ klass = getattr(axis, klass_name) out_params = {} for k, v in in_params.items(): if k == "transform": out_params[k] = get_transform_func(v) elif k == "metadata": v["unit"] = u.Unit(v["unit"]) out_params[k] = v else: out_params[k] = v return klass(**out_params) def compare_axis_list(l1, l2): """ Compare 2 axis_list. Axis_list is expected to be a list of tuple (class, dict) where class is the axis class, and dict is the dictionary of parameter to init the class Parameters ---------- l1 : First axis_list. l2 : Second. """ is_equal = True msg = "" nax1 = len(l1) nax2 = len(l2) if nax1 != nax2: is_equal = False msg += f"axis_list have different nb axis ({nax1} vs {nax2}).\n" ax_id = 0 for ax1, ax2 in zip(l1, l2): (class1, p1) = serialize_hist_ax(ax1) (class2, p2) = serialize_hist_ax(ax2) if class1 != class2: is_equal = False msg += f"ax{ax_id}: Got {class1}, expected {class2}\n" eq, m = utils.compare_ax_serialisation(p1, p2) if not eq: is_equal = False msg += m return is_equal, msg class IRFMetric(Metric): @classmethod @abstractmethod def _from_specific_table(cls, irf_table): # noqa """Create class instance from an IRF in GADF format.""" raise NotImplementedError( "The subclass of IRFMetric doesn't implement _from_specific_table()" ) @classmethod def from_table(cls, irf_table, output_store=None): """ Create instance from an IRF table in GADF format. Parameters ---------- irf_table: Table astropy table containing the IRF in GADF/Fermi format. output_store: MetricStore | None if passed, used to define the name of the Metric """ if issubclass(type(irf_table), Table): irf_table = irf_table[0] elif not issubclass(type(irf_table), Row): raise ValueError( "table needs to be of type `astropy.Table` or `astropy.Row`" ) new_metric = cls._from_specific_table(irf_table) new_metric.update_name_with_store(output_store) return new_metric @classmethod def _get_metric_bins_from_table(cls, table, reg_cols=[], cat_cols=[]): """Get bin edges from a GADF table easier.""" if issubclass(type(table), Table): table = table[0] elif not issubclass(type(table), Row): raise ValueError( "table needs to be of type `astropy.Table` or `astropy.Row`" ) bins = [] for reg in reg_cols: bins.append( utils.flatten_lo_hi_bins(table[f"{reg}_LO"], table[f"{reg}_HI"]) ) for cat in cat_cols: cols = [ f"{lo.value} to {hi}" for lo, hi in zip(table[f"{cat}_LO"], table[f"{cat}_HI"]) ] bins.append(cols) return bins