#!/usr/bin/env python3
"""
Helper functions for plotting.
"""
import matplotlib as mpl
import matplotlib.colors as col
import matplotlib.patheffects as ptheff
import matplotlib.pyplot as plt
import matplotlib.scale as scl
import numpy as np
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import (
AutoLocator,
AutoMinorLocator,
NullFormatter,
NullLocator,
ScalarFormatter,
)
from numpy import ma
from pyirf.binning import join_bin_lo_hi
from .compare_1d import compare_1d_metrics, compare_rat_1d_metrics
[docs]
class SqrtScale(scl.ScaleBase):
"""
Scale for bins in theta square or equivalent.
Parameters
----------
axis : `~matplotlib.axis.Axis`
The axis for the scale.
"""
name = "sqrt"
def __init__(self, axis):
transform = scl.FuncTransform(lambda e: e**0.5, lambda e: e**2)
self._transform = transform
[docs]
def setup_profiledata_axes(data, ax: Axes | None = None) -> Axes:
"""Set axis properties from a ProfileData object.
We don't want a full plot() function here, as that is too inflexiable.
This just sets up the attributes of the axis object
Properties
----------
ax: Optional[Axes]
Axes to modify. If None, will use current axis
"""
ax = plt.gca() if ax is None else ax
ax.set_title(data.title)
ax.set_xscale(data.scales[0])
ax.set_yscale(data.scales[1])
ax.set_xlabel(data.labels[0])
ax.set_ylabel(data.labels[1])
return ax
[docs]
def plot_spans(ax, spans, y=0, height=0.2, alpha=0.3):
"""Easily plot spans found by `datapipe_testbench.stat_helpers.find_positive_spans`."""
for span in spans:
ax.barh(
y=y, width=(span[1] - span[0]), height=height, left=span[0], alpha=alpha
)
[docs]
def plot_complement_spans(span_dict, is_good_span, xlabel, title_prefix, log=True):
"""Make summary plots of a set of spans.
Parameters
----------
span_dict :
Dictionary with all necessary information to do the rectangles plots.
is_good_span : bool
Indicates if a given span shows something positive, colouring it green.
If false the span will be painted using a red colour.
xlabel :
title_prefix :
log :
By default, True.
"""
if is_good_span:
notok_color = "red"
ok_color = "seagreen"
else:
notok_color = "seagreen"
ok_color = "red"
xmin, xmax = span_dict["xrange"]
figs = []
for cam, res in span_dict["xspans"].items():
tests = res.keys()
fig, axs = plt.subplots(
len(tests), 1, figsize=(6, 1.2 * 0.5 * len(tests)), sharex=True
)
for idx, test in enumerate(tests):
ax = axs[idx]
ok_box_edges = res[test]
ax.broken_barh([(xmin, xmax - xmin)], (0, 1), color=notok_color)
rect_params = [(x1, x2 - x1) for (x1, x2) in ok_box_edges]
ax.broken_barh(rect_params, (0, 1), color=ok_color)
txt = ax.text(
0.5,
0.5,
test,
transform=ax.transAxes,
horizontalalignment="center",
verticalalignment="center",
)
txt.set_path_effects([ptheff.withStroke(linewidth=3, foreground="white")])
if log:
ax.set(xscale="log", xlabel=xlabel, xlim=(xmin * 0.9, xmax * 1.1))
else:
dx = 0.01 * (xmax - xmin)
ax.set(xlabel=xlabel, xlim=(xmin - dx, xmax + dx))
ax.axes.get_yaxis().set_visible(False)
fig.suptitle(f"{title_prefix}, {cam}")
figs.append(fig)
return figs
[docs]
def plot_fail_spans(req_result, xlabel, title_prefix, log=True):
"""Make summary plot of 'fail spans'."""
figs = plot_complement_spans(
req_result["fail spans"],
is_good_span=False,
xlabel=xlabel,
title_prefix=title_prefix,
log=log,
)
return figs
[docs]
def plot_linear_spans(req_result, xlabel, title_prefix, log=True):
"""Make summary plot of linear regions."""
figs = plot_complement_spans(
req_result["linear spans"],
is_good_span=True,
xlabel=xlabel,
title_prefix=title_prefix,
log=log,
)
return figs
[docs]
def make_table_rows(res_dict):
"""Transform a result dict containing distance measures to a list that mdutils table function will accept."""
rows = []
header = ""
for ref_key, distances in res_dict.items():
prefix, *tmp = ref_key.split()
ref_input = " ".join(tmp)
header = [f"{prefix}: Distance to"]
rows.append([ref_input]) # New row for each item in main dict
for d in distances:
# Dict should only have one key
k = list(d.keys())[0]
v = f"{d[k]:.2f}"
# Strip inputdata name from benchmark name since this name is the same for both the reference and the
# comparison.
input_name = k.removeprefix(f"{prefix} ")
header.append(input_name)
rows[-1].append(v)
# Add header only once
rows.insert(0, header)
return rows
[docs]
def plot_irf_table(
ax, table, column, prefix=None, lo_name=None, hi_name=None, label=None, **mpl_args
):
"""Easily plot irf table."""
if isinstance(column, str):
vals = np.squeeze(table[column])
else:
vals = column
if prefix:
lo = table[f"{prefix}_LO"]
hi = table[f"{prefix}_HI"]
elif hi_name and lo_name:
lo = table[lo_name]
hi = table[hi_name]
else:
raise ValueError(
"Either prefix or both `lo_name` and `hi_name` has to be given"
)
if not label:
label = column
bins = np.squeeze(join_bin_lo_hi(lo, hi))
ax.stairs(vals, bins, label=label, **mpl_args)
[docs]
class PerformancePoster:
"""Class for making a performance poster."""
[docs]
def sens_compare(self, ax, ref, others, yscale="log", **extra_conf):
"""Compare sensitivity metrics."""
conf = {
"yscale": yscale,
"ylabel": "Sensitivity",
"legend_loc": "upper right",
"plot_opts": dict(histtype="errorbar", xerr=True),
}
conf.update(extra_conf)
compare_1d_metrics(ax, ref, others, conf)
[docs]
def sens_rat_compare(self, ax, ref, others, yscale="linear", **extra_conf):
"""Compare sensitivity metrics using ratio."""
conf = {
"yscale": yscale,
"ylabel": "Sensitivity ratio",
"legend_loc": "upper center",
}
conf.update(extra_conf)
compare_rat_1d_metrics(ax, ref, others, conf)
[docs]
def aeff_compare(self, ax, ref, others, yscale="log", **extra_conf):
"""Compare effective area metrics."""
conf = {
"yscale": yscale,
"ylabel": "Effective area",
"legend_loc": "lower center",
}
conf.update(extra_conf)
compare_1d_metrics(ax, ref, others, conf)
[docs]
def aeff_rat_compare(self, ax, ref, others, yscale="linear", **extra_conf):
"""Compare effective area metrics using ratio."""
conf = {
"yscale": yscale,
"ylabel": "Effective area ratio",
"legend_loc": "lower center",
}
conf.update(extra_conf)
compare_rat_1d_metrics(ax, ref, others, conf)
[docs]
def angres_compare(self, ax, ref, others, yscale="linear", **extra_conf):
"""Compare angular resolution metrics."""
conf = {
"yscale": yscale,
"ylabel": "Angular resolution",
"legend_loc": "upper right",
}
conf.update(extra_conf)
compare_1d_metrics(ax, ref, others, conf)
[docs]
def angres_rat_compare(self, ax, ref, others, yscale="linear", **extra_conf):
"""Compare angular resolution metrics using ratio."""
conf = {
"yscale": yscale,
"ylabel": "Angular resolution ratio",
"legend_loc": "lower center",
}
conf.update(extra_conf)
compare_rat_1d_metrics(ax, ref, others, conf)
[docs]
def e_bias_res_compare(
self, ax, ref, others, kind="Energy resolution", yscale="linear", **extra_conf
):
"""Compare energy resolution and energy bias metrics."""
conf = {"yscale": yscale, "ylabel": kind}
conf.update(extra_conf)
compare_1d_metrics(ax, ref, others, conf)
[docs]
def e_bias_res_rat_compare(
self, ax, ref, others, kind="Energy resolution", yscale="linear", **extra_conf
):
"""Compare energy resolution and energy bias metrics using ratio."""
conf = {"yscale": yscale, "ylabel": f"{kind} ratio"}
conf.update(extra_conf)
compare_rat_1d_metrics(ax, ref, others, conf)
[docs]
def bkg_compare(self, ax, ref, others, yscale="log", **extra_conf):
"""Compare background rate metrics."""
conf = {"yscale": yscale, "ylabel": "Background rate"}
conf.update(extra_conf)
compare_1d_metrics(ax, ref, others, conf)
[docs]
def bkg_rat_compare(self, ax, ref, others, yscale="linear", **extra_conf):
"""Compare background rate metrics using ratio."""
conf = {"yscale": yscale, "ylabel": "Background ratio"}
conf.update(extra_conf)
compare_rat_1d_metrics(ax, ref, others, conf)
[docs]
def gernot_plot(self, ref, other):
"""Produce a Gernot plot."""
sens_ref, sens_other = ref["sensitivity"], other["sensitivity"]
angres_ref, angres_other = (
ref["angular resolution"],
other["angular resolution"],
)
eres_ref, eres_other = ref["energy resolution"], other["energy resolution"]
aeff_ref, aeff_other = ref["effective area"], other["effective area"]
bkg_ref, bkg_other = ref["background"], other["background"]
fig = plt.figure(figsize=(12, 8), layout="constrained")
grid = GridSpec(6, 10, figure=fig)
plot_conf = {
"legend_framealpha": 0,
}
axis_dict = {}
axis_dict["sensitivity"] = fig.add_subplot(grid[0:4, 0:6])
axis_dict["sensitivity_ratio"] = fig.add_subplot(grid[4:6, 0:4])
axis_dict["eres"] = fig.add_subplot(grid[4:6, 4:6])
axis_dict["eff_area"] = fig.add_subplot(grid[0:2, 6:8])
axis_dict["eff_area_ratio"] = fig.add_subplot(grid[0:2, 8:10])
axis_dict["bkg_rate"] = fig.add_subplot(grid[2:4, 6:8])
axis_dict["bkg_rate_ratio"] = fig.add_subplot(grid[2:4, 8:10])
axis_dict["angres"] = fig.add_subplot(grid[4:6, 6:8])
axis_dict["angres_ratio"] = fig.add_subplot(grid[4:6, 8:10])
self.sens_compare(axis_dict["sensitivity"], sens_ref, sens_other, **plot_conf)
plot_conf["add_legend"] = False
self.sens_rat_compare(
axis_dict["sensitivity_ratio"], sens_ref, sens_other, **plot_conf
)
self.e_bias_res_compare(axis_dict["eres"], eres_ref, eres_other, **plot_conf)
self.aeff_compare(axis_dict["eff_area"], aeff_ref, aeff_other, **plot_conf)
self.aeff_rat_compare(
axis_dict["eff_area_ratio"], aeff_ref, aeff_other, **plot_conf
)
self.bkg_compare(axis_dict["bkg_rate"], bkg_ref, bkg_other, **plot_conf)
self.bkg_rat_compare(
axis_dict["bkg_rate_ratio"], bkg_ref, bkg_other, **plot_conf
)
self.angres_compare(axis_dict["angres"], angres_ref, angres_other, **plot_conf)
self.angres_rat_compare(
axis_dict["angres_ratio"], angres_ref, angres_other, **plot_conf
)
return fig
def __non_zero(arr):
"""Small helper for broadcasting."""
idx = np.flatnonzero(arr)
if len(idx) == 0:
return np.array(-1)
return idx[0]