Source code for datapipe_testbench.auto_benchmark

#!/usr/bin/env python3
"""
Defines what is an AutoBenchmark.
"""

import datetime
import inspect
import itertools
import logging
import os
import time
from pathlib import Path
from typing import override

from ctapipe.io import TableLoader
from hist import axis

from . import auto_constants, auto_lib
from .benchmark import Benchmark
from .store import MetricsStore, ResultStore

__all__ = ["AutoBenchmark"]


[docs] class AutoBenchmark(Benchmark): """A Benchmark with minimal configuration needed. Attributes ---------- datalevel : str dl1, dl2, dl0 (dl1images not supported atm) col_lists : list[tuple] list of tuples of columns, each tuple defines chunk_size : int Chunk size when reading input event data. custom_cols: dict[str, Callable] Custom columns to generate. custom_axis : dict[str,axis.Axis] Override the automatic axis definition for columns. """ col_lists = [] data_level = None chunk_size = auto_constants.DEFAULT_CHUNK_SIZE nevent_threshold = None # If set, do not read all events # Custom accumulators, when an axis is not a simple column for an input file # key: axis_name (either a made up name for a new column, or the name of a # pre-existing column in the input file value: callable with one parameter # (that will be the table of events in table loader). This function need to # return the list of values, one per line in the chunk. You basically do # what you want in it to extract, filter and prepare the data to be filled # in the histogram custom_cols = {} # All columns in the input files should have default axis (min, max, number # of bins,...). Here you can overwrite those definitions if you want to. For # custom column you created, you have to, because no default values are # defined for those. key: axis_name (either a made up name for a new column, # or the name of a pre-existing column in the input file value: # hist.axis.axis object (e.g. hist.axis.Regular) custom_axis = {} def __init__(self): super().__init__() if self.data_level is None: raise ValueError("Subclass need to set data_level") if len(self.col_lists) == 0: raise ValueError("Subclass need to define col_lists") # From col_lists, extract the custom ones for which we do not have a default columns # For those ones, we need both an axis and a function to accumulate. default_cols = set(list(auto_constants.default_axis.keys())) all_cols = set(itertools.chain.from_iterable(self.col_lists)) required_extra_cols = all_cols.difference(default_cols) # Check that every custom column is defined in custom_cols missing_cols = required_extra_cols.difference(set(self.custom_cols.keys())) if len(missing_cols) > 0: msg = "The following columns need to be defined in custom_cols (not default columns):\n" msg += ", ".join(required_extra_cols) raise ValueError(msg) # Check that every custom column is defined in custom_axis missing_cols = required_extra_cols.difference(set(self.custom_axis.keys())) if len(missing_cols) > 0: msg = "The following columns need to be defined in custom_axis (not default columns):\n" msg += ", ".join(required_extra_cols) raise ValueError(msg) # Check that every value in custom_cols is a function wrong_values = [] for k, v in self.custom_cols.items(): if not callable(v) or len(inspect.signature(v).parameters) != 1: wrong_values = [k] if len(wrong_values) > 0: msg = "custom_cols: Value is not a callable with one argument for the following keys:\n" msg += ", ".join(wrong_values) raise ValueError(msg) # Check that every value in custom_axis has an axis type wrong_values = [] for k, v in self.custom_axis.items(): if not isinstance(v, axis.AxesMixin): wrong_values = [k] if len(wrong_values) > 0: msg = "custom_axis: Value is not of type hist.axis.AxesMixin for the following keys (e.g. axis.Regular):\n" msg += ", ".join(wrong_values) raise TypeError(msg) # Get columns that are needed based on the custom_cols accumulation function extra_accessible_cols = [] for k, v in self.custom_cols.items(): extra_accessible_cols.extend(auto_lib.inspect_func(v)) accessible_cols = set(extra_accessible_cols) accessible_cols = accessible_cols.union(all_cols) # Get rid of custom cols # that are not columns in the file, but rather custom ones used in the axis. accessible_cols = accessible_cols.difference(set(list(self.custom_cols.keys()))) # Some columns need a custom accumulator, so give them a default one if none was provided for col, func in auto_constants.custom_cols.items(): if col in accessible_cols and col not in self.custom_cols: self.custom_cols[col] = func # Find and set flags for TableLoader method_type, flags = auto_lib.get_tableloader_flags(accessible_cols) print(f"Method_type: {method_type} ; flags: {flags}") self.callable_method_name = auto_constants.method_name[method_type] # Prepare table loader options (everything to False, except the one necessary to access the columns we want self.default_options = dict.fromkeys( auto_constants.table_loader_options[method_type], False ) for flag in flags: self.default_options[flag] = True # Path to store files inside the store self.metric_path = Path(f"{self.data_level}/{self.__class__.__name__}") # no custom_axis for now, to be implemented later self.hists = auto_lib.make_hists(self.col_lists, custom_axis=self.custom_axis) # Set default logger for instance self._log = logging.getLogger(self.__class__.__name__)
[docs] @override def generate_metrics(self, metric_store: MetricsStore): filename = getattr(metric_store.get_inputdata(), self.data_level) if filename is None: msg = f"{metric_store.name} has no file set for data_level: {self.data_level}." raise ValueError(msg) asdf_files = { k: os.path.join(self.metric_path, f"hist_{k}.asdf") for k in self.hists.keys() } hist_keys = list(self.hists.keys()) # Check if all metric files exist. If not, Accumulate or reaccumulate all of them parse_file = False for f in asdf_files.values(): if not metric_store.data_exists(f): parse_file = True if parse_file: # Accumulate tstart = time.time() with TableLoader(filename) as loader: nevents = len(loader) # Threshold can be None when there's no threshold, so we have to test that nevents = min( x for x in [nevents, self.nevent_threshold] if x is not None ) callable_method = getattr(loader, self.callable_method_name) for start, stop, events in callable_method( chunk_size=self.chunk_size, **self.default_options ): progress = start / nevents * 100 duration = time.time() - tstart eta = duration * (100 - progress) / (progress + 1e-9) print( f"\rRead {filename}: {progress:.1f}% (ETA: " f"{datetime.timedelta(seconds=eta)}s) ", end="", ) # Compute the custom values once per chunk custom_values = {k: f(events) for k, f in self.custom_cols.items()} for col_identifier, h in self.hists.items(): cols = col_identifier.split(auto_constants.column_separator) data = [] for col in cols: if col in self.custom_cols: data.append(custom_values[col]) else: data.append(events[col]) h.fill(*data) # Don't read everything if stop > nevents: break duration = time.time() - tstart self._log.info( f"\rRead {filename}: Done in {datetime.timedelta(seconds=duration)}" ) # Save for later self._log.info(f"Save files to {metric_store.base_path}.") for key, h in self.hists.items(): out_filename = asdf_files[key] metric_store.store_data(h, out_filename) else: self._log.info(f"Load files from {metric_store.base_path}.") for name in hist_keys: asdf_file = asdf_files[name] self.hists[name] = metric_store.retrieve_data(asdf_file) return self.hists
[docs] def plot_all(self, mstore: MetricsStore, rstore: ResultStore): """Plot all histograms related to the current benchmark from one metric store. This is a convenience function that basically call compare_to_reference with only one metric store. Parameters ---------- mstore : MetricsStore Input MetricsStore (contains all histograms intermediate files). rstore : ResultStore Output ResultStore (will contains all plots files). """ self.compare_to_reference([mstore], rstore)
[docs] @override def compare_to_reference( self, metric_store_list: list[MetricsStore], result_store: ResultStore ): # Will look into each MetricsStore for all associated histograms, then # compare them one by one. The best plot will be automatically chosen # based on the columns and their numbers. # Take only the metrics for the current benchmark store_metrics = {} for mstore in metric_store_list: indexed = mstore.index_metrics() store_metrics[mstore.name] = indexed[str(self.metric_path)] # Use first store as reference, then loop over all metrics # Each metric is one plot for metric_identifier in store_metrics[metric_store_list[0].name].keys(): hists = {} for store_name, tmp_metrics in store_metrics.items(): hists[store_name] = tmp_metrics[metric_identifier] # Call individual plot for the list of histograms auto_lib.auto_plot(hists, result_store)
[docs] @override def make_report(self, result_store: ResultStore): # Report is automatic, and use figure name to infer the columns and their numbers. auto_lib.write_report(result_store)