"""Core classes for the data model."""
import warnings
from enum import Flag, StrEnum, auto
from inspect import cleandoc
from typing import Annotated, Any
from astropy import units as u
from astropy.io.votable.ucd import check_ucd, parse_ucd
from astropy.units.typing import UnitLike
from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Field,
PlainSerializer,
PlainValidator,
WithJsonSchema,
validate_call,
)
from pydantic_core import core_schema
__all__ = [
"AstroField",
"ModelBase",
"doc",
"enum_doc",
"ValidUCD",
"type_to_string",
"StrFlag",
"Quantity",
"QuantityFormat",
]
WARN_ON_DEPRACATED_UNIT = True #: emit warning on fields using old unit convention
UNIT_STRING_FORMAT = "fits" #: format to use when converting Quantity to string
def is_valid_ucd(value: str) -> str:
if check_ucd(value, check_controlled_vocabulary=True):
return value
msg = f"Invalid UCD: '{value}'"
raise ValueError(msg)
ValidUCD = AfterValidator(is_valid_ucd)
[docs]
@validate_call
def AstroField( # noqa: N802
description: str | None = None,
*,
fits_keyword: str | None = None,
ivoa_keyword: str | None = None,
unit: str | None = None,
ucd: str | None = None,
examples: list[str] | None = None,
**kwargs,
) -> Field: # type: ignore
"""
Return a Field with extra astronomy-related metadata.
The extra info is not for validation, but used for documentation or
serialization. This is just a helper to avoid having to add a
json_schema_extra dict manually.
"""
if fits_keyword and len(fits_keyword) > 8:
msg = f"FITS Keyword muse be less than 8 characters, got '{fits_keyword}'"
raise ValueError(
msg,
)
# normalize unit string
if unit:
unit = u.Unit(unit).to_string(UNIT_STRING_FORMAT)
if WARN_ON_DEPRACATED_UNIT:
warnings.warn(
"Specifying unit in AstroField is deprecated. Use dm.Quantity[unit]"
" instead",
category=DeprecationWarning,
)
json_schema_extra: dict[str, Any] = dict()
if fits_keyword:
json_schema_extra["fits_keyword"] = fits_keyword
if ivoa_keyword:
json_schema_extra["ivoa_keyword"] = ivoa_keyword
if unit:
json_schema_extra["unit"] = unit
if examples:
json_schema_extra["examples"] = examples
if ucd:
parse_ucd(ucd, check_controlled_vocabulary=True)
json_schema_extra["ucd"] = ucd
return Field(description=description, json_schema_extra=json_schema_extra, **kwargs)
[docs]
def doc(obj):
"""Return cleaned up docstring for the object, suitable for a description field."""
doc = obj.__doc__ or ""
return cleandoc(doc).replace("\n", " ")
[docs]
def enum_doc(enum: type[StrEnum]) -> str:
"""Return nicer documentation for an Enum."""
return doc(enum) + " Options are: " + ", ".join([f'"{v}"' for v in enum]) + "."
[docs]
class ModelBase(BaseModel):
"""
Base class for all CTAO models.
Should just set the common model_config here.
"""
model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True, # needed for Quantity serialization
validate_default=True, # needed for Quantities with defaults
)
def type_to_string(cls):
"""Turn a type into a string."""
if hasattr(cls, "__args__"):
# Handling compound types
return f"{'|'.join(map(type_to_string, cls.__args__))}"
elif hasattr(cls, "__name__"):
return cls.__name__
elif hasattr(cls, "__forward_arg__"):
# Handle forward refs
return cls.__forward_arg__
return str(cls)
[docs]
class StrFlag(Flag):
"""Like enum.Flag, but allowing string input, like ``A|B``."""
@classmethod
def _missing_(cls, value):
# Only handle strings; defer everything else to Flag
if not isinstance(value, str):
return super()._missing_(value)
result = cls(0)
parts = value.split("|")
for part in parts:
name = part.strip().upper()
try:
result |= cls[name]
except KeyError:
raise ValueError(
f"{name!r} is not a valid member of {cls.__name__}"
) from None
return result
[docs]
def to_string(self) -> str:
"""Turn the flag into a string representation."""
if self.value == 0:
return "0"
return "|".join(member.name for member in type(self) if member in self)
def __str__(self):
return self.to_string()
@classmethod
def __get_pydantic_core_schema__(cls, source, handler):
"""Implement correct serialization for pydantic."""
return core_schema.no_info_after_validator_function(
cls,
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
core_schema.str_schema(),
core_schema.int_schema(),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: v.to_string(),
return_schema=core_schema.str_schema(),
),
)
def _make_quantity(unit: UnitLike) -> type[u.Quantity]:
"""Build a Quantity[x] type, where x is the given unit."""
unit = u.Unit(unit)
def serialize(q: u.Quantity, info) -> dict | float | str:
quantity_format = QuantityFormat.DICT
if isinstance(info.context, dict):
quantity_format = info.context.get("quantity_format", QuantityFormat.DICT)
quantity_format = QuantityFormat(quantity_format)
if quantity_format == QuantityFormat.FLOAT:
return float(q.to_value(unit))
elif quantity_format == QuantityFormat.STRING:
return q.to_string(subfmt=UNIT_STRING_FORMAT)
else:
return dict(value=float(q.value), unit=q.unit.to_string(UNIT_STRING_FORMAT))
def deserialize(d, info) -> u.Quantity:
if isinstance(d, u.Quantity | str):
return u.Quantity(d).to(unit)
if isinstance(d, dict):
return u.Quantity(d["value"], d["unit"])
return u.Quantity(d, unit=unit)
return Annotated[
u.Quantity[unit],
PlainSerializer(serialize, return_type=dict | float | str, when_used="always"),
PlainValidator(deserialize),
WithJsonSchema(
{
"anyOf": [
{
"type": "object",
"properties": {
"value": {"type": "number"},
"unit": {"type": "string"},
},
"required": ["value", "unit"],
},
{"type": "string"},
{"type": "number"},
],
"unit": u.Unit(unit).to_string(UNIT_STRING_FORMAT),
}
),
]
class _QuantityAlias:
"""Allows both `Quantity` and `Quantity["m"]` as type annotations."""
def __getitem__(self, unit: UnitLike) -> type[UnitLike]:
return _make_quantity(unit)
# Support use as a plain type (without brackets)
def __class_getitem__(cls, unit: UnitLike) -> type[UnitLike]:
return _make_quantity(unit)
#: Pydantic type for astropy Quantities that supports both JSON and FITS
#: serialization
Quantity = _QuantityAlias()