#!/usr/bin/env python3
"""
Defines DataSets: collections of stored Data objects related to a single analysis.
"""
import glob
import json
import logging
import os
from pathlib import Path
import deprecation
from astropy.time import Time
from matplotlib.figure import FigureBase
from . import storable, utils
from .constants import plot_format
from .inputdataset import InputDataset
__all__ = [
"Metadata",
"MetricsStore",
"RequirementStore",
"ResultStore",
"Store",
]
logger = logging.getLogger(__name__)
class ValueChangedError(ValueError):
"""Raised when user tries to overwrite an exiting key with a new value."""
class StoreTypeError(Exception):
"""Raised when wrong store type is used."""
class SubpathStoreError(Exception):
"""Raised when wrong sub-path is used."""
[docs]
class Store:
"""Base class for storage of sets of items and common metadata."""
# Can't have multiple instance created from the same base_path.
# dict is common to all sub classes because no matter the class, it's a risk of conflict between stores
_instances = {}
def __new__(cls, base_path, *args, **kw):
"""Overwrite __new__ to deal with some edge cases.
If a Store with the same Class and same base_path already exist, simply return it instead
If a Store exist, but from a different class, return a StoreTypeError
If a Store exist in a parent path, return a SubpathStoreError
Else, simply call init
"""
input_path = Path(base_path).resolve()
base_paths = list(Store._instances.keys())
# Test if base_path is already defined in another store
if input_path in base_paths:
existing_instance = Store._instances[input_path]
existing_class = existing_instance.__class__
if cls == existing_class:
if args or kw:
logger.warning(
"Store '%s' already exists, (args=%s;kwargs=%s)",
cls.__name__,
args,
kw,
)
return existing_instance
msg = f"Can't create {cls.__name__}, {existing_class.__name__} already exist with the same path: {base_path}"
raise StoreTypeError(msg)
# exist or is sub_path of pre-existing path
for p in base_paths:
if input_path.is_relative_to(p):
msg = f"Can't init a {cls.__name__} ({input_path}) in subpath of pre-existing Store ({p})"
raise SubpathStoreError(msg)
# Instantiate after all other checks
obj = super().__new__(cls)
return obj
def __init__(
self,
base_path: Path,
name: str = None,
):
"""Init function."""
# Store absolute path
self.base_path = Path(base_path).resolve()
if not self.base_path.exists():
# create the output path
os.makedirs(self.base_path) # Handle recursivity
self.metadata = Metadata(path=self.base_path / "metadata.json")
if "creation_date" not in self.metadata:
self.metadata["creation_date"] = Time.now().iso
# Add store type and prevent loading an existing store using a different class than the one defined in its
# metadata
my_class = self.__class__.__name__
try:
self.metadata["store_type"] = my_class
except ValueError:
store_type = self.metadata["store_type"]
msg = f"Unable to set {self.__class__.__name__} in '{self.base_path}', {store_type} expected."
raise StoreTypeError(msg)
# name need to be set manually without using getter and setter for this one time, to handle pre-existing stores
# If we load, without specifying a name, we need to retrieve the one from store and not fail because name=None
if name is not None or "name" not in self.metadata:
self.name = name
# Only add instance to tracker dict at the end of __init__ when we're sure the instance is valid.
Store._instances[self.base_path] = self
[docs]
def contents(self, pattern="**"):
"""List all files in the store, ordered by extension."""
filelist = glob.glob(os.path.join(self.base_path, pattern), recursive=True)
# Remove internal files
filelist.remove(str(self.metadata.path))
# Keep only files
filelist = [it for it in filelist if os.path.isfile(it)]
# Put relative path
filelist = [os.path.relpath(it, self.base_path) for it in filelist]
# Order files primarily by extension, then filename
filelist = utils.sort_files_by_ext(filelist)
return filelist
def __str__(self):
"""Str function."""
repr = f"{self.__class__.__name__}: {self.name}\n"
repr += f"base_path: {self.base_path}\n\n"
repr += f"metadata:\n{self.metadata_repr(indent=1)}\n"
repr += "files:\n"
for f in self.contents():
repr += f"\t{f}\n"
return repr
@property
def name(self):
"""The name of this dataset, used as a label when comparing."""
return self.metadata["name"]
@name.setter
def name(self, thename):
"""Name this results set, if not done already."""
self.metadata["name"] = thename
[docs]
def data_exists(self, path: Path) -> bool:
"""Check if the corresponding path exist in this store.
Parameters
----------
path : Path
Path to an existing store to test.
"""
# TODO: replace this very simple test with something better
input_file = self.base_path / Path(path)
return input_file.exists()
[docs]
def store_data(self, data: storable.Storable | FigureBase, path: Path | str):
"""Store data in this Store at the given relative path."""
output = self.base_path / Path(path)
output.parent.mkdir(
parents=True, exist_ok=True
) # create the directories if needed
if isinstance(data, storable.Storable):
data.save(output)
elif isinstance(data, FigureBase):
output = output.with_suffix(f".{plot_format}")
data.savefig(output, format=plot_format, bbox_inches="tight")
logger.info("Wrote: %s", output)
[docs]
def retrieve_data(self, path: Path) -> storable.Storable:
"""Retrieve data from this store corresponding to the input path.
Parameters
----------
path : Path
Returns
-------
storable.Storable
"""
input_file = self.base_path / Path(path)
data = storable.open(input_file)
return data
[docs]
class MetricsStore(Store):
"""Manage storage of multiple related Data objects.
This class manages a directory structure defined by the ``base_path`` attribute
under which data objects are stored and retrieved by identifier, which are
simply relative paths. Additionally, this class manages global metadata that
apply to all Data inside.
"""
def __init__(self, base_path: Path, name: str = None, label=None):
"""Store designed for Metrics, i.e. reduced bins that were or will be produced using input_data.
Parameters
----------
label :
By default, None.
name : str, optional
By default, None.
str base_path :
Base path On the hard drive where all the store structure will be save.
"""
super().__init__(base_path, name=name)
if label is None:
label = self.name
self.label = label
[docs]
def index_metrics(self):
"""Return all identified metrics in the store.
Structure of output dict is as follow:
- key: data_level/BenchmarkClass
- value: dict of histograms:
- key: tuple of all column names
- value: Metric instance
Returns
-------
dict[str,dict[str,Metric]]
Dict of all Metrics in the store, sorted by data_level/Benchmark.
"""
metrics = {}
all_files = self.contents()
metric_files = [x for x in all_files if x.endswith(".asdf")]
for path in metric_files:
key, basename = os.path.split(path)
if key not in metrics:
metrics[key] = {}
h = self.retrieve_data(path)
identifier = h.get_identifier()
metrics[key][identifier] = h
return metrics
[docs]
class ResultStore(Store):
"""Storage of plots and reports."""
def __init__(self, *args, **kw):
"""Init function."""
super().__init__(*args, **kw)
if "benchmarks" not in self.metadata:
self.metadata["benchmarks"] = {}
[docs]
@deprecation.deprecated(
details="Use Store.contents() instead. (filetype can't be given, this new method list ALL "
"files in the store)"
)
def list_store(self):
"""List files in the store matching a supplied glob.
Default glob: ``*.asdf``.
"""
plot_files = [
str(itm.relative_to(self.base_path)) for itm in self.base_path.rglob("*svg")
]
data_files = [
str(itm.relative_to(self.base_path))
for itm in self.base_path.rglob("*json")
]
return plot_files + data_files
[docs]
@deprecation.deprecated(
details="Use constructor with path of pre-existing store as argument instead."
)
@classmethod
def from_path(cls, path: Path | str):
"""Create MetricsStore from path to an existing store."""
path = Path(path) # ensure it's a path object
meta = Metadata(path / "metadata.json")
store = cls(base_path=path, name=meta["name"])
store.metadata.update(meta._metadata)
return store
[docs]
class RequirementStore(Store):
"""Store for requirement inputs."""
class Node(dict):
"""Class derived from dict to enforce that you can't overwrite an already existing value.
An error is thrown only if a value already exist and is different than the
one we want to set. This extra check is done because we expect to set the
same value multiple times in case of rerun. We also explicitly allow to set
a new parameter that was not defined before.
"""
def __init__(self, indict=None):
tmp = {}
if indict is not None:
for k, v in indict.items():
if isinstance(v, dict):
v = Node(v)
tmp[k] = v
super().__init__(tmp)
def update(self, indict):
"""Update function."""
is_safe, msg = is_safe_update(self, indict)
if is_safe:
super().update(indict)
else:
msg = "Can't update already existing metadata: \n" + msg
raise ValueChangedError(msg)
def __setitem__(self, key, value):
"""Setitem function."""
# Force all sub-nodes to be Node objects and not dict object
if isinstance(value, dict):
value = Node(value)
try:
old_value = super().__getitem__(key)
except KeyError:
old_value = None
if old_value in [None, value]:
super().__setitem__(key, value)
else:
msg = f"Metadata key '{key}'='{old_value}' already exists. Can't set to '{value}'."
raise ValueChangedError(msg)
def to_dict(self):
"""To dict."""
tmp = {}
for k, v in self.items():
if isinstance(v, Node):
v = dict(v)
tmp[k] = v
return tmp
def json_handler(obj):
"""
Convert things not supported to JSON.
"""
if isinstance(obj, Path):
return str(obj)
raise TypeError(f"{obj!r} cannot be serialized to json")
def is_safe_update(d1, d2, msg=None, prefix=None):
"""Check if it's safe to update d1 with d2.
Check that no key in d2 (in nested directories) will overwrite an already existing key whose value is neither
None nor another sub-directory.
Everything in d1 should be in d2 Everything in d2 should either not be in
d1, or with safe values (None or dict)
Parameters
----------
d1 : dict
Reference dictionary.
d2 : dict
New dictionary to update the reference with.
msg : str, optional
Internal parameter for recursive function. DO NOT USE. By default, None.
prefix : str, optional
Internal parameter for recursive function. DO NOT USE. By default, None.
Returns
-------
is_safe : bool
If update is safe.
msg : str
Information about why failed.
"""
is_safe = True
if not msg:
msg = ""
keys1 = set(d1.keys())
keys2 = set(d2.keys())
d1_prefix = "meta"
d2_prefix = "update"
if prefix:
d1_prefix += prefix
d2_prefix += prefix
common_keys = keys1.intersection(keys2)
# Common keys
for key in common_keys:
value1 = d1[key]
value2 = d2[key]
# Ignore None values without complaint
if value2 is not None:
if isinstance(value1, Node):
new_prefix = prefix if prefix else ""
new_prefix += f"[{key}]"
(is_safe, tmp_msg) = is_safe_update(value1, value2, prefix=new_prefix)
if not is_safe:
is_safe = False
msg += tmp_msg
elif value1 is not None and value1 != value2:
is_safe = False
msg += "Value changed:\n"
msg += f"\t{d1_prefix}[{key}] = {value1}\n"
msg += f"\t{d2_prefix}[{key}] = {value2}\n"
return is_safe, msg