"""
Ndhistogram metric and datamodel location definition.
"""
import logging
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Self
import asdf
import numpy as np
from asdf import AsdfFile
from astropy import units as u
from astropy.table import Row, Table
from hist import Hist, axis
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.pyplot import FigureBase, figure
from . import auto_constants, utils
from .stat_helpers import normalize_along_axis
from .storable import Storable
logger = logging.getLogger(__name__)
__all__ = ["Metric"]
@contextmanager
def _open_asdf(init, schema_url):
"""Open an existing or create a new ASDF file as a context manager."""
if isinstance(init, dict | asdf.AsdfFile):
yield asdf.AsdfFile(tree=init, custom_schema=schema_url)
elif isinstance(init, str | Path) or hasattr(init, "read"):
with asdf.open(init, custom_schema=schema_url) as asdffile:
yield asdffile
else:
raise ValueError(f"Unexpected init value in Metric.load(): '{init}'")
def get_transform_func(string):
"""
Return transform object from a string representation.
Functions are expected to be part of hist.axis.transform.
Only the following functions are supported now:
* log
* sqrt
* pow(int)
Parameters
----------
string :
"""
# Special case for "Pow(2)" that has a parameter in it.
if "pow(" in string:
func_name, str_int = string[:-1].split("(")
func = axis.transform.Pow(int(str_int))
else:
func = getattr(axis.transform, string)
return func
def _get_axis_unit(ax: axis) -> u.Unit:
if ax.metadata and "unit" in ax.metadata:
return ax.metadata["unit"]
else:
return u.dimensionless_unscaled
[docs]
class Metric(Storable):
"""ND-Histogram that can be serialized to ASDF."""
schema_url = "asdf://cta-observatory.org/benchmark/ndhist.schema"
# TODO: Storables are dataclasses: shouldn't this be __post_init__?
def __init__(
self,
axis_list: list[axis.AxesMixin],
unit: u.Unit | str = "",
label: str | None = None,
output_store: Storable | None = None,
other_attributes: dict | None = None,
data=None,
hist: Hist | None = None,
):
"""Class stores ND-histogram based on scikit-hep Hist package to create bins by accumulation.
This initializer **should not be overridden** in subclasses! If you need
a specialized constructor, define a @classmethod that does what is
needed and calls this generic one to return a properly constructed
Metric. Otherwise serialization will not work.
Parameters
----------
axis_list : list[axis.AxesMixin]
List of hist.axis objects. Each can define a metadata keyword (str)
that will be used as the. If the axis has a unit, set it in the
"unit" key in the axis.metadata dict
label : str | None
if none, an automatic label will be generated.
unit: u.Unit | None:
Unit of the _data_ stored in the metric. Note that you should set _axis_
units by passing them as metadata to the axes in the axis_list, i.e.
``RegularAxis(..., metadata=dict(unit="deg"))``
output_store: ResultStore | None
If specified, the name of this metric will be
updated to reflect the output name. This is used e.g. by AutoBenchmark
for nicer bookkeeping and to avoid name collisions.
other_attributes: dict[str,Any] | None
Other attributes of this metric that should be stored along with it.
data: np.array | None
If specified, will set the internal histogram to this value.. Must
be the same shape as the axes specify.
hist: Hist | None
If set, use this for the internal histogram rather than constructing
one. Must have the same axis definition as specified .
"""
super().__init__()
# clean up axis list
for ax in axis_list:
if ax.metadata is not None and not isinstance(ax.metadata, dict):
raise ValueError(
"Axis {ax} metadata should be a dict, not `{ax.metadata}`"
)
# ensure we always hae a metadata dict
if ax.metadata is None:
ax.metadata = dict()
# convert unit strings in axis lists if necessary:
ax.metadata["unit"] = u.Unit(ax.metadata.get("unit", ""))
if hist:
self._storage = hist
else:
self._storage = Hist(*axis_list, data=data)
self.unit = u.Unit(unit)
# ensure that if we have a category axis, it is the first one. This is
# currently an assumption we make in plotting and comparing.
if self.category_dimensions() and self.category_dimensions()[0] != 0:
raise ValueError("Metrics expect non-category axes to be the first.")
# This will be used in auto-benchmark plots to name the plots, you can still change the label but don't do it
# in AutoBenchmark because you'll loose the axis information in the process.
if label is None:
label = auto_constants.column_separator.join(self.get_identifier())
self.label = label
if output_store:
# rename this by prepending the output_store's name if one is given
self.update_name_with_store(output_store)
# add subclass-specific data so we can serialize it.
self.other_attributes = other_attributes if other_attributes else dict()
logger.debug(
"Metric %s (%s) initialized: other_attributes=%s",
self.name,
self.label,
self.other_attributes,
)
def __repr__(self):
"""Nicer repr than the base one."""
return f"{self.__class__.__name__}(name='{self.name}', label='{self.label}', naxes={len(self.axes)})"
[docs]
@classmethod
@abstractmethod
def setup(cls, *args, **kwargs) -> Self:
"""Create a new Metric with specialized parameters.
Overridden by subclasses to properly setup the Metric the first time
with user-specified and metric-specific parameters. This should be
called instead of the default constructor to set up the Metric
correctly. Afterward, it will be loaded from a file using the default
Metric constructor. This is because all Metrics must share the same
``__init__`` method for serialization to work.
"""
return cls(*args, **kwargs) # default is just to call __init__
[docs]
@abstractmethod
def compare(self, others: list):
"""Compare several instances of the class.
Uses ``self`` as reference when comparing with ``others``. Plots the spectra
on top of each other, and calculates the wasserstein distance.
Parameters
----------
others : list[Metric]
List containing other instances of this class.
Returns
-------
ComparisonResult | None
Results of comparison of this reference to the others list.
"""
return None
[docs]
def fill(self, *columns):
"""Fill defined axis.
Parameters
----------
*columns :
"""
self._storage.fill(*columns)
@property
def axes(self):
"""Wrapper to transfer to _storage."""
return self._storage.axes
[docs]
def get_identifier(self):
"""Auto generation of a unique identifier for the metric.
Use axis order and name to return a tuple that need to be unique and allow identification of the current
metric throughout multiple metric stores.
"""
# Force list then tuple to avoid a generator
identifier = tuple([ax.name for ax in self.axes])
return identifier
[docs]
def project(self, *columns):
"""Project underlying histogram.
Parameters
----------
*columns :
"""
new_storage = self._storage.project(*columns)
# Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to
# axis_list for each subclass)
new_obj = self.__class__(
axis_list=new_storage.axes, unit=self.unit, label=self.label
)
new_obj._storage = new_storage
new_obj.name = self.name
return new_obj
[docs]
def stack(self, *columns):
"""Stack underlying histogram.
Parameters
----------
*columns :
"""
new_storage = self._storage.stack(*columns)
# Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to
# axis_list for each subclass)
new_obj = Metric(axis_list=new_storage.axes, label=self.label)
new_obj._storage = new_storage
new_obj.name = self.name
return new_obj
[docs]
def to_numpy(self, flow=False):
"""Return numpy representation of underlying histogram."""
result = list(self._storage.to_numpy(flow=flow))
# Add units to axis arrays
# result is first the ndarray, then 1d samples for each axis
for i, ax in enumerate(self.axes):
result[i + 1] = u.Quantity(result[i + 1], _get_axis_unit(ax))
result[0] = u.Quantity(result[0], self.unit)
return tuple(result)
def __getitem__(self, args):
"""Redirect getitem calls to underlying histogram and handle result.
If the requested slice return a single bin, return a single value,
else return a new metric.
Parameters
----------
args :
"""
new_storage = self._storage.__getitem__(args)
# If we ask for one specific bin we just get its content
# as a primitive type
if not isinstance(new_storage, Hist):
return new_storage
# Force Metric class to prevent user trying to accumulate on hybrid instances (accumulation is tied to
# axis_list for each subclass)
new_obj = self.__class__(
axis_list=new_storage.axes, label=self.label, unit=self.unit
)
new_obj._storage = new_storage
new_obj.name = self.name
return new_obj
def __truediv__(self, other):
"""Divide a Metric."""
if isinstance(other, Metric):
denominator = other._storage
extra_label = "ratio"
if not self.compatible(other):
raise ValueError("Can't divide incompatible metrics.")
else:
# is a scalar or array
other = u.Quantity(other)
denominator = other.value
extra_label = f" / {other}"
return self.__class__(
axis_list=self.axes,
label=" ".join([self.label, extra_label]),
unit=self.unit / other.unit,
hist=self._storage / denominator,
)
def __add__(self, other):
"""Add another metric or scalar."""
if isinstance(other, Metric):
other_value = other._storage
if not self.unit.is_equivalent(other.unit):
raise ValueError("Can't add incompatible units.")
other_label = other.label
else:
other_value = u.Quantity(other).to_value(self.unit)
other_label = f"{u.Quantity(other):latex}"
return self.__class__(
axis_list=self.axes,
label=f"{self.label} + {other_label}",
unit=self.unit,
hist=self._storage + other_value,
)
def __sub__(self, other):
"""Add another metric or scalar."""
if isinstance(other, Metric):
other_value = other._storage
if not self.unit.is_equivalent(other.unit):
raise ValueError("Can't add incompatible units.")
other_label = other.label
else:
other_value = u.Quantity(other).to_value(self.unit)
other_label = f"{u.Quantity(other):latex}"
return self.__class__(
axis_list=self.axes,
label=f"{self.label} - {other_label}",
unit=self.unit,
hist=self._storage - other_value,
)
[docs]
def category_dimensions(self) -> list[int]:
"""Return indices of non-continuous dimensions."""
return [i for i, ax in enumerate(self.axes) if ax.traits.discrete]
[docs]
def plot(self, full=False, *args, **kwargs):
"""Redirect plot calls to underlying histogram."""
from matplotlib.pyplot import gca
noncategory_dims = list(
set(range(len(self.axes))) - set(self.category_dimensions())
)
num_noncategory_dims = len(noncategory_dims)
# make a copy so we can modify the labels without breaking things
hist = self._storage.copy(deep=False)
for ax in hist.axes:
# update the labels with units
unit = ax.metadata.get("unit", u.dimensionless_unscaled)
if unit != u.dimensionless_unscaled:
ax.label = f"{ax.label} [{unit:latex}]"
plot_ax = kwargs.get("ax", gca())
if num_noncategory_dims == 1:
# for 1D histograms, we want to label by the metric and it's unit.
ylabel = (
self.label
if self.unit == u.dimensionless_unscaled
else f"{self.label} [{self.unit:latex}]"
)
plot_ax.set_ylabel(ylabel)
if num_noncategory_dims == 2:
plot_ax.set_title(f"{self.name} [{self.unit:latex}]")
if full and num_noncategory_dims == 2:
return hist.plot2d_full(*args, **kwargs)
ret = hist.plot(*args, **kwargs)
x_axis_dim = noncategory_dims[0]
plot_ax.set_xscale(
"log" if str(self.axes[x_axis_dim].transform) == "log" else "linear"
)
return ret
[docs]
def plot_compare_1d(
self,
others: Self | list[Self] = None,
fig: None | FigureBase = None,
legend=True,
show_xlabel=True,
layout: str | None = "constrained",
sharex: Axes | None = None,
sharey: Axes | None = None,
sharey_diff: Axes | None = None,
**kwargs,
) -> dict[str, Axes]:
r"""Generate a subfigure with comparisons and residuals.
The residuals are the relative differences,
:math:`\Delta_\mathrm{rel} \equiv \frac{(v_i - v_\mathrm{ref})}{v_\mathrm{ref}}`
Parameters
----------
others: Metric | list[Metric] | None
List of metrics to compare to reference
fig: matplotlib.Figure | None
Figure or SubFigure to add the axes to, or none to use current.
legend: bool
Show legend on plot
layout: str | None
If fig not passed, in, use this maptlotlib layout engine.
sharex: Axes | None
If specified, share all x-axes with this axis.
sharey: Axes | None
If specified share the value's y-axis with this axis
sharey_diff: Axes | None
If specified share the diff's y-axis with this axis.
**kwargs:
Any other args are passed to the plot() functions.
Returns
-------
dict[matplotlib.Axes]:
"val": value_axis, "diff": relative_difference_axis
"""
fig = fig or figure(layout=layout)
others = others or []
others = others if isinstance(others, list) else [others]
if len(self.axes) != 1:
raise ValueError(f"Expected 1D metric, but got {len(self.axes)}D.")
gs = GridSpec(2, 1, figure=fig, hspace=0, height_ratios=[0.7, 0.3])
ax_val = fig.add_subplot(gs[0], sharex=sharex, sharey=sharey)
ax_diff = fig.add_subplot(gs[1], sharex=ax_val, sharey=sharey_diff)
# plot the comparisons
self.plot(ax=ax_val, label=self.name, **kwargs)
for metric in others:
metric.plot(ax=ax_val, label=metric.name, **kwargs)
# plot the relative differences
ax_diff.axhline(0.0, lw=3, color="C0")
for ii, metric in enumerate(others):
(metric / self - 1).plot(
ax=ax_diff, label=metric.name, color=f"C{ii + 1}", **kwargs
)
if legend:
ax_val.legend(loc="best")
ax_diff.set_ylabel(r"$\Delta_\mathrm{rel}$")
ax_val.tick_params(axis="x", direction="in")
ax_val.set_xlabel(None, labelpad=0)
# don't show x-tic labels on the value axis
for label in ax_val.get_xticklabels():
label.set_visible(False)
if show_xlabel is False:
# optionally don't show x-tic labels on the diff axis
for label in ax_diff.get_xticklabels():
ax_diff.set_xlabel(None, labelpad=0) # has no effect
label.set_visible(False)
return fig, dict(value=ax_val, diff=ax_diff)
[docs]
def save(self, output_file: Path):
"""Save metric to file.
Parameters
----------
output_file : Path
Output filename.
"""
init = self._init_treedict()
data, *dummy = self.to_numpy(flow=True)
axis_spec = []
for ax in self.axes:
axis_spec.append(serialize_hist_ax(ax))
init["data"] = data.to_value(self.unit)
init["axes"] = axis_spec
init["label"] = self.label
init["metadata"] = self.other_attributes
init["unit"] = self.unit.to_string("fits")
logger.debug("Saving ASDF file with tree: %s", init)
with AsdfFile(init, custom_schema=self.schema_url) as asdffile:
if not output_file.suffix == ".asdf":
output_file = output_file.with_suffix(output_file.suffix + ".asdf")
asdffile.write_to(output_file)
[docs]
def compatible(self, other):
"""
Compare objects and ensure number and name of axis match.
Parameters
----------
other :
Metric other :
Other instance to compare self to.
Raises
------
IndexError.
Returns
-------
True if the objects are compatible.
"""
is_compatible = True
self_ndim = len(self.axes)
other_ndim = len(other.axes)
if self_ndim != other_ndim:
raise IndexError(
f"Dimensions of {self.label} does not match those of {other.label}: {self_ndim} != {other_ndim}"
)
for self_ax, other_ax in zip(self.axes, other.axes):
self_class, self_params = serialize_hist_ax(self_ax)
other_class, other_params = serialize_hist_ax(other_ax)
if self_class != other_class:
raise IndexError(
f"Axes type of {self.label} does not match those of {other.label}: {self_class} != {other_class}"
)
if self_params["name"] != other_params["name"]:
raise IndexError(
f"Axes name of {self.label} does not match those of {other.label}: {self_params['name']} != {other_params['name']}"
)
return is_compatible
[docs]
@classmethod
def load(cls, init):
"""
Load class instance from file.
Parameters
----------
cls :
init :
Dict or string
"""
instance = None
with _open_asdf(init, cls.schema_url) as asdffile:
metric_type = asdffile.tree["filetype"]
if metric_type != cls.__name__:
# TODO: with the refactored init method, not sure if this is really necessary anymore
msg = f"File is of type {metric_type}, can't load with {cls.__name__}"
asdffile.close() # Close file before raising error.
raise TypeError(msg)
axis_list = [unserialize_hist_ax(*args) for args in asdffile.tree["axes"]]
# Create the instance and connect any other information
instance = cls(
axis_list=axis_list,
label=asdffile.tree["label"],
unit=asdffile.tree["unit"],
output_store=None,
other_attributes=asdffile.tree["metadata"],
)
# Ensure stored axis_list in file is compatible with default from
# class (in case of class changes). However, if the instance axes is
# set to None, then we want to use the one in the file.
ax_equal, msg = compare_axis_list(axis_list, instance.axes)
if instance.axes and not ax_equal:
msg = (
f"Stored axis_list is different that the one defined for {cls}\n"
+ msg
)
asdffile.close() # Close file before raising error.
raise ValueError(msg)
# Overwrite axis_list, for StrCategory, now that we know they are compatible
instance._storage = Hist(*axis_list)
instance._storage[...] = asdffile.tree["data"]
instance.name = asdffile.tree["name"]
return instance
[docs]
def normalize_to_probability(
self, axis_index: int = 1, threshold_frac=0.005, fill_value=np.nan
) -> Self:
"""
Normalize the data such that the integral over axis ``axis_index`` 1.0.
Parameters
----------
arr : np.ndarray
The input array.
axis_index : int
The axis along which to normalize.
threshold_frac: float
if the integral along the axis is below this fraction of
the total integral, mask off this row as having too low stats to plot.
This prevents low-stats values form saturating the color scale
fill_value:
value to replace low-stats entries with
Returns
-------
np.ndarray
A new Metric normalized along the specified axis.
"""
if not self.unit.is_equivalent(u.dimensionless_unscaled):
raise ValueError(
f"Can't normalize a non-dimensionless quantity ({self.unit})"
)
data = self.to_numpy()[0]
normed = normalize_along_axis(
data, axis_index, threshold_frac=threshold_frac, fill_value=fill_value
)
return self.__class__(
axis_list=self.axes,
label=f"{self.label} (normed)",
unit=self.unit,
data=normed,
)
[docs]
def split_by_category(self) -> dict[Self]:
"""Turn a Metric with categories into a dict of Metrics by category."""
raise NotImplementedError()
def serialize_hist_ax(ax):
"""
Extract parameters from a Hist axes to pass to a future Metric class.
Parameters
----------
ax :
"""
params = {}
params["name"] = ax.name
params["label"] = ax.label
params["growth"] = ax.traits.growth
if ax.metadata and isinstance(ax.metadata, dict):
unit = u.Unit(ax.metadata.get("unit", ""))
else:
unit = u.dimensionless_unscaled
params["metadata"] = dict(unit=unit.to_string("fits"))
klass_name = ax.__class__.__name__
if klass_name == "Boolean":
del params["growth"]
elif klass_name in ["StrCategory", "IntCategory"]:
params["categories"] = [c for c in ax]
elif klass_name in ["Variable"]:
del params["growth"]
params["edges"] = ax.edges
else:
# get start and stop from edges
params["start"] = ax.edges[0]
params["stop"] = ax.edges[-1]
params["bins"] = ax.size # No need for nbins in StrCategory, it makes no sense
params["underflow"] = ax.traits.underflow
params["overflow"] = ax.traits.overflow
params["circular"] = ax.traits.circular
if ax.transform is not None:
params["transform"] = str(ax.transform)
return (klass_name, params)
def unserialize_hist_ax(klass_name, in_params):
"""
From the axis definition in the asdf file, provide the instantiated axis.
Parameters
----------
in_params :
klass_name :
params :
"""
klass = getattr(axis, klass_name)
out_params = {}
for k, v in in_params.items():
if k == "transform":
out_params[k] = get_transform_func(v)
elif k == "metadata":
v["unit"] = u.Unit(v["unit"])
out_params[k] = v
else:
out_params[k] = v
return klass(**out_params)
def compare_axis_list(l1, l2):
"""
Compare 2 axis_list.
Axis_list is expected to be a list of tuple (class, dict) where class is the
axis class, and dict is the dictionary of parameter to init the class
Parameters
----------
l1 :
First axis_list.
l2 :
Second.
"""
is_equal = True
msg = ""
nax1 = len(l1)
nax2 = len(l2)
if nax1 != nax2:
is_equal = False
msg += f"axis_list have different nb axis ({nax1} vs {nax2}).\n"
ax_id = 0
for ax1, ax2 in zip(l1, l2):
(class1, p1) = serialize_hist_ax(ax1)
(class2, p2) = serialize_hist_ax(ax2)
if class1 != class2:
is_equal = False
msg += f"ax{ax_id}: Got {class1}, expected {class2}\n"
eq, m = utils.compare_ax_serialisation(p1, p2)
if not eq:
is_equal = False
msg += m
return is_equal, msg
class IRFMetric(Metric):
@classmethod
@abstractmethod
def _from_specific_table(cls, irf_table): # noqa
"""Create class instance from an IRF in GADF format."""
raise NotImplementedError(
"The subclass of IRFMetric doesn't implement _from_specific_table()"
)
@classmethod
def from_table(cls, irf_table, output_store=None):
"""
Create instance from an IRF table in GADF format.
Parameters
----------
irf_table: Table
astropy table containing the IRF in GADF/Fermi format.
output_store: MetricStore | None
if passed, used to define the name of the Metric
"""
if issubclass(type(irf_table), Table):
irf_table = irf_table[0]
elif not issubclass(type(irf_table), Row):
raise ValueError(
"table needs to be of type `astropy.Table` or `astropy.Row`"
)
new_metric = cls._from_specific_table(irf_table)
new_metric.update_name_with_store(output_store)
return new_metric
@classmethod
def _get_metric_bins_from_table(cls, table, reg_cols=[], cat_cols=[]):
"""Get bin edges from a GADF table easier."""
if issubclass(type(table), Table):
table = table[0]
elif not issubclass(type(table), Row):
raise ValueError(
"table needs to be of type `astropy.Table` or `astropy.Row`"
)
bins = []
for reg in reg_cols:
bins.append(
utils.flatten_lo_hi_bins(table[f"{reg}_LO"], table[f"{reg}_HI"])
)
for cat in cat_cols:
cols = [
f"{lo.value} to {hi}"
for lo, hi in zip(table[f"{cat}_LO"], table[f"{cat}_HI"])
]
bins.append(cols)
return bins