"""
Classes for comparing 2d metrics in various situations.
"""
import matplotlib.colors as col
import matplotlib.pyplot as plt
import numpy as np
from hist import loc
from .stat_helpers import (
calc_log_fraction_of_spans,
find_line_segments_by_var_of_prop,
rms_profile_from_2d_hist,
)
def _update_range_from_2dmetric(ranges, metric):
"""Update range from 2dmetric."""
met_range = metric.axes[0].edges[[0, -1]]
ranges["x"] = (min(met_range[0], ranges["x"][0]), max(met_range[1], ranges["x"][1]))
met_range = metric.axes[1].edges[[0, -1]]
ranges["y"] = (min(met_range[0], ranges["y"][0]), max(met_range[1], ranges["y"][1]))
[docs]
class CompareByCat2D:
"""Class for comparing 2d histograms grouped by category."""
def __init__(self, cat_axis):
self.categories = [cat for cat in cat_axis]
# Generic dictionary to store all comparison results to be saved later.
# First level, the value identify the test we consider (linearity, requirement, etc...)
# inside the level, we expect:
# * xrange keyword for xmin and xmax (define once and for all categories)
# * xspans that will be a list of tuple (x1;x2) that represent ranges where the condition of the test is
# valid (sometimes the test is inverted, and the range represent where it's invalid, this is specific to each
# test. One list of tuple per category
# Format:
# Comparison_name: (e.g. , "linear region size", "linear spans", "fraction violating", "fail spans")
# xrange: (min, max)
# xspans:
# category name:
# histogram name (e.g. production id):
# value
self.comp_res = {}
def _check_num_axes(self, ref, others):
"""Check num axes."""
for o in others:
ref.compatible(o)
[docs]
def plot_compare_disp_by_category(
self, ref, others, xscale="log", yscale="linear", normed=False
):
"""Plot "dispersion" plots.
Takes a set of 3d histograms with axes ``[category, x-axis, y-axis]``.
Groups by category
"""
self._check_num_axes(ref, others)
fig = plt.figure(
figsize=(1.1 * 4 * len(self.categories), (1 + len(others)) * 4),
layout="constrained",
)
metric_list = [ref] + others
subfigs = fig.subfigures(nrows=len(metric_list), ncols=1)
for metric, subfig in zip(metric_list, subfigs):
axs = subfig.subplots(nrows=1, ncols=len(self.categories))
axs = np.atleast_1d(axs) # for the case where ncols=1
for idx, cat in enumerate(self.categories):
if cat not in metric.axes[0]:
continue # can't compare this category since doesn't exist
metric_for_cat = metric[cat, ...]
if normed:
metric_for_cat.normalize_to_probability().plot(
ax=axs[idx], cbarextend=False
)
else:
metric_for_cat.plot(
ax=axs[idx], norm=col.LogNorm(), cbarextend=False
)
axs[idx].axhline(0, color="black", linestyle=":", lw=1.8)
axs[idx].set(title=cat)
axs[idx].set(xscale=xscale, yscale=yscale)
subfig.suptitle(metric.name)
return fig
[docs]
def plot_compare_cut_disp_by_cat(
self, ref, others, cuts_dict, xscale="log", yscale="linear", **fig_args
):
"""Plot "dispersion" plots in user defined bins along a third axis.
Takes a set of 4d histograms with axes: ``[category, cut_axis, x-axis, y-axis]``.
and makes a plot grouping by category and further sub-grouping
using the provided ``cuts_dict`` to split the ``cut_axis``,
making 2d histograms for each ``category`` and bin in the ``cut_axis``.
"""
self._check_num_axes(ref, others)
metrics = [ref] + others
figs = []
for id0, cat in enumerate(self.categories):
fig, axs = plt.subplots(
3,
len(metrics),
figsize=(1.5 * 4 * len(metrics), (1 + 3) * 3.2),
layout="constrained",
**fig_args,
)
fig.suptitle(cat)
ranges = {"x": [10, -10], "y": [10, -10]}
for idx, met in enumerate(metrics):
cuts = cuts_dict[met.name]
for idy, pair in enumerate(cuts[cat]):
st, ed = pair
plot_hist = met[cat, loc(st) : loc(ed) : sum, ...]
plot_hist.plot(
ax=axs[idy, idx], norm=col.LogNorm(), cbarextend=False
)
_update_range_from_2dmetric(ranges, plot_hist)
axs[idy, idx].set(
xscale=xscale,
yscale=yscale,
title=f"{met.name}\n{st:.2f} < peak_time < {ed:.2f}",
xlim=ranges["x"],
ylim=ranges["y"],
)
figs.append((cat, fig))
return figs
[docs]
def plot_compare_2dhist_rms(
self, ref, others, xscale="log", yscale="linear", **fig_args
):
"""Compare binwise RMS values, groups by category.
Takes a set of 3d histograms with axes ``[category, x-axis, y-axis]``,
for each x-bin an RMS is calculated using the y-values.
"""
self._check_num_axes(ref, others)
fig, axs = plt.subplots(
1,
len(self.categories),
figsize=(1.1 * 4 * len(self.categories), 4),
layout="constrained",
**fig_args,
)
xlabel = ref.axes[1].label
ylabel = ref.axes[2].label
for idx, cat in enumerate(self.categories):
ref_rms, _, ref_cent = rms_profile_from_2d_hist(ref[cat, ...])
axs[idx].plot(ref_cent, ref_rms, "ro", label=ref.label, markersize=5)
axs[idx].set(title="RMS " + cat)
for idy, other in enumerate(others):
rms, _, cent = rms_profile_from_2d_hist(other[cat, ...])
axs[idx].plot(cent, rms, ".", label=other.label)
axs[idx].set(
xscale=xscale, yscale=yscale, xlabel=xlabel, ylabel=f"RMS {ylabel}"
)
axs[idx].legend()
return fig
[docs]
def plot_compare_2dhist_mean(
self, ref, others, xscale="log", yscale="linear", **fig_args
):
"""Compare means of 2d histograms."""
self._check_num_axes(ref, others)
fig, axs = plt.subplots(
1,
len(self.categories),
figsize=(1.1 * 4 * len(self.categories), 4),
layout="constrained",
**fig_args,
)
xlabel = ref.axes[1].label
ylabel = ref.axes[2].label
for idx, cat in enumerate(self.categories):
_, ref_mean, ref_cent = rms_profile_from_2d_hist(ref[cat, ...])
axs[idx].plot(ref_cent, ref_mean, "ro", label=ref.label, markersize=5)
axs[idx].set(title="Mean " + cat)
for idy, other in enumerate(others):
_, mean, cent = rms_profile_from_2d_hist(other[cat, ...])
axs[idx].plot(cent, mean, ".", label=other.label)
axs[idx].set(
xscale=xscale, yscale=yscale, xlabel=xlabel, ylabel=f"Mean {ylabel}"
)
axs[idx].legend()
return fig
[docs]
def plot_compare_2dhist_rms_with_req(
self, req_xs, req_ys, ref, others, xscale="log", yscale="linear", **fig_args
):
"""Compare binwise RMS values, groups by category.
Takes a set of 3d histograms with axes [category, x-axis, y-axis],
for each x-bin an RMS is calculated using the y-values.
Then it draws the provided requirement values on top.
"""
fig = self.plot_compare_2dhist_rms(ref, others, xscale, yscale, **fig_args)
axs = fig.get_axes()
for ax in axs:
ax.plot(
req_xs,
req_ys,
alpha=0.5,
color="firebrick",
linestyle="dashed",
label="Requirement",
)
return fig
[docs]
def compare_disp_with_req(self, req_xs, req_ys, hists):
"""Compare dispersion matrix with requirements.
Takes a set of 3d histograms with axes [category, x-axis, y-axis],
using ``x_ax`` as true values, and ``y_ax`` as reco/true - 1. It
computes RMS in each true bin, comparing the result with the
provided requirement values assumed to be in same units as the
histograms.
"""
# Init the tests in the general dict
fracs = {
"xrange": (req_xs[0], req_xs[-1]),
"xspans": {cat: {} for cat in self.categories},
}
fail_spans = {
"xrange": (req_xs[0], req_xs[-1]),
"xspans": {cat: {} for cat in self.categories},
}
for cat in self.categories:
for hist in hists:
rms, mean, cent = rms_profile_from_2d_hist(hist[cat, ...])
sel = np.isfinite(rms)
xs, ys = cent[sel], rms[sel]
frac_tmp, fail_tmp = calc_log_fraction_of_spans(req_xs, req_ys, xs, ys)
fracs["xspans"][cat][hist.name] = frac_tmp
fail_spans["xspans"][cat][hist.name] = fail_tmp
self.comp_res["fraction violating"] = fracs
self.comp_res["fail spans"] = fail_spans
def _get_default_range(self, hists):
"""Get the default xrange for comparisons.
Uses the range of the x axis (first non category
axis). Make sure that the range include all ranges for all histograms.
Parameters
----------
hists :
"""
xmins = []
xmaxs = []
for h in hists:
edges = h.axes[1].edges
xmins.append(edges.min())
xmaxs.append(edges.max())
xrange = (np.min(xmins), np.max(xmaxs))
return xrange
[docs]
def compare_disp_with_lin(self, hists, lin_threshold=0.015):
"""Compare dispersion matrix with linearity.
Takes a set of 3d histograms with axes [category, x-axis, y-axis],
using ``x_ax`` as true values, and ``y_ax`` as reco/true - 1. It
computes RMS in each true bin, comparing the result with the
provided requirement values assumed to be in same units as the
histograms.
"""
xrange = self._get_default_range(hists)
self.comp_res["linear region size"] = {
"xrange": xrange,
"xspans": {cat: {} for cat in self.categories},
}
self.comp_res["linear spans"] = {
"xrange": xrange,
"xspans": {cat: {} for cat in self.categories},
}
for cat in self.categories:
for hist in hists:
rms, mean, cent = rms_profile_from_2d_hist(hist[cat, ...])
spans, span_len = find_line_segments_by_var_of_prop(
cent, mean, lin_threshold
)
span_edges = [itm[0:2] for itm in spans]
self.comp_res["linear region size"]["xspans"][cat][hist.name] = span_len
self.comp_res["linear spans"]["xspans"][cat][hist.name] = span_edges