Source code for datapipe_testbench.compare_1d

"""
Classes for comparing 1d metrics in various situations.
"""

import logging

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from astropy.units import Quantity

logger = logging.getLogger(__name__)


[docs] class CompareByCat1D: """Class for comparing 1d histograms grouped by category.""" def __init__(self, cat_axis): self.categories = [cat for cat in cat_axis]
[docs] def plot_compare_by_category( self, ref, others, xscale="linear", yscale="log", density=False, yerr=True, n_cols=None, ): """Plot 1d ref and comparison histograms grouped by category.""" other_cats = set(self.categories) for other in others: other_cats.intersection_update(set(other.axes[0])) if isinstance(density, bool): if density: ylabel = "Density" else: ylabel = "Counts" else: ylabel = density if n_cols: n_rows = len(self.categories) % n_cols else: n_rows, n_cols = 1, len(self.categories) fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.1 * 4 * n_cols, 4 * n_rows)) fig.set_tight_layout(True) fig.suptitle(ref.__class__.__name__) axs = axs.flatten() for idx, cat in enumerate(self.categories): ref[cat, :].plot( ax=axs[idx], label=f"{ref.label}", lw=2, density=density, yerr=yerr ) axs[idx].legend(loc="best") for other in others: if cat in other_cats: other[cat, :].plot( ax=axs[idx], label=f"{other.label}", alpha=0.6, density=density, yerr=yerr, ) axs[idx].legend(loc="best") axs[idx].set(yscale=yscale, xscale=xscale, ylabel=ylabel) axs[idx].set_title(cat) return fig
[docs] def distance_by_category(self, ref, others): """Measure the wasserstein distance between ref and comparison histograms grouped by category.""" result = dict() other_cats = set(self.categories) for other in others: other_cats.intersection_update(set(other.axes[0])) for cat in self.categories: ref_cnt, bins = ref._storage[cat, :].to_numpy() # TODO: use axis.center instead ref_cent = np.convolve(bins, 2 * [0.5], mode="valid") cat_result = {ref.name: []} for other in others: if cat in other_cats: cnt, bins = other._storage[cat, :].to_numpy() cent = np.convolve(bins, 2 * [0.5], mode="valid") # Weird call needed to properly work with histograms of different binning # https://stackoverflow.com/questions/76049158/wasserstein-distance-in-scipy-definition-of-support cat_result[ref.name].append( { other.name: st.wasserstein_distance( ref_cent, cent, ref_cnt, cnt ) } ) result[cat] = cat_result return result
[docs] def plot_compare_by_category_from_method( self, ref, others, method_name, xscale="linear", yscale="log", density=False, yerr=True, n_cols=None, ): """Plot 1d ref and comparison histograms grouped by category.""" # ref_func = getattr(ref, method_name) # other_funcs = [getattr(other, method_name) for other in others] # TODO: actually use this for later plotting other_cats = set(self.categories) for other in others: other_cats.intersection_update(set(other.axes[0])) if isinstance(density, bool): if density: ylabel = "Density" else: ylabel = "Counts" else: ylabel = density if n_cols: n_rows = len(self.categories) % n_cols else: n_rows, n_cols = 1, len(self.categories) fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.1 * 4 * n_cols, 4 * n_rows)) fig.set_tight_layout(True) fig.suptitle(ref.__class__.__name__) axs = axs.flatten() for idx, cat in enumerate(self.categories): ref[cat, :].plot( ax=axs[idx], label=f"{ref.label}", lw=2, density=density, yerr=yerr ) axs[idx].legend(loc="best") for other in others: if cat in other_cats: other[cat, :].plot( ax=axs[idx], label=f"{other.label}", alpha=0.6, density=density, yerr=yerr, ) axs[idx].legend(loc="best") axs[idx].set(yscale=yscale, xscale=xscale, ylabel=ylabel) axs[idx].set_title(cat) return fig
[docs] def plot_compare_tuples_by_category_in_bins( self, ref, others, bins, ybin_label="energy", ybin_unit="TeV", tup1_label="gammas", tup2_label="protons", density=True, yerr=True, ): """Plot two lines from same metric against two lines of other metric. Takes tuples of metrics with the shape: ``[category, bin_property, property]`` and plots the 1d distribution of property in a given bin from the two tuples into the same axes, making one column of axes per category, and one row per bin value provided. """ other_cats = set(self.categories) for other, _ in others: other_cats.intersection_update(set(other.axes[0])) if isinstance(density, bool): if density: ylabel = "Density" else: ylabel = "Counts" else: ylabel = density n_rows = len(bins) n_cols = len(self.categories) fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows)) ref_tup1, ref_tup2 = ref for idx, cat in enumerate(self.categories): for idy, bin in enumerate(bins): ax = axs[idy, idx] low, hig = ref_tup1.axes[1].bin(bin) ref_tup1[cat, bin, :].plot( ax=ax, label=f"{tup1_label} {ref_tup1.name}", lw=2, density=density, yerr=yerr, ) ref_tup2[cat, bin, :].plot( ax=ax, label=f"{tup2_label} {ref_tup2.name}", lw=2, density=density, yerr=yerr, ) ax.set( ylabel=f"{ylabel}\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}" ) for tup1, tup2 in others: if cat in other_cats: tup1[cat, bin, :].plot( ax=ax, label=f"{tup1_label} {tup1.name}", density=density, yerr=yerr, ) tup2[cat, bin, :].plot( ax=ax, label=f"{tup2_label} {tup2.name}", density=density, yerr=yerr, ) if idy == 0: ax.set_title(cat) axs[-1, -1].legend(loc="best") return fig
[docs] def compare_1d_metrics(ax, ref, others, external_config={}): """Compare several 1d histograms.""" conf = {} conf["xscale"] = external_config.get("xscale", "log") conf["yscale"] = external_config.get("yscale", "log") ylabel = external_config.get("ylabel", "") textsize = external_config.get("textsize", "larger") add_legend = external_config.get("add_legend", True) leged_framealpha = external_config.get("legend_framealpha", 1) legend_loc = external_config.get("legend_loc", "best") plot_opts = external_config.get("plot_opts", dict()) if len(ref.axes) > 0: with np.errstate(divide="ignore", invalid="ignore"): ref.plot(ax=ax, yerr=False, label=f"{ref.label}", lw=2, **plot_opts) for other in others: if len(other.axes) > 0: other.plot( ax=ax, yerr=False, label=f"{other.label}", alpha=0.6, **plot_opts ) ax.set(**conf) if add_legend: ax.legend(loc=legend_loc, fontsize=textsize, framealpha=leged_framealpha) ax.set_ylabel(ylabel, fontsize=textsize)
[docs] def compare_rat_1d_metrics(ax, ref, others, external_config={}): """Compare several 1d histograms by taking their ratio.""" conf = {} conf["xscale"] = external_config.get("xscale", "log") conf["yscale"] = external_config.get("yscale", "log") plot_opts = external_config.get("plot_opts", dict()) ylabel = external_config.get("ylabel", "ratio") textsize = external_config.get("textsize", "larger") add_legend = external_config.get("add_legend", True) leged_framealpha = external_config.get("legend_framealpha", 1) legend_loc = external_config.get("legend_loc", "best") ax.axhline(y=1.0, linestyle="--") try: for ii, other in enumerate(others): if len(other.axes) != 1: logger.warning( "%s name '%s' label '%s' is empty, skipping", other.__class__.__name__, other.name, other.label, ) continue with np.errstate(divide="ignore", invalid="ignore"): if other.axes[0] != ref.axes[0]: xs, rat = _interp_ratio(ref, other) ax.plot( xs, rat, color=f"C{ii + 1}", label=f"{other.label}", alpha=0.9 ) else: (other._storage / ref._storage).plot( ax=ax, yerr=False, color=f"C{ii + 1}", label=f"{other.label}", alpha=0.9, **plot_opts, ) except (RuntimeError, ValueError) as err: # noqa logger.warning( "Comparison of %s '%s' and %s '%s' failed due to %s: %s", ref.__class__.__name__, ref.name, other.__class__.__name__, other.name, err.__class__.__name__, err, ) ax.set(**conf) if add_legend: ax.legend(loc=legend_loc, fontsize=textsize, framealpha=leged_framealpha) ax.set_ylabel(ylabel) ax.set_xlabel(ref.axes[0].name)
def _interp_ratio(ref, other): """Interp ratio.""" breakpoint() ref_h, ref_x = ref.to_numpy() oth_h, oth_x = other.to_numpy() ref_x = np.sqrt(ref_x[1:] * ref_x[:-1]) oth_x = np.sqrt(oth_x[1:] * oth_x[:-1]) if Quantity(ref_x).unit.is_equivalent("TeV"): ref_x = ref_x.to("TeV") if Quantity(oth_x).unit.is_equivalent("TeV"): oth_x = oth_x.to("TeV") ref_x, oth_x = ref_x.value, oth_x.value xs = np.logspace(start=np.log10(ref_x.min()), stop=np.log10(ref_x.max()), num=90) ref_int = np.interp(xs, ref_x, ref_h) oth_int = np.interp(xs, oth_x, oth_h) with np.errstate(divide="ignore"): ratio = oth_int / ref_int return xs, ratio