"""
Classes for comparing 1d metrics in various situations.
"""
import logging
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from astropy.units import Quantity
logger = logging.getLogger(__name__)
[docs]
class CompareByCat1D:
"""Class for comparing 1d histograms grouped by category."""
def __init__(self, cat_axis):
self.categories = [cat for cat in cat_axis]
[docs]
def plot_compare_by_category(
self,
ref,
others,
xscale="linear",
yscale="log",
density=False,
yerr=True,
n_cols=None,
):
"""Plot 1d ref and comparison histograms grouped by category."""
other_cats = set(self.categories)
for other in others:
other_cats.intersection_update(set(other.axes[0]))
if isinstance(density, bool):
if density:
ylabel = "Density"
else:
ylabel = "Counts"
else:
ylabel = density
if n_cols:
n_rows = len(self.categories) % n_cols
else:
n_rows, n_cols = 1, len(self.categories)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.1 * 4 * n_cols, 4 * n_rows))
fig.set_tight_layout(True)
fig.suptitle(ref.__class__.__name__)
axs = axs.flatten()
for idx, cat in enumerate(self.categories):
ref[cat, :].plot(
ax=axs[idx], label=f"{ref.label}", lw=2, density=density, yerr=yerr
)
axs[idx].legend(loc="best")
for other in others:
if cat in other_cats:
other[cat, :].plot(
ax=axs[idx],
label=f"{other.label}",
alpha=0.6,
density=density,
yerr=yerr,
)
axs[idx].legend(loc="best")
axs[idx].set(yscale=yscale, xscale=xscale, ylabel=ylabel)
axs[idx].set_title(cat)
return fig
[docs]
def distance_by_category(self, ref, others):
"""Measure the wasserstein distance between ref and comparison histograms grouped by category."""
result = dict()
other_cats = set(self.categories)
for other in others:
other_cats.intersection_update(set(other.axes[0]))
for cat in self.categories:
ref_cnt, bins = ref._storage[cat, :].to_numpy()
# TODO: use axis.center instead
ref_cent = np.convolve(bins, 2 * [0.5], mode="valid")
cat_result = {ref.name: []}
for other in others:
if cat in other_cats:
cnt, bins = other._storage[cat, :].to_numpy()
cent = np.convolve(bins, 2 * [0.5], mode="valid")
# Weird call needed to properly work with histograms of different binning
# https://stackoverflow.com/questions/76049158/wasserstein-distance-in-scipy-definition-of-support
cat_result[ref.name].append(
{
other.name: st.wasserstein_distance(
ref_cent, cent, ref_cnt, cnt
)
}
)
result[cat] = cat_result
return result
[docs]
def plot_compare_by_category_from_method(
self,
ref,
others,
method_name,
xscale="linear",
yscale="log",
density=False,
yerr=True,
n_cols=None,
):
"""Plot 1d ref and comparison histograms grouped by category."""
# ref_func = getattr(ref, method_name)
# other_funcs = [getattr(other, method_name) for other in others]
# TODO: actually use this for later plotting
other_cats = set(self.categories)
for other in others:
other_cats.intersection_update(set(other.axes[0]))
if isinstance(density, bool):
if density:
ylabel = "Density"
else:
ylabel = "Counts"
else:
ylabel = density
if n_cols:
n_rows = len(self.categories) % n_cols
else:
n_rows, n_cols = 1, len(self.categories)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.1 * 4 * n_cols, 4 * n_rows))
fig.set_tight_layout(True)
fig.suptitle(ref.__class__.__name__)
axs = axs.flatten()
for idx, cat in enumerate(self.categories):
ref[cat, :].plot(
ax=axs[idx], label=f"{ref.label}", lw=2, density=density, yerr=yerr
)
axs[idx].legend(loc="best")
for other in others:
if cat in other_cats:
other[cat, :].plot(
ax=axs[idx],
label=f"{other.label}",
alpha=0.6,
density=density,
yerr=yerr,
)
axs[idx].legend(loc="best")
axs[idx].set(yscale=yscale, xscale=xscale, ylabel=ylabel)
axs[idx].set_title(cat)
return fig
[docs]
def plot_compare_tuples_by_category_in_bins(
self,
ref,
others,
bins,
ybin_label="energy",
ybin_unit="TeV",
tup1_label="gammas",
tup2_label="protons",
density=True,
yerr=True,
):
"""Plot two lines from same metric against two lines of other metric.
Takes tuples of metrics with the shape: ``[category, bin_property,
property]`` and plots the 1d distribution of property in a given bin
from the two tuples into the same axes, making one column of axes per
category, and one row per bin value provided.
"""
other_cats = set(self.categories)
for other, _ in others:
other_cats.intersection_update(set(other.axes[0]))
if isinstance(density, bool):
if density:
ylabel = "Density"
else:
ylabel = "Counts"
else:
ylabel = density
n_rows = len(bins)
n_cols = len(self.categories)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(1.3 * 4 * n_cols, 4 * n_rows))
ref_tup1, ref_tup2 = ref
for idx, cat in enumerate(self.categories):
for idy, bin in enumerate(bins):
ax = axs[idy, idx]
low, hig = ref_tup1.axes[1].bin(bin)
ref_tup1[cat, bin, :].plot(
ax=ax,
label=f"{tup1_label} {ref_tup1.name}",
lw=2,
density=density,
yerr=yerr,
)
ref_tup2[cat, bin, :].plot(
ax=ax,
label=f"{tup2_label} {ref_tup2.name}",
lw=2,
density=density,
yerr=yerr,
)
ax.set(
ylabel=f"{ylabel}\n{ybin_label} {low:.2f} to {hig:.2f} {ybin_unit}"
)
for tup1, tup2 in others:
if cat in other_cats:
tup1[cat, bin, :].plot(
ax=ax,
label=f"{tup1_label} {tup1.name}",
density=density,
yerr=yerr,
)
tup2[cat, bin, :].plot(
ax=ax,
label=f"{tup2_label} {tup2.name}",
density=density,
yerr=yerr,
)
if idy == 0:
ax.set_title(cat)
axs[-1, -1].legend(loc="best")
return fig
[docs]
def compare_1d_metrics(ax, ref, others, external_config={}):
"""Compare several 1d histograms."""
conf = {}
conf["xscale"] = external_config.get("xscale", "log")
conf["yscale"] = external_config.get("yscale", "log")
ylabel = external_config.get("ylabel", "")
textsize = external_config.get("textsize", "larger")
add_legend = external_config.get("add_legend", True)
leged_framealpha = external_config.get("legend_framealpha", 1)
legend_loc = external_config.get("legend_loc", "best")
plot_opts = external_config.get("plot_opts", dict())
if len(ref.axes) > 0:
with np.errstate(divide="ignore", invalid="ignore"):
ref.plot(ax=ax, yerr=False, label=f"{ref.label}", lw=2, **plot_opts)
for other in others:
if len(other.axes) > 0:
other.plot(
ax=ax, yerr=False, label=f"{other.label}", alpha=0.6, **plot_opts
)
ax.set(**conf)
if add_legend:
ax.legend(loc=legend_loc, fontsize=textsize, framealpha=leged_framealpha)
ax.set_ylabel(ylabel, fontsize=textsize)
[docs]
def compare_rat_1d_metrics(ax, ref, others, external_config={}):
"""Compare several 1d histograms by taking their ratio."""
conf = {}
conf["xscale"] = external_config.get("xscale", "log")
conf["yscale"] = external_config.get("yscale", "log")
plot_opts = external_config.get("plot_opts", dict())
ylabel = external_config.get("ylabel", "ratio")
textsize = external_config.get("textsize", "larger")
add_legend = external_config.get("add_legend", True)
leged_framealpha = external_config.get("legend_framealpha", 1)
legend_loc = external_config.get("legend_loc", "best")
ax.axhline(y=1.0, linestyle="--")
try:
for ii, other in enumerate(others):
if len(other.axes) != 1:
logger.warning(
"%s name '%s' label '%s' is empty, skipping",
other.__class__.__name__,
other.name,
other.label,
)
continue
with np.errstate(divide="ignore", invalid="ignore"):
if other.axes[0] != ref.axes[0]:
xs, rat = _interp_ratio(ref, other)
ax.plot(
xs, rat, color=f"C{ii + 1}", label=f"{other.label}", alpha=0.9
)
else:
(other._storage / ref._storage).plot(
ax=ax,
yerr=False,
color=f"C{ii + 1}",
label=f"{other.label}",
alpha=0.9,
**plot_opts,
)
except (RuntimeError, ValueError) as err: # noqa
logger.warning(
"Comparison of %s '%s' and %s '%s' failed due to %s: %s",
ref.__class__.__name__,
ref.name,
other.__class__.__name__,
other.name,
err.__class__.__name__,
err,
)
ax.set(**conf)
if add_legend:
ax.legend(loc=legend_loc, fontsize=textsize, framealpha=leged_framealpha)
ax.set_ylabel(ylabel)
ax.set_xlabel(ref.axes[0].name)
def _interp_ratio(ref, other):
"""Interp ratio."""
breakpoint()
ref_h, ref_x = ref.to_numpy()
oth_h, oth_x = other.to_numpy()
ref_x = np.sqrt(ref_x[1:] * ref_x[:-1])
oth_x = np.sqrt(oth_x[1:] * oth_x[:-1])
if Quantity(ref_x).unit.is_equivalent("TeV"):
ref_x = ref_x.to("TeV")
if Quantity(oth_x).unit.is_equivalent("TeV"):
oth_x = oth_x.to("TeV")
ref_x, oth_x = ref_x.value, oth_x.value
xs = np.logspace(start=np.log10(ref_x.min()), stop=np.log10(ref_x.max()), num=90)
ref_int = np.interp(xs, ref_x, ref_h)
oth_int = np.interp(xs, oth_x, oth_h)
with np.errstate(divide="ignore"):
ratio = oth_int / ref_int
return xs, ratio