"""
Functionality to recursively call a function on model sub-elements.
This is used in other places to simplify turning the model into other formats,
like PlantUML.
"""
import json
import logging
from collections import defaultdict
from collections.abc import Generator, MutableMapping, Sequence
from dataclasses import dataclass
from enum import Enum, auto
from types import UnionType
from typing import Annotated, ForwardRef, Literal, Union, get_args, get_origin
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo
from ._core import type_to_string
DEFAULT_MAX_DEPTH = 1000
#: Types that should not be stringified for e.g. FITS
NON_STRING_TYPES = (int, float)
logger = logging.getLogger("_visitor")
__all__ = [
"flatten_model_instance",
"unflatten_model_instance",
"Relation",
"ParentInfo",
"walk_model",
"walk_model_to_depth",
"classes_in_model",
"related_classes_in_model",
"all_classes_in_model",
"extract_keyword_mapping",
"print_model",
]
#: Classes to ignore when finding related models
IGNORED_TYPES = (
"ModelBase",
"BaseModel",
"ReprEnum",
"StrEnum",
"Enum",
"Flag",
"StrFlag",
)
[docs]
class Relation(Enum):
"""Parent relationship."""
contains = auto()
contains_one_of = auto()
contains_many = auto()
none = auto()
[docs]
@dataclass
class ParentInfo:
"""Info about parent of the current element."""
item_name: str
item_type: type
is_optional: bool
parent_relation: Relation
field: FieldInfo
def _unwrap_annotated(tp):
"""Remove Annotated[...] wrapper."""
origin = get_origin(tp)
if origin is Annotated:
return get_args(tp)[0]
return tp
def _walk_model(
model: type, _path: list[ParentInfo] | None = None
) -> Generator[tuple[type, list[ParentInfo]]]:
"""Recurse for walk_model()."""
if not _path:
_path = []
# base case: yield the info
yield model, _path
# if it's a leaf element, just return at this point. we are done:
if not isinstance(model, type) or not issubclass(model, BaseModel):
return
# If it's a base model, however, loop over its children and recurse:
for fieldname, field in model.model_fields.items():
outer_type = _unwrap_annotated(field.annotation)
origin = get_origin(outer_type)
args = get_args(outer_type)
# forward refs require a bit of a hack for now...
if isinstance(outer_type, ForwardRef):
yield from _walk_model(
outer_type,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=ForwardRef(outer_type.__forward_arg__.split(" |")[0]),
is_optional="None" in outer_type.__forward_arg__,
parent_relation=Relation.contains,
field=field,
)
],
)
elif origin in [UnionType, Union]:
# handle compound type, e.g. X|Y
is_optional = type(None) in args
non_none_args = [t for t in args if t is not type(None)]
relation = (
Relation.contains
if len(non_none_args) <= 1
else Relation.contains_one_of
)
for sub_type in non_none_args:
sub_type = _unwrap_annotated(sub_type)
if get_origin(sub_type) is list:
inner = get_args(sub_type)[0]
yield from _walk_model(
inner,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=inner,
is_optional=is_optional,
parent_relation=Relation.contains_many,
field=field,
)
],
)
else:
yield from _walk_model(
sub_type,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=sub_type,
is_optional=is_optional,
parent_relation=relation,
field=field,
)
],
)
# Handle list[x], assuming that all items of the list are the same.
elif origin is list and args:
sub_type = args[0]
yield from _walk_model(
sub_type,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=sub_type,
is_optional=False,
parent_relation=Relation.contains_many,
field=field,
)
],
)
# Literal types
elif origin is Literal:
yield from _walk_model(
outer_type,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=origin,
is_optional=False,
parent_relation=Relation.contains,
field=field,
)
],
)
# Simple single type
elif isinstance(outer_type, type):
yield from _walk_model(
outer_type,
_path=_path
+ [
ParentInfo(
item_name=fieldname,
item_type=outer_type,
is_optional=False,
parent_relation=Relation.contains,
field=field,
)
],
)
else:
raise NotImplementedError(
f"UNKNOWN type: {fieldname=} {outer_type=} {origin=} {args=}"
)
[docs]
def walk_model(
model: type[BaseModel], parent_key: str = ""
) -> Generator[tuple[type, list[ParentInfo]]]:
"""
Recursively walk the pydantic model.
Parameters
----------
model: Type[BaseModel]
Model class (not instance) to visit
parent_key: str
Key to use for first element in the model
If not specified, the model name is used.
Returns
-------
Generator[tuple[BaseModel, list[Element]]]:
Element class and path at each step. The path is a list of ParentInfo
describing the parents of the current element in order of ancestry, e.g.
the current model element's parent is path[-1], the grandparent is
path[-2], ...
Examples
--------
To generate a hierarchical list of the elements and sub-elements:
>>> for model_element, path in walk_model(some_model):
>>> name = '.'.join(p.item_name for p in path)
>>> print(f"{'*' * (len(path)+1)} {name} : {model_element.__name__}")
* Rererence : Reference
** process : Process
*** process.type : ObservatoryProcess
** data : Product
*** data.category : DataCategory
*** data.level : DataLevel
*** data.division : DataDivision
*** data.association : DataAssociation
*** data.type : DataType
...
"""
if parent_key == "":
parent_key = model.__name__.lower()
initial = ParentInfo(
item_name=parent_key,
item_type=model,
parent_relation=Relation.none,
is_optional=False,
field=Field(description=model.__doc__),
)
yield from _walk_model(model, _path=[initial])
[docs]
def walk_model_to_depth(
model: type[BaseModel], max_depth: int = DEFAULT_MAX_DEPTH, parent_key: str = ""
):
"""Walk the model as in walk_model, but stop at depth."""
return filter(
lambda x: len(x[1]) <= max_depth, walk_model(model, parent_key=parent_key)
)
[docs]
def print_model(model, max_depth: int = DEFAULT_MAX_DEPTH) -> None:
"""Print out the model, for debugging."""
print(
"Element : Type :Opt: Parent"
" Relation"
)
print("=" * 85)
for _, path in walk_model_to_depth(model, max_depth=max_depth):
indent = " " * (len(path) - 1)
relation = path[-1].parent_relation
print(
f" {indent+path[-1].item_name:30s} :"
f" {type_to_string(path[-1].item_type):30s} :"
f" {'*' if path[-1].is_optional else ' '} : {relation.name:20s}"
)
[docs]
def classes_in_model(
model: type[BaseModel],
allowed_types=BaseModel | Enum,
max_depth: int = DEFAULT_MAX_DEPTH,
):
"""
Return list of unique classes used in the given model.
Parameters
----------
model : BaseModel
model to search
allowed_types : type
Types to find. Default is BaseModel, but you can also
specify e.g. BaseModel|Enum
"""
classes_in_model = [model]
for model, _ in walk_model_to_depth(model, max_depth=max_depth):
if isinstance(model, type):
classes_in_model.append(model)
# remove dups:
unique_classes = [c for c in set(classes_in_model) if issubclass(c, allowed_types)]
return unique_classes
[docs]
def all_classes_in_model(model: type[BaseModel], allowed_types=BaseModel | Enum):
"""Return all classes and related classes in the model."""
classes = classes_in_model(model, allowed_types=allowed_types)
parents, children = related_classes_in_model(model, allowed_types=allowed_types)
return classes + parents + children
def _flatten(
obj,
parent_key: str = "",
separator: str = ".",
to_string: bool = True,
expand_list: bool = True,
do_not_stringify: tuple[type, ...] = NON_STRING_TYPES,
):
"""Flatten dicts and optionally lists. Skip empty leaf values."""
items = {}
# -----------------------------
# CASE 1: dict-like
# -----------------------------
if isinstance(obj, MutableMapping):
for key, value in obj.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
result = _flatten(value, new_key, separator, to_string, expand_list)
items.update(result)
return items
# -----------------------------
# CASE 2: list-like, not stringified, and with simple non-structured elements:
# -----------------------------
if isinstance(obj, Sequence) and not isinstance(obj, str | bytes):
# Special rule:
# if to_string is False → do NOT expand the list.
if not expand_list:
if not obj: # empty list → skip
return {}
return {parent_key: str(obj) if to_string else obj}
# Normal flattening when to_string=True
for idx, value in enumerate(obj):
new_key = f"{parent_key}{separator}{idx}" if parent_key else str(idx)
result = _flatten(value, new_key, separator, to_string, expand_list)
items.update(result)
return items
# --------------------------
# CASE 3: leaf value
# -----------------------------
# Skip empty leaf values (None, "", 0-length sequence, {})
if not obj:
return {}
# Store final leaf
should_stringify = not isinstance(obj, do_not_stringify)
return {parent_key: str(obj) if to_string and should_stringify else obj}
def _unflatten(flat_dict: dict, parent_key: str = "", separator: str = ".") -> dict:
"""Unflatten a dictionary flattened with _flatten, supporting lists."""
result = {}
for key, value in flat_dict.items():
if key.startswith(parent_key + separator):
key = key.removeprefix(parent_key + separator)
parts = key.split(separator)
d = result
for i, part in enumerate(parts[:-1]):
if part.isdigit():
# Current container should be a list
index = int(part)
if not isinstance(d, list):
raise ValueError(f"Expected list but found {type(d)}")
# Extend list if needed
while len(d) <= index:
d.append({})
d = d[index]
else:
# Check if next part is a digit (this should be a list)
next_part = parts[i + 1]
if next_part.isdigit():
d = d.setdefault(part, [])
else:
d = d.setdefault(part, {})
# Handle the final part
final_part = parts[-1]
if final_part.isdigit():
index = int(final_part)
if not isinstance(d, list):
raise ValueError(f"Expected list but found {type(d)}")
while len(d) <= index:
d.append(None)
d[index] = value
else:
d[final_part] = value
return result
[docs]
def flatten_model_instance(
model_instance: BaseModel,
parent_key: str = "",
separator: str = ".",
to_string: bool = True,
expand_lists: bool = True,
):
"""
Return the flattened model instance.
Parameters
----------
model_instance: BaseModel
model instance to flatten (not the class!)
parent_key: str
starting key
separator: str
separator between keys
to_string: bool
if True, turn the leaf values into strings
expand_lists: bool
if True, split lists into key.N, where N is the index
If this is not True, and to_string is True, the string rep
of the list will be used, which might cause
issues for round-tripping back to a Pydantic model.
"""
# to ensure more complex objects like AstroPydanticICRS, where the JSON
# serialization creates sub-dicts, we do a round trip here.
model_dict = json.loads(model_instance.model_dump_json())
return _flatten(
model_dict,
parent_key=parent_key,
separator=separator,
to_string=to_string,
expand_list=expand_lists,
)
[docs]
def unflatten_model_instance(
flat_dict: dict, model: type[BaseModel], parent_key: str = "", separator: str = "."
) -> BaseModel:
"""
Return the flattened model instance.
Parameters
----------
flat_dict: dict
flattened dictionary of keywords, e.g. as loaded from json
model_instance: BaseModel
model instance to flatten (not the class!)
parent_key: str
starting key
separator: str
separator between keys
"""
return model(
**_unflatten(
flat_dict,
parent_key=parent_key,
separator=separator,
)
)
def get_field_metadata(field, metadata_key: str):
"""Get a metadata item from a Field."""
extra = field.json_schema_extra or dict()
if extra and (metadata_key in extra):
return extra[metadata_key]
else:
return getattr(field, metadata_key, None)
def extract_keyword_mapping(
model: type[BaseModel], metadata_key: str = "ivoa_keyword", *, sep="."
) -> dict[str, str]:
"""Create a dictionary mapping flat model keyword to a metadata item.
All items found that have the metadata will be returned.
Parameters
----------
model: Type[BaseModel]
model to extract from
metadata_key: str
key in the json_schema_extra field to search for,
e.g. "fits_keyword"
sep: str
separator for the model key
Returns
-------
dict[str,str]:
dictionary of flat_model_key -> metadata_key
"""
keyword_map = defaultdict(list)
for _, path in walk_model(model):
item = path[-1]
key = sep.join(p.item_name for p in path[1:])
val = get_field_metadata(field=item.field, metadata_key=metadata_key)
if val and val not in keyword_map[key]:
keyword_map[key].append(val)
return keyword_map