"""
Various auxiliary functions for calculating stats.
"""
import astropy.units as u
import deprecation
import numpy as np
import scipy.stats as st
from .version import __version__
[docs]
@deprecation.deprecated(
details="should be replaced with hist.profile()", current_version=__version__
)
def rms_profile_from_2d_hist(
hist, density: bool = False, threshold=15
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Calculate RMS and mean per x-bin in a 2d histogram.
Parameters
----------
hist: Metric|hist.Histogram
Metric to profile, must be 2D
density: bool
If False, assumes the histogram is proportional to counts per bin;
otherwise, assumes it is proportional to a density.
For constant bin widths, these are equivalent, but the distinction
is important when bin widths vary (see Notes).
If None (default), sets ``density=True`` for backwards compatibility,
but warns if the bin widths are variable. Set ``density`` explicitly
to silence the warning.
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray]:
y_rms, y_mean, x_bin_center
"""
h, x_bins, y_bins = hist.to_numpy()
h = u.Quantity(h)
# TODO handle units properly (they fail to be dumped to json if I keep them)
if isinstance(x_bins, u.Quantity):
x_bins = x_bins.value
if isinstance(y_bins, u.Quantity):
y_bins = y_bins.value
x_cent = hist.axes[0].centers
rms = np.full(h[:, 0].shape, np.nan)
mean = np.full(h[:, 0].shape, np.nan)
for idx, row in enumerate(h):
if row.sum().value > threshold:
dens = st.rv_histogram((row, y_bins), density=density)
mean[idx] = dens.mean()
rms[idx] = np.sqrt(mean[idx] ** 2 + dens.std() ** 2)
return rms, mean, x_cent
[docs]
def find_positive_spans(xs, ys, where="post"):
"""Find regions where the provided curve is positive.
Parameters
----------
xs :
x-coordinates of the input points
ys :
y-coordinates of the input points
where:
Where the points provided are supposed to be in the spands of
the stepwise function provided. Accepted values are "pre", "mid",
and "post", behaving like in ``matplotlib.step``. Default is "post"
Returns
-------
spans : list[tuple]
A list of tuples, each tuple being the start and stop positions in x.
Treat input points as a stepwise function and finds the sections
where the ``ys`` are positive. By default considers each constant
section as starting at point ``xs[i]`` and going to ``xs[i+1]``,
having the value ``ys[i]`` in the section.
Results are sensitive to the density of points provided.
"""
if where == "post":
sel = slice(-1)
edges = xs
if where == "pre":
sel = slice(1, None, None)
edges = xs
if where == "mid":
sel = slice(None)
edges = np.zeros(len(xs) + 1)
edges[0] = xs[0]
edges[-1] = xs[-1]
edges[1:-1] = np.convolve(xs, 2 * [0.5], mode="valid")
pos_val = np.nonzero(ys[sel] > 0)[0]
span_idxs = []
prev = first = pos_val[0]
for itm in pos_val[1:]:
if itm - prev == 1:
prev = itm
elif itm - prev > 1:
span_idxs.append((first, prev))
first = prev = itm
span_idxs.append((first, prev))
spans = []
for start, stop in span_idxs:
spans.append((edges[start], edges[stop + 1]))
return spans
[docs]
def calc_log_fraction_of_spans(req_xs, req_ys, xs, ys, lower_better=True):
"""Calculate fraction of range where the curve complies with requirements.
By default "complies" means the curve is below the requirement points.
This function interpollates the curve points linerarly, but does not
interpolate the requirement points.
"""
if lower_better:
diff = np.interp(x=req_xs, xp=xs, fp=ys) - req_ys
else:
diff = req_ys - np.interp(x=req_xs, xp=xs, fp=ys)
fail_spans = find_positive_spans(req_xs, diff)
log_tot_diff = np.log10(req_xs[-1] / req_xs[0])
log_span_diff = 0
res_spans = []
for span in fail_spans:
log_span_diff += np.log10(span[1] / span[0])
res_spans.append(span[0:2])
return log_span_diff / log_tot_diff, res_spans
[docs]
def find_line_segments_by_var_of_prop(cent, prop, lin_threshold):
"""Primitive linearity estimator.
Determines linearity by requiring the absolute difference in successive
points in ``prop`` varies sufficiently slowly, ie the differences are all
small.
"""
qual_metric = np.abs(np.diff(prop))
met_xs = np.convolve(cent, 2 * [0.5], mode="valid")
spans = find_positive_spans(met_xs, lin_threshold - qual_metric)
sp = np.array(spans)
span_len = np.log10(sp[:, 1] / sp[:, 0])
span_len = span_len[span_len.argmax()]
return spans, span_len
[docs]
def normalize_along_axis(
arr: np.ndarray, axis: int = 1, threshold_frac=0.0001, fill_value=np.nan
) -> np.ndarray:
"""
Normalize an ndarray along the given axis so that the sum along that axis is 1.0.
Parameters
----------
arr : np.ndarray
The input array.
axis : int
The axis along which to normalize.
threshold_frac: float
if the integral along the axis is below this fraction of
the total integral, mask off this row as having too low stats to plot.
This prevents low-stats values form saturating the color scale
fill_value:
value to replace low-stats entries with
Returns
-------
np.ndarray
A new array normalized along the specified axis.
"""
# arr = np.asarray(arr, dtype=float).copy() # ensure float for division
integral = arr.sum(axis=axis, keepdims=True)
mask = integral / integral.sum() <= threshold_frac
# avoid division by zero: leave rows with zero sum unchanged
with np.errstate(divide="ignore", invalid="ignore"):
normalized = np.divide(arr, integral, where=~mask)
# Set slices with too-small integrals to zero
normalized = np.where(mask, fill_value, normalized)
return normalized