Source code for ctao_datamodel._visitor

"""
Functionality to recursively call a function on model sub-elements.

This is used in other places to simplify turning the model into other formats,
like PlantUML.
"""

import json
import logging
from collections import defaultdict
from collections.abc import Generator, MutableMapping, Sequence
from dataclasses import dataclass
from enum import Enum, auto
from types import UnionType
from typing import Annotated, ForwardRef, Literal, Union, get_args, get_origin

from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo

from ._core import type_to_string

DEFAULT_MAX_DEPTH = 1000

#: Types that should not be stringified for e.g. FITS
NON_STRING_TYPES = (int, float)


logger = logging.getLogger("_visitor")

__all__ = [
    "flatten_model_instance",
    "unflatten_model_instance",
    "Relation",
    "ParentInfo",
    "walk_model",
    "walk_model_to_depth",
    "classes_in_model",
    "related_classes_in_model",
    "all_classes_in_model",
    "extract_keyword_mapping",
    "print_model",
]

#: Classes to ignore when finding related models
IGNORED_TYPES = (
    "ModelBase",
    "BaseModel",
    "ReprEnum",
    "StrEnum",
    "Enum",
    "Flag",
    "StrFlag",
)


[docs] class Relation(Enum): """Parent relationship.""" contains = auto() contains_one_of = auto() contains_many = auto() none = auto()
[docs] @dataclass class ParentInfo: """Info about parent of the current element.""" item_name: str item_type: type is_optional: bool parent_relation: Relation field: FieldInfo
def _unwrap_annotated(tp): """Remove Annotated[...] wrapper.""" origin = get_origin(tp) if origin is Annotated: return get_args(tp)[0] return tp def _walk_model( model: type, _path: list[ParentInfo] | None = None ) -> Generator[tuple[type, list[ParentInfo]]]: """Recurse for walk_model().""" if not _path: _path = [] # base case: yield the info yield model, _path # if it's a leaf element, just return at this point. we are done: if not isinstance(model, type) or not issubclass(model, BaseModel): return # If it's a base model, however, loop over its children and recurse: for fieldname, field in model.model_fields.items(): outer_type = _unwrap_annotated(field.annotation) origin = get_origin(outer_type) args = get_args(outer_type) # forward refs require a bit of a hack for now... if isinstance(outer_type, ForwardRef): yield from _walk_model( outer_type, _path=_path + [ ParentInfo( item_name=fieldname, item_type=ForwardRef(outer_type.__forward_arg__.split(" |")[0]), is_optional="None" in outer_type.__forward_arg__, parent_relation=Relation.contains, field=field, ) ], ) elif origin in [UnionType, Union]: # handle compound type, e.g. X|Y is_optional = type(None) in args non_none_args = [t for t in args if t is not type(None)] relation = ( Relation.contains if len(non_none_args) <= 1 else Relation.contains_one_of ) for sub_type in non_none_args: sub_type = _unwrap_annotated(sub_type) if get_origin(sub_type) is list: inner = get_args(sub_type)[0] yield from _walk_model( inner, _path=_path + [ ParentInfo( item_name=fieldname, item_type=inner, is_optional=is_optional, parent_relation=Relation.contains_many, field=field, ) ], ) else: yield from _walk_model( sub_type, _path=_path + [ ParentInfo( item_name=fieldname, item_type=sub_type, is_optional=is_optional, parent_relation=relation, field=field, ) ], ) # Handle list[x], assuming that all items of the list are the same. elif origin is list and args: sub_type = args[0] yield from _walk_model( sub_type, _path=_path + [ ParentInfo( item_name=fieldname, item_type=sub_type, is_optional=False, parent_relation=Relation.contains_many, field=field, ) ], ) # Literal types elif origin is Literal: yield from _walk_model( outer_type, _path=_path + [ ParentInfo( item_name=fieldname, item_type=origin, is_optional=False, parent_relation=Relation.contains, field=field, ) ], ) # Simple single type elif isinstance(outer_type, type): yield from _walk_model( outer_type, _path=_path + [ ParentInfo( item_name=fieldname, item_type=outer_type, is_optional=False, parent_relation=Relation.contains, field=field, ) ], ) else: raise NotImplementedError( f"UNKNOWN type: {fieldname=} {outer_type=} {origin=} {args=}" )
[docs] def walk_model( model: type[BaseModel], parent_key: str = "" ) -> Generator[tuple[type, list[ParentInfo]]]: """ Recursively walk the pydantic model. Parameters ---------- model: Type[BaseModel] Model class (not instance) to visit parent_key: str Key to use for first element in the model If not specified, the model name is used. Returns ------- Generator[tuple[BaseModel, list[Element]]]: Element class and path at each step. The path is a list of ParentInfo describing the parents of the current element in order of ancestry, e.g. the current model element's parent is path[-1], the grandparent is path[-2], ... Examples -------- To generate a hierarchical list of the elements and sub-elements: >>> for model_element, path in walk_model(some_model): >>> name = '.'.join(p.item_name for p in path) >>> print(f"{'*' * (len(path)+1)} {name} : {model_element.__name__}") * Rererence : Reference ** process : Process *** process.type : ObservatoryProcess ** data : Product *** data.category : DataCategory *** data.level : DataLevel *** data.division : DataDivision *** data.association : DataAssociation *** data.type : DataType ... """ if parent_key == "": parent_key = model.__name__.lower() initial = ParentInfo( item_name=parent_key, item_type=model, parent_relation=Relation.none, is_optional=False, field=Field(description=model.__doc__), ) yield from _walk_model(model, _path=[initial])
[docs] def walk_model_to_depth( model: type[BaseModel], max_depth: int = DEFAULT_MAX_DEPTH, parent_key: str = "" ): """Walk the model as in walk_model, but stop at depth.""" return filter( lambda x: len(x[1]) <= max_depth, walk_model(model, parent_key=parent_key) )
[docs] def classes_in_model( model: type[BaseModel], allowed_types=BaseModel | Enum, max_depth: int = DEFAULT_MAX_DEPTH, ): """ Return list of unique classes used in the given model. Parameters ---------- model : BaseModel model to search allowed_types : type Types to find. Default is BaseModel, but you can also specify e.g. BaseModel|Enum """ classes_in_model = [model] for model, _ in walk_model_to_depth(model, max_depth=max_depth): if isinstance(model, type): classes_in_model.append(model) # remove dups: unique_classes = [c for c in set(classes_in_model) if issubclass(c, allowed_types)] return unique_classes
[docs] def all_classes_in_model(model: type[BaseModel], allowed_types=BaseModel | Enum): """Return all classes and related classes in the model.""" classes = classes_in_model(model, allowed_types=allowed_types) parents, children = related_classes_in_model(model, allowed_types=allowed_types) return classes + parents + children
def _flatten( obj, parent_key: str = "", separator: str = ".", to_string: bool = True, expand_list: bool = True, do_not_stringify: tuple[type, ...] = NON_STRING_TYPES, ): """Flatten dicts and optionally lists. Skip empty leaf values.""" items = {} # ----------------------------- # CASE 1: dict-like # ----------------------------- if isinstance(obj, MutableMapping): for key, value in obj.items(): new_key = f"{parent_key}{separator}{key}" if parent_key else key result = _flatten(value, new_key, separator, to_string, expand_list) items.update(result) return items # ----------------------------- # CASE 2: list-like, not stringified, and with simple non-structured elements: # ----------------------------- if isinstance(obj, Sequence) and not isinstance(obj, str | bytes): # Special rule: # if to_string is False → do NOT expand the list. if not expand_list: if not obj: # empty list → skip return {} return {parent_key: str(obj) if to_string else obj} # Normal flattening when to_string=True for idx, value in enumerate(obj): new_key = f"{parent_key}{separator}{idx}" if parent_key else str(idx) result = _flatten(value, new_key, separator, to_string, expand_list) items.update(result) return items # -------------------------- # CASE 3: leaf value # ----------------------------- # Skip empty leaf values (None, "", 0-length sequence, {}) if not obj: return {} # Store final leaf should_stringify = not isinstance(obj, do_not_stringify) return {parent_key: str(obj) if to_string and should_stringify else obj} def _unflatten(flat_dict: dict, parent_key: str = "", separator: str = ".") -> dict: """Unflatten a dictionary flattened with _flatten, supporting lists.""" result = {} for key, value in flat_dict.items(): if key.startswith(parent_key + separator): key = key.removeprefix(parent_key + separator) parts = key.split(separator) d = result for i, part in enumerate(parts[:-1]): if part.isdigit(): # Current container should be a list index = int(part) if not isinstance(d, list): raise ValueError(f"Expected list but found {type(d)}") # Extend list if needed while len(d) <= index: d.append({}) d = d[index] else: # Check if next part is a digit (this should be a list) next_part = parts[i + 1] if next_part.isdigit(): d = d.setdefault(part, []) else: d = d.setdefault(part, {}) # Handle the final part final_part = parts[-1] if final_part.isdigit(): index = int(final_part) if not isinstance(d, list): raise ValueError(f"Expected list but found {type(d)}") while len(d) <= index: d.append(None) d[index] = value else: d[final_part] = value return result
[docs] def flatten_model_instance( model_instance: BaseModel, parent_key: str = "", separator: str = ".", to_string: bool = True, expand_lists: bool = True, ): """ Return the flattened model instance. Parameters ---------- model_instance: BaseModel model instance to flatten (not the class!) parent_key: str starting key separator: str separator between keys to_string: bool if True, turn the leaf values into strings expand_lists: bool if True, split lists into key.N, where N is the index If this is not True, and to_string is True, the string rep of the list will be used, which might cause issues for round-tripping back to a Pydantic model. """ # to ensure more complex objects like AstroPydanticICRS, where the JSON # serialization creates sub-dicts, we do a round trip here. model_dict = json.loads(model_instance.model_dump_json()) return _flatten( model_dict, parent_key=parent_key, separator=separator, to_string=to_string, expand_list=expand_lists, )
[docs] def unflatten_model_instance( flat_dict: dict, model: type[BaseModel], parent_key: str = "", separator: str = "." ) -> BaseModel: """ Return the flattened model instance. Parameters ---------- flat_dict: dict flattened dictionary of keywords, e.g. as loaded from json model_instance: BaseModel model instance to flatten (not the class!) parent_key: str starting key separator: str separator between keys """ return model( **_unflatten( flat_dict, parent_key=parent_key, separator=separator, ) )
def get_field_metadata(field, metadata_key: str): """Get a metadata item from a Field.""" extra = field.json_schema_extra or dict() if extra and (metadata_key in extra): return extra[metadata_key] else: return getattr(field, metadata_key, None) def extract_keyword_mapping( model: type[BaseModel], metadata_key: str = "ivoa_keyword", *, sep="." ) -> dict[str, str]: """Create a dictionary mapping flat model keyword to a metadata item. All items found that have the metadata will be returned. Parameters ---------- model: Type[BaseModel] model to extract from metadata_key: str key in the json_schema_extra field to search for, e.g. "fits_keyword" sep: str separator for the model key Returns ------- dict[str,str]: dictionary of flat_model_key -> metadata_key """ keyword_map = defaultdict(list) for _, path in walk_model(model): item = path[-1] key = sep.join(p.item_name for p in path[1:]) val = get_field_metadata(field=item.field, metadata_key=metadata_key) if val and val not in keyword_map[key]: keyword_map[key].append(val) return keyword_map