"""Core classes for the data model."""
from enum import Flag, StrEnum
from inspect import cleandoc
from typing import Any
from astropy import units as u
from astropy.io.votable.ucd import check_ucd, parse_ucd
from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Field,
validate_call,
)
from pydantic_core import core_schema
__all__ = [
"AstroField",
"ModelBase",
"doc",
"enum_doc",
"ValidUCD",
"type_to_string",
"StrFlag",
]
def is_valid_ucd(value: str) -> str:
if check_ucd(value, check_controlled_vocabulary=True):
return value
msg = f"Invalid UCD: '{value}'"
raise ValueError(msg)
ValidUCD = AfterValidator(is_valid_ucd)
[docs]
@validate_call
def AstroField( # noqa: N802
description: str | None = None,
*,
fits_keyword: str | None = None,
ivoa_keyword: str | None = None,
unit: str | None = None,
ucd: str | None = None,
examples: list[str] | None = None,
**kwargs,
) -> Field: # type: ignore
"""
Return a Field with extra astronomy-related metadata.
The extra info is not for validation, but used for documentation or
serialization. This is just a helper to avoid having to add a
json_schema_extra dict manually.
"""
if fits_keyword and len(fits_keyword) > 8:
msg = f"FITS Keyword muse be less than 8 characters, got '{fits_keyword}'"
raise ValueError(
msg,
)
# normalize unit string
if unit:
unit = u.Unit(unit).to_string("fits")
json_schema_extra: dict[str, Any] = dict()
if fits_keyword:
json_schema_extra["fits_keyword"] = fits_keyword
if ivoa_keyword:
json_schema_extra["ivoa_keyword"] = ivoa_keyword
if unit:
json_schema_extra["unit"] = unit
if examples:
json_schema_extra["examples"] = examples
if ucd:
parse_ucd(ucd, check_controlled_vocabulary=True)
json_schema_extra["ucd"] = ucd
return Field(description=description, json_schema_extra=json_schema_extra, **kwargs)
[docs]
def doc(obj):
"""Return cleaned up docstring for the object, suitable for a description field."""
doc = obj.__doc__ or ""
return cleandoc(doc).replace("\n", " ")
[docs]
def enum_doc(enum: type[StrEnum]) -> str:
"""Return nicer documentation for an Enum."""
return doc(enum) + " Options are: " + ", ".join([f'"{v}"' for v in enum]) + "."
[docs]
class ModelBase(BaseModel):
"""
Base class for all CTAO models.
Should just set the common model_config here.
"""
model_config = ConfigDict(extra="forbid")
def type_to_string(cls):
"""Turn a type into a string."""
if hasattr(cls, "__args__"):
# Handling compound types
return f"{'|'.join(map(type_to_string, cls.__args__))}"
elif hasattr(cls, "__name__"):
return cls.__name__
elif hasattr(cls, "__forward_arg__"):
# Handle forward refs
return cls.__forward_arg__
return str(cls)
[docs]
class StrFlag(Flag):
"""Like enum.Flag, but allowing string input, like ``A|B``."""
@classmethod
def _missing_(cls, value):
# Only handle strings; defer everything else to Flag
if not isinstance(value, str):
return super()._missing_(value)
result = cls(0)
parts = value.split("|")
for part in parts:
name = part.strip().upper()
try:
result |= cls[name]
except KeyError:
raise ValueError(
f"{name!r} is not a valid member of {cls.__name__}"
) from None
return result
[docs]
def to_string(self) -> str:
"""Turn the flag into a string representation."""
if self.value == 0:
return "0"
return "|".join(member.name for member in type(self) if member in self)
def __str__(self):
return self.to_string()
@classmethod
def __get_pydantic_core_schema__(cls, source, handler):
"""Implement correct serialization for pydantic."""
return core_schema.no_info_after_validator_function(
cls,
core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
core_schema.str_schema(),
core_schema.int_schema(),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: v.to_string(),
return_schema=core_schema.str_schema(),
),
)