Source code for datapipe_testbench.compare_2d

"""
Classes for comparing 2d metrics in various situations.
"""

import matplotlib.colors as col
import matplotlib.pyplot as plt
import numpy as np
from hist import loc

from .stat_helpers import (
    calc_log_fraction_of_spans,
    find_line_segments_by_var_of_prop,
    rms_profile_from_2d_hist,
)


def _update_range_from_2dmetric(ranges, metric):
    """Update range from 2dmetric."""
    met_range = metric.axes[0].edges[[0, -1]]
    ranges["x"] = (min(met_range[0], ranges["x"][0]), max(met_range[1], ranges["x"][1]))

    met_range = metric.axes[1].edges[[0, -1]]
    ranges["y"] = (min(met_range[0], ranges["y"][0]), max(met_range[1], ranges["y"][1]))


[docs] class CompareByCat2D: """Class for comparing 2d histograms grouped by category.""" def __init__(self, cat_axis): self.categories = [cat for cat in cat_axis] # Generic dictionary to store all comparison results to be saved later. # First level, the value identify the test we consider (linearity, requirement, etc...) # inside the level, we expect: # * xrange keyword for xmin and xmax (define once and for all categories) # * xspans that will be a list of tuple (x1;x2) that represent ranges where the condition of the test is # valid (sometimes the test is inverted, and the range represent where it's invalid, this is specific to each # test. One list of tuple per category # Format: # Comparison_name: (e.g. , "linear region size", "linear spans", "fraction violating", "fail spans") # xrange: (min, max) # xspans: # category name: # histogram name (e.g. production id): # value self.comp_res = {} def _check_num_axes(self, ref, others): """Check num axes.""" for o in others: ref.compatible(o)
[docs] def plot_compare_disp_by_category( self, ref, others, xscale="log", yscale="linear", normed=False ): """Plot "dispersion" plots. Takes a set of 3d histograms with axes ``[category, x-axis, y-axis]``. Groups by category """ self._check_num_axes(ref, others) fig = plt.figure( figsize=(1.1 * 4 * len(self.categories), (1 + len(others)) * 4), layout="constrained", ) metric_list = [ref] + others subfigs = fig.subfigures(nrows=len(metric_list), ncols=1) for metric, subfig in zip(metric_list, subfigs): axs = subfig.subplots(nrows=1, ncols=len(self.categories)) axs = np.atleast_1d(axs) # for the case where ncols=1 for idx, cat in enumerate(self.categories): if cat not in metric.axes[0]: continue # can't compare this category since doesn't exist metric_for_cat = metric[cat, ...] if normed: metric_for_cat.normalize_to_probability().plot( ax=axs[idx], cbarextend=False ) else: metric_for_cat.plot( ax=axs[idx], norm=col.LogNorm(), cbarextend=False ) axs[idx].axhline(0, color="black", linestyle=":", lw=1.8) axs[idx].set(title=cat) axs[idx].set(xscale=xscale, yscale=yscale) subfig.suptitle(metric.name) return fig
[docs] def plot_compare_cut_disp_by_cat( self, ref, others, cuts_dict, xscale="log", yscale="linear", **fig_args ): """Plot "dispersion" plots in user defined bins along a third axis. Takes a set of 4d histograms with axes: ``[category, cut_axis, x-axis, y-axis]``. and makes a plot grouping by category and further sub-grouping using the provided ``cuts_dict`` to split the ``cut_axis``, making 2d histograms for each ``category`` and bin in the ``cut_axis``. """ self._check_num_axes(ref, others) metrics = [ref] + others figs = [] for id0, cat in enumerate(self.categories): fig, axs = plt.subplots( 3, len(metrics), figsize=(1.5 * 4 * len(metrics), (1 + 3) * 3.2), layout="constrained", **fig_args, ) fig.suptitle(cat) ranges = {"x": [10, -10], "y": [10, -10]} for idx, met in enumerate(metrics): cuts = cuts_dict[met.name] for idy, pair in enumerate(cuts[cat]): st, ed = pair plot_hist = met[cat, loc(st) : loc(ed) : sum, ...] plot_hist.plot( ax=axs[idy, idx], norm=col.LogNorm(), cbarextend=False ) _update_range_from_2dmetric(ranges, plot_hist) axs[idy, idx].set( xscale=xscale, yscale=yscale, title=f"{met.name}\n{st:.2f} < peak_time < {ed:.2f}", xlim=ranges["x"], ylim=ranges["y"], ) figs.append((cat, fig)) return figs
[docs] def plot_compare_2dhist_rms( self, ref, others, xscale="log", yscale="linear", **fig_args ): """Compare binwise RMS values, groups by category. Takes a set of 3d histograms with axes ``[category, x-axis, y-axis]``, for each x-bin an RMS is calculated using the y-values. """ self._check_num_axes(ref, others) fig, axs = plt.subplots( 1, len(self.categories), figsize=(1.1 * 4 * len(self.categories), 4), layout="constrained", **fig_args, ) xlabel = ref.axes[1].label ylabel = ref.axes[2].label for idx, cat in enumerate(self.categories): ref_rms, _, ref_cent = rms_profile_from_2d_hist(ref[cat, ...]) axs[idx].plot(ref_cent, ref_rms, "ro", label=ref.label, markersize=5) axs[idx].set(title="RMS " + cat) for idy, other in enumerate(others): rms, _, cent = rms_profile_from_2d_hist(other[cat, ...]) axs[idx].plot(cent, rms, ".", label=other.label) axs[idx].set( xscale=xscale, yscale=yscale, xlabel=xlabel, ylabel=f"RMS {ylabel}" ) axs[idx].legend() return fig
[docs] def plot_compare_2dhist_mean( self, ref, others, xscale="log", yscale="linear", **fig_args ): """Compare means of 2d histograms.""" self._check_num_axes(ref, others) fig, axs = plt.subplots( 1, len(self.categories), figsize=(1.1 * 4 * len(self.categories), 4), layout="constrained", **fig_args, ) xlabel = ref.axes[1].label ylabel = ref.axes[2].label for idx, cat in enumerate(self.categories): _, ref_mean, ref_cent = rms_profile_from_2d_hist(ref[cat, ...]) axs[idx].plot(ref_cent, ref_mean, "ro", label=ref.label, markersize=5) axs[idx].set(title="Mean " + cat) for idy, other in enumerate(others): _, mean, cent = rms_profile_from_2d_hist(other[cat, ...]) axs[idx].plot(cent, mean, ".", label=other.label) axs[idx].set( xscale=xscale, yscale=yscale, xlabel=xlabel, ylabel=f"Mean {ylabel}" ) axs[idx].legend() return fig
[docs] def plot_compare_2dhist_rms_with_req( self, req_xs, req_ys, ref, others, xscale="log", yscale="linear", **fig_args ): """Compare binwise RMS values, groups by category. Takes a set of 3d histograms with axes [category, x-axis, y-axis], for each x-bin an RMS is calculated using the y-values. Then it draws the provided requirement values on top. """ fig = self.plot_compare_2dhist_rms(ref, others, xscale, yscale, **fig_args) axs = fig.get_axes() for ax in axs: ax.plot( req_xs, req_ys, alpha=0.5, color="firebrick", linestyle="dashed", label="Requirement", ) return fig
[docs] def compare_disp_with_req(self, req_xs, req_ys, hists): """Compare dispersion matrix with requirements. Takes a set of 3d histograms with axes [category, x-axis, y-axis], using ``x_ax`` as true values, and ``y_ax`` as reco/true - 1. It computes RMS in each true bin, comparing the result with the provided requirement values assumed to be in same units as the histograms. """ # Init the tests in the general dict fracs = { "xrange": (req_xs[0], req_xs[-1]), "xspans": {cat: {} for cat in self.categories}, } fail_spans = { "xrange": (req_xs[0], req_xs[-1]), "xspans": {cat: {} for cat in self.categories}, } for cat in self.categories: for hist in hists: rms, mean, cent = rms_profile_from_2d_hist(hist[cat, ...]) sel = np.isfinite(rms) xs, ys = cent[sel], rms[sel] frac_tmp, fail_tmp = calc_log_fraction_of_spans(req_xs, req_ys, xs, ys) fracs["xspans"][cat][hist.name] = frac_tmp fail_spans["xspans"][cat][hist.name] = fail_tmp self.comp_res["fraction violating"] = fracs self.comp_res["fail spans"] = fail_spans
def _get_default_range(self, hists): """Get the default xrange for comparisons. Uses the range of the x axis (first non category axis). Make sure that the range include all ranges for all histograms. Parameters ---------- hists : """ xmins = [] xmaxs = [] for h in hists: edges = h.axes[1].edges xmins.append(edges.min()) xmaxs.append(edges.max()) xrange = (np.min(xmins), np.max(xmaxs)) return xrange
[docs] def compare_disp_with_lin(self, hists, lin_threshold=0.015): """Compare dispersion matrix with linearity. Takes a set of 3d histograms with axes [category, x-axis, y-axis], using ``x_ax`` as true values, and ``y_ax`` as reco/true - 1. It computes RMS in each true bin, comparing the result with the provided requirement values assumed to be in same units as the histograms. """ xrange = self._get_default_range(hists) self.comp_res["linear region size"] = { "xrange": xrange, "xspans": {cat: {} for cat in self.categories}, } self.comp_res["linear spans"] = { "xrange": xrange, "xspans": {cat: {} for cat in self.categories}, } for cat in self.categories: for hist in hists: rms, mean, cent = rms_profile_from_2d_hist(hist[cat, ...]) spans, span_len = find_line_segments_by_var_of_prop( cent, mean, lin_threshold ) span_edges = [itm[0:2] for itm in spans] self.comp_res["linear region size"]["xspans"][cat][hist.name] = span_len self.comp_res["linear spans"]["xspans"][cat][hist.name] = span_edges