Source code for ctao_datamodel._core

"""Core classes for the data model."""

import warnings
from enum import Flag, StrEnum, auto
from inspect import cleandoc
from typing import Annotated, Any

from astropy import units as u
from astropy.io.votable.ucd import check_ucd, parse_ucd
from astropy.units.typing import UnitLike
from pydantic import (
    AfterValidator,
    BaseModel,
    ConfigDict,
    Field,
    PlainSerializer,
    PlainValidator,
    WithJsonSchema,
    validate_call,
)
from pydantic_core import core_schema

__all__ = [
    "AstroField",
    "ModelBase",
    "doc",
    "enum_doc",
    "ValidUCD",
    "type_to_string",
    "StrFlag",
    "Quantity",
    "QuantityFormat",
]

WARN_ON_DEPRACATED_UNIT = True  #: emit warning on fields using old unit convention
UNIT_STRING_FORMAT = "fits"  #: format to use when converting Quantity to string


def is_valid_ucd(value: str) -> str:
    if check_ucd(value, check_controlled_vocabulary=True):
        return value
    msg = f"Invalid UCD: '{value}'"
    raise ValueError(msg)


ValidUCD = AfterValidator(is_valid_ucd)


[docs] @validate_call def AstroField( # noqa: N802 description: str | None = None, *, fits_keyword: str | None = None, ivoa_keyword: str | None = None, unit: str | None = None, ucd: str | None = None, examples: list[str] | None = None, **kwargs, ) -> Field: # type: ignore """ Return a Field with extra astronomy-related metadata. The extra info is not for validation, but used for documentation or serialization. This is just a helper to avoid having to add a json_schema_extra dict manually. """ if fits_keyword and len(fits_keyword) > 8: msg = f"FITS Keyword muse be less than 8 characters, got '{fits_keyword}'" raise ValueError( msg, ) # normalize unit string if unit: unit = u.Unit(unit).to_string(UNIT_STRING_FORMAT) if WARN_ON_DEPRACATED_UNIT: warnings.warn( "Specifying unit in AstroField is deprecated. Use dm.Quantity[unit]" " instead", category=DeprecationWarning, ) json_schema_extra: dict[str, Any] = dict() if fits_keyword: json_schema_extra["fits_keyword"] = fits_keyword if ivoa_keyword: json_schema_extra["ivoa_keyword"] = ivoa_keyword if unit: json_schema_extra["unit"] = unit if examples: json_schema_extra["examples"] = examples if ucd: parse_ucd(ucd, check_controlled_vocabulary=True) json_schema_extra["ucd"] = ucd return Field(description=description, json_schema_extra=json_schema_extra, **kwargs)
[docs] def doc(obj): """Return cleaned up docstring for the object, suitable for a description field.""" doc = obj.__doc__ or "" return cleandoc(doc).replace("\n", " ")
[docs] def enum_doc(enum: type[StrEnum]) -> str: """Return nicer documentation for an Enum.""" return doc(enum) + " Options are: " + ", ".join([f'"{v}"' for v in enum]) + "."
[docs] class ModelBase(BaseModel): """ Base class for all CTAO models. Should just set the common model_config here. """ model_config = ConfigDict( extra="forbid", arbitrary_types_allowed=True, # needed for Quantity serialization validate_default=True, # needed for Quantities with defaults )
def type_to_string(cls): """Turn a type into a string.""" if hasattr(cls, "__args__"): # Handling compound types return f"{'|'.join(map(type_to_string, cls.__args__))}" elif hasattr(cls, "__name__"): return cls.__name__ elif hasattr(cls, "__forward_arg__"): # Handle forward refs return cls.__forward_arg__ return str(cls)
[docs] class StrFlag(Flag): """Like enum.Flag, but allowing string input, like ``A|B``.""" @classmethod def _missing_(cls, value): # Only handle strings; defer everything else to Flag if not isinstance(value, str): return super()._missing_(value) result = cls(0) parts = value.split("|") for part in parts: name = part.strip().upper() try: result |= cls[name] except KeyError: raise ValueError( f"{name!r} is not a valid member of {cls.__name__}" ) from None return result
[docs] def to_string(self) -> str: """Turn the flag into a string representation.""" if self.value == 0: return "0" return "|".join(member.name for member in type(self) if member in self)
def __str__(self): return self.to_string() @classmethod def __get_pydantic_core_schema__(cls, source, handler): """Implement correct serialization for pydantic.""" return core_schema.no_info_after_validator_function( cls, core_schema.union_schema( [ core_schema.is_instance_schema(cls), core_schema.str_schema(), core_schema.int_schema(), ] ), serialization=core_schema.plain_serializer_function_ser_schema( lambda v: v.to_string(), return_schema=core_schema.str_schema(), ), )
[docs] class QuantityFormat(StrEnum): """Format used to serialize Quantities.""" DICT = auto() #: dict with value and unit keys STRING = auto() #: string representation FLOAT = auto() #: float-representation, units removed
def _make_quantity(unit: UnitLike) -> type[u.Quantity]: """Build a Quantity[x] type, where x is the given unit.""" unit = u.Unit(unit) def serialize(q: u.Quantity, info) -> dict | float | str: quantity_format = QuantityFormat.DICT if isinstance(info.context, dict): quantity_format = info.context.get("quantity_format", QuantityFormat.DICT) quantity_format = QuantityFormat(quantity_format) if quantity_format == QuantityFormat.FLOAT: return float(q.to_value(unit)) elif quantity_format == QuantityFormat.STRING: return q.to_string(subfmt=UNIT_STRING_FORMAT) else: return dict(value=float(q.value), unit=q.unit.to_string(UNIT_STRING_FORMAT)) def deserialize(d, info) -> u.Quantity: if isinstance(d, u.Quantity | str): return u.Quantity(d).to(unit) if isinstance(d, dict): return u.Quantity(d["value"], d["unit"]) return u.Quantity(d, unit=unit) return Annotated[ u.Quantity[unit], PlainSerializer(serialize, return_type=dict | float | str, when_used="always"), PlainValidator(deserialize), WithJsonSchema( { "anyOf": [ { "type": "object", "properties": { "value": {"type": "number"}, "unit": {"type": "string"}, }, "required": ["value", "unit"], }, {"type": "string"}, {"type": "number"}, ], "unit": u.Unit(unit).to_string(UNIT_STRING_FORMAT), } ), ] class _QuantityAlias: """Allows both `Quantity` and `Quantity["m"]` as type annotations.""" def __getitem__(self, unit: UnitLike) -> type[UnitLike]: return _make_quantity(unit) # Support use as a plain type (without brackets) def __class_getitem__(cls, unit: UnitLike) -> type[UnitLike]: return _make_quantity(unit) #: Pydantic type for astropy Quantities that supports both JSON and FITS #: serialization Quantity = _QuantityAlias()