Source code for ctao_datamodel._core

"""Core classes for the data model."""

from enum import Flag, StrEnum
from inspect import cleandoc
from typing import Any

from astropy import units as u
from astropy.io.votable.ucd import check_ucd, parse_ucd
from pydantic import (
    AfterValidator,
    BaseModel,
    ConfigDict,
    Field,
    validate_call,
)
from pydantic_core import core_schema

__all__ = [
    "AstroField",
    "ModelBase",
    "doc",
    "enum_doc",
    "ValidUCD",
    "type_to_string",
    "StrFlag",
]


def is_valid_ucd(value: str) -> str:
    if check_ucd(value, check_controlled_vocabulary=True):
        return value
    msg = f"Invalid UCD: '{value}'"
    raise ValueError(msg)


ValidUCD = AfterValidator(is_valid_ucd)


[docs] @validate_call def AstroField( # noqa: N802 description: str | None = None, *, fits_keyword: str | None = None, ivoa_keyword: str | None = None, unit: str | None = None, ucd: str | None = None, examples: list[str] | None = None, **kwargs, ) -> Field: # type: ignore """ Return a Field with extra astronomy-related metadata. The extra info is not for validation, but used for documentation or serialization. This is just a helper to avoid having to add a json_schema_extra dict manually. """ if fits_keyword and len(fits_keyword) > 8: msg = f"FITS Keyword muse be less than 8 characters, got '{fits_keyword}'" raise ValueError( msg, ) # normalize unit string if unit: unit = u.Unit(unit).to_string("fits") json_schema_extra: dict[str, Any] = dict() if fits_keyword: json_schema_extra["fits_keyword"] = fits_keyword if ivoa_keyword: json_schema_extra["ivoa_keyword"] = ivoa_keyword if unit: json_schema_extra["unit"] = unit if examples: json_schema_extra["examples"] = examples if ucd: parse_ucd(ucd, check_controlled_vocabulary=True) json_schema_extra["ucd"] = ucd return Field(description=description, json_schema_extra=json_schema_extra, **kwargs)
[docs] def doc(obj): """Return cleaned up docstring for the object, suitable for a description field.""" doc = obj.__doc__ or "" return cleandoc(doc).replace("\n", " ")
[docs] def enum_doc(enum: type[StrEnum]) -> str: """Return nicer documentation for an Enum.""" return doc(enum) + " Options are: " + ", ".join([f'"{v}"' for v in enum]) + "."
[docs] class ModelBase(BaseModel): """ Base class for all CTAO models. Should just set the common model_config here. """ model_config = ConfigDict(extra="forbid")
def type_to_string(cls): """Turn a type into a string.""" if hasattr(cls, "__args__"): # Handling compound types return f"{'|'.join(map(type_to_string, cls.__args__))}" elif hasattr(cls, "__name__"): return cls.__name__ elif hasattr(cls, "__forward_arg__"): # Handle forward refs return cls.__forward_arg__ return str(cls)
[docs] class StrFlag(Flag): """Like enum.Flag, but allowing string input, like ``A|B``.""" @classmethod def _missing_(cls, value): # Only handle strings; defer everything else to Flag if not isinstance(value, str): return super()._missing_(value) result = cls(0) parts = value.split("|") for part in parts: name = part.strip().upper() try: result |= cls[name] except KeyError: raise ValueError( f"{name!r} is not a valid member of {cls.__name__}" ) from None return result
[docs] def to_string(self) -> str: """Turn the flag into a string representation.""" if self.value == 0: return "0" return "|".join(member.name for member in type(self) if member in self)
def __str__(self): return self.to_string() @classmethod def __get_pydantic_core_schema__(cls, source, handler): """Implement correct serialization for pydantic.""" return core_schema.no_info_after_validator_function( cls, core_schema.union_schema( [ core_schema.is_instance_schema(cls), core_schema.str_schema(), core_schema.int_schema(), ] ), serialization=core_schema.plain_serializer_function_ser_schema( lambda v: v.to_string(), return_schema=core_schema.str_schema(), ), )