#!/usr/bin/env python3
"""
Defines what is an AutoBenchmark.
"""
import datetime
import inspect
import itertools
import logging
import os
import time
from pathlib import Path
from typing import override
from ctapipe.io import TableLoader
from hist import axis
from . import auto_constants, auto_lib
from .benchmark import Benchmark
from .store import MetricsStore, ResultStore
__all__ = ["AutoBenchmark"]
[docs]
class AutoBenchmark(Benchmark):
"""A Benchmark with minimal configuration needed.
Attributes
----------
datalevel : str
dl1, dl2, dl0 (dl1images not supported atm)
col_lists : list[tuple]
list of tuples of columns, each tuple defines
chunk_size : int
Chunk size when reading input event data.
custom_cols: dict[str, Callable]
Custom columns to generate.
custom_axis : dict[str,axis.Axis]
Override the automatic axis definition for columns.
"""
col_lists = []
data_level = None
chunk_size = auto_constants.DEFAULT_CHUNK_SIZE
nevent_threshold = None # If set, do not read all events
# Custom accumulators, when an axis is not a simple column for an input file
# key: axis_name (either a made up name for a new column, or the name of a
# pre-existing column in the input file value: callable with one parameter
# (that will be the table of events in table loader). This function need to
# return the list of values, one per line in the chunk. You basically do
# what you want in it to extract, filter and prepare the data to be filled
# in the histogram
custom_cols = {}
# All columns in the input files should have default axis (min, max, number
# of bins,...). Here you can overwrite those definitions if you want to. For
# custom column you created, you have to, because no default values are
# defined for those. key: axis_name (either a made up name for a new column,
# or the name of a pre-existing column in the input file value:
# hist.axis.axis object (e.g. hist.axis.Regular)
custom_axis = {}
def __init__(self):
super().__init__()
if self.data_level is None:
raise ValueError("Subclass need to set data_level")
if len(self.col_lists) == 0:
raise ValueError("Subclass need to define col_lists")
# From col_lists, extract the custom ones for which we do not have a default columns
# For those ones, we need both an axis and a function to accumulate.
default_cols = set(list(auto_constants.default_axis.keys()))
all_cols = set(itertools.chain.from_iterable(self.col_lists))
required_extra_cols = all_cols.difference(default_cols)
# Check that every custom column is defined in custom_cols
missing_cols = required_extra_cols.difference(set(self.custom_cols.keys()))
if len(missing_cols) > 0:
msg = "The following columns need to be defined in custom_cols (not default columns):\n"
msg += ", ".join(required_extra_cols)
raise ValueError(msg)
# Check that every custom column is defined in custom_axis
missing_cols = required_extra_cols.difference(set(self.custom_axis.keys()))
if len(missing_cols) > 0:
msg = "The following columns need to be defined in custom_axis (not default columns):\n"
msg += ", ".join(required_extra_cols)
raise ValueError(msg)
# Check that every value in custom_cols is a function
wrong_values = []
for k, v in self.custom_cols.items():
if not callable(v) or len(inspect.signature(v).parameters) != 1:
wrong_values = [k]
if len(wrong_values) > 0:
msg = "custom_cols: Value is not a callable with one argument for the following keys:\n"
msg += ", ".join(wrong_values)
raise ValueError(msg)
# Check that every value in custom_axis has an axis type
wrong_values = []
for k, v in self.custom_axis.items():
if not isinstance(v, axis.AxesMixin):
wrong_values = [k]
if len(wrong_values) > 0:
msg = "custom_axis: Value is not of type hist.axis.AxesMixin for the following keys (e.g. axis.Regular):\n"
msg += ", ".join(wrong_values)
raise TypeError(msg)
# Get columns that are needed based on the custom_cols accumulation function
extra_accessible_cols = []
for k, v in self.custom_cols.items():
extra_accessible_cols.extend(auto_lib.inspect_func(v))
accessible_cols = set(extra_accessible_cols)
accessible_cols = accessible_cols.union(all_cols)
# Get rid of custom cols
# that are not columns in the file, but rather custom ones used in the axis.
accessible_cols = accessible_cols.difference(set(list(self.custom_cols.keys())))
# Some columns need a custom accumulator, so give them a default one if none was provided
for col, func in auto_constants.custom_cols.items():
if col in accessible_cols and col not in self.custom_cols:
self.custom_cols[col] = func
# Find and set flags for TableLoader
method_type, flags = auto_lib.get_tableloader_flags(accessible_cols)
print(f"Method_type: {method_type} ; flags: {flags}")
self.callable_method_name = auto_constants.method_name[method_type]
# Prepare table loader options (everything to False, except the one necessary to access the columns we want
self.default_options = dict.fromkeys(
auto_constants.table_loader_options[method_type], False
)
for flag in flags:
self.default_options[flag] = True
# Path to store files inside the store
self.metric_path = Path(f"{self.data_level}/{self.__class__.__name__}")
# no custom_axis for now, to be implemented later
self.hists = auto_lib.make_hists(self.col_lists, custom_axis=self.custom_axis)
# Set default logger for instance
self._log = logging.getLogger(self.__class__.__name__)
[docs]
@override
def generate_metrics(self, metric_store: MetricsStore):
filename = getattr(metric_store.get_inputdata(), self.data_level)
if filename is None:
msg = f"{metric_store.name} has no file set for data_level: {self.data_level}."
raise ValueError(msg)
asdf_files = {
k: os.path.join(self.metric_path, f"hist_{k}.asdf")
for k in self.hists.keys()
}
hist_keys = list(self.hists.keys())
# Check if all metric files exist. If not, Accumulate or reaccumulate all of them
parse_file = False
for f in asdf_files.values():
if not metric_store.data_exists(f):
parse_file = True
if parse_file:
# Accumulate
tstart = time.time()
with TableLoader(filename) as loader:
nevents = len(loader)
# Threshold can be None when there's no threshold, so we have to test that
nevents = min(
x for x in [nevents, self.nevent_threshold] if x is not None
)
callable_method = getattr(loader, self.callable_method_name)
for start, stop, events in callable_method(
chunk_size=self.chunk_size, **self.default_options
):
progress = start / nevents * 100
duration = time.time() - tstart
eta = duration * (100 - progress) / (progress + 1e-9)
print(
f"\rRead {filename}: {progress:.1f}% (ETA: "
f"{datetime.timedelta(seconds=eta)}s) ",
end="",
)
# Compute the custom values once per chunk
custom_values = {k: f(events) for k, f in self.custom_cols.items()}
for col_identifier, h in self.hists.items():
cols = col_identifier.split(auto_constants.column_separator)
data = []
for col in cols:
if col in self.custom_cols:
data.append(custom_values[col])
else:
data.append(events[col])
h.fill(*data)
# Don't read everything
if stop > nevents:
break
duration = time.time() - tstart
self._log.info(
f"\rRead {filename}: Done in {datetime.timedelta(seconds=duration)}"
)
# Save for later
self._log.info(f"Save files to {metric_store.base_path}.")
for key, h in self.hists.items():
out_filename = asdf_files[key]
metric_store.store_data(h, out_filename)
else:
self._log.info(f"Load files from {metric_store.base_path}.")
for name in hist_keys:
asdf_file = asdf_files[name]
self.hists[name] = metric_store.retrieve_data(asdf_file)
return self.hists
[docs]
def plot_all(self, mstore: MetricsStore, rstore: ResultStore):
"""Plot all histograms related to the current benchmark from one metric store.
This is a convenience function that basically call compare_to_reference with only one metric store.
Parameters
----------
mstore : MetricsStore
Input MetricsStore (contains all histograms intermediate files).
rstore : ResultStore
Output ResultStore (will contains all plots files).
"""
self.compare_to_reference([mstore], rstore)
[docs]
@override
def compare_to_reference(
self, metric_store_list: list[MetricsStore], result_store: ResultStore
):
# Will look into each MetricsStore for all associated histograms, then
# compare them one by one. The best plot will be automatically chosen
# based on the columns and their numbers.
# Take only the metrics for the current benchmark
store_metrics = {}
for mstore in metric_store_list:
indexed = mstore.index_metrics()
store_metrics[mstore.name] = indexed[str(self.metric_path)]
# Use first store as reference, then loop over all metrics
# Each metric is one plot
for metric_identifier in store_metrics[metric_store_list[0].name].keys():
hists = {}
for store_name, tmp_metrics in store_metrics.items():
hists[store_name] = tmp_metrics[metric_identifier]
# Call individual plot for the list of histograms
auto_lib.auto_plot(hists, result_store)
[docs]
@override
def make_report(self, result_store: ResultStore):
# Report is automatic, and use figure name to infer the columns and their numbers.
auto_lib.write_report(result_store)