Source code for ctao_datamodel._plantuml

import base64
import logging
import subprocess
from enum import Enum
from functools import reduce
from pathlib import Path
from typing import Self

from pydantic import BaseModel

from ._latex import model_to_table
from ._visitor import (
    Relation,
    all_classes_in_model,
    classes_in_model,
    related_classes_in_model,
    walk_model_to_depth,
)

logger = logging.getLogger(__name__)


__all__ = [
    "model_to_plantuml_class",
    "generate_plantuml_diagrams",
    "PlantUMLDiagram",
]

RELATION_MAP = {
    Relation.contains: "1",
    Relation.contains_many: "..*",
    Relation.contains_one_of: "0..1",
}

START_UML = "@startuml"
END_UML = "@enduml"


def model_name(model: type[BaseModel | Enum]) -> str:
    """Return name of model used for PlantUML.

    This will also include the optional namespace if it exists.
    """
    if hasattr(model, "_namespace"):
        return f"{model._namespace}.{model.__name__}"
    return model.__name__


[docs] def model_to_plantuml_class(model: type[BaseModel | Enum]) -> str: """Return PlantUML class string.""" puml = [START_UML] if issubclass(model, Enum): puml.append(f"class {model_name(model)} <<enum>> {{") else: puml.append(f"class {model_name(model)} {{") if issubclass(model, BaseModel) or isinstance(model, BaseModel): table = model_to_table(model, optional_symbol=" [opt]", type_sep=" | ") for row in table: puml.append(f" + {row['Name']}: <i>{row['Type']}</i>") elif issubclass(model, Enum): for member in model: puml.append(f" + {member}") else: msg = f"couldn't convert: {model}" raise TypeError(msg) puml.append("}") puml.append(END_UML) return "\n".join(puml)
def _uml_relations_for_models( models: list[type[BaseModel]], allowed_types=BaseModel, ) -> tuple[list[str], set[type[BaseModel | Enum]]]: """ Generate PlantUML inheritance relationships. The search goes both up (parents) and down (children), recursively. Parameters ---------- models : list[type] Classes to analyze. allowed_types : type | tuple[type], optional Only include classes that are subclasses of these types. Returns ------- list[str] Sorted list of PlantUML relationship lines like "Parent <|-- Child". set[type[BaseModel|Enum]] List if models that should be imported """ # Start from the given classes and expand outward to_visit = set(models) all_related = set(models) seen = set() models_to_import = set() while to_visit: cls = to_visit.pop() if cls in seen: continue seen.add(cls) parents, children = related_classes_in_model(cls) new_related = set(parents) | set(children) to_visit |= new_related - all_related all_related |= new_related # Now build PlantUML relationships among all related classes relations = set() for cls in all_related: models_to_import.add(cls) for parent in cls.__mro__[1:]: if issubclass(parent, allowed_types) and parent in all_related: models_to_import.add(parent) relations.add(f"{model_name(parent)} <|-- {model_name(cls)}") return sorted(relations), models_to_import def _generate_imports(models: list[type[BaseModel]]) -> list[str]: """Generate import statement for the models.""" models = list(set(models)) # demove dups puml = [] for model in models: puml.append(f"!include ./definitions/{model_name(model)}.plantuml") return puml
[docs] def model_to_plantuml_relations( model: type[BaseModel], *, max_depth=1000, import_definitions: bool = False, show_inheritance: bool = False, ) -> str: """ Output PlantUML relationships for a model. By default containment relations are generated, and optionally inheritance relations. Parameters ---------- model: type[BaseModel] the model max_depth: int how deep to go import_defintions: bool include plantuml imports of the class definitions show_inheritance: bool include parent/child relations Returns ------- str: PlantUML code """ puml = [START_UML] models_to_import = set() models_to_import.add(model) for _, path in walk_model_to_depth(model, max_depth=max_depth): if len(path) < 2: continue item = path[-1] parent = path[-2] if isinstance(item.item_type, type) and issubclass( item.item_type, BaseModel | Enum ): rel = RELATION_MAP[item.parent_relation] if item.is_optional and item.parent_relation != Relation.contains_one_of: rel = "0.." + rel models_to_import.add(item.item_type) puml.append( f'{model_name(parent.item_type)} "+{item.item_name}" *-- "{rel}"' f" {model_name(item.item_type)} " ) if show_inheritance: relations, relation_models = _uml_relations_for_models([model]) models_to_import.update(relation_models) puml += relations if import_definitions: puml += _generate_imports(list(models_to_import)) puml.append(END_UML) return "\n".join(puml)
[docs] def generate_plantuml_diagrams( models: list[type[BaseModel]], output_dir: Path | str, show_inheritance: bool = True, imports: bool = True, max_depth: int = 1000, ) -> None: """ Generate plantuml diagrams for the given models. Parameters ---------- models: list[type[BaseModel]] list of Models to generte diagrams for output_dir: Path | str where to write the plantuml code show_inheritance: bool If True, also include parent/child relations and classes max_depth: int maximum depth to include in the diagrams. """ base = Path(output_dir) defs = base / "definitions" defs.mkdir(parents=True, exist_ok=True) # First generate all necessary class definitions all_classes = [] for model in models: all_classes += all_classes_in_model(model) all_classes = list(set(all_classes)) for model in all_classes: def_filename = defs / f"{model_name(model)}.plantuml" def_filename.write_text(model_to_plantuml_class(model)) logger.debug("Wrote: %s", def_filename) # now generate the relationship diagrams: logger.debug("Writing relationship diagrams...") for model in models: if imports: rel_filename = base / f"relations_{model_name(model)}.plantuml" rel_filename.write_text( model_to_plantuml_relations( model, max_depth=max_depth, show_inheritance=show_inheritance ) ) logger.debug("Wrote: %s", rel_filename) full_filename = base / f"full_{model_name(model)}.plantuml" full_filename.write_text( model_to_plantuml_relations( model, import_definitions=True, max_depth=max_depth, show_inheritance=True, ) ) logger.debug("Wrote: %s", full_filename)
[docs] class PlantUMLDiagram: """Render a PlantUML diagram in a jupyter notebook. Diagrams can be composed using the ``+`` operator. """ def __init__( self, model_or_text: str | type[BaseModel | Enum], relations: bool = True, inheritance: bool = False, details: bool = False, max_depth: int = 1000, cwd: Path | str | None = None, ) -> None: """ Create a PlantUML diagram. Parameters ---------- model_or_text: str | type[BaseModel] Either a model, or any arbitrary plantuml-format text. relations: bool include composed classes in the diagram inheritance: bool include inherited classes in the diagram details: bool include details of inherited or related classes cwd: Path | str current working directory, in case you have a relative ``!include`` directive in the plantuml text. """ if isinstance(model_or_text, str): self.plantuml_text = model_or_text elif isinstance(model_or_text, type) and issubclass( model_or_text, BaseModel | Enum ): self.plantuml_text = model_to_plantuml_class(model_or_text) if relations and issubclass(model_or_text, BaseModel): self.plantuml_text = ( self + PlantUMLDiagram( model_to_plantuml_relations( model_or_text, show_inheritance=inheritance, max_depth=max_depth, ) ) ).plantuml_text if details and issubclass(model_or_text, BaseModel): classes = classes_in_model(model_or_text, max_depth=max_depth) if len(classes) > 1: self.plantuml_text = ( self + reduce( lambda x, y: x + y, ( PlantUMLDiagram(m, inheritance=False, relations=False) for m in classes if m is not model_or_text ), ) ).plantuml_text else: raise ValueError("Expected plantuml string or type[BaseModel]") self.cwd = cwd
[docs] @classmethod def from_path(cls, path) -> Self: """Return diagram from an existing file at path.""" path = Path(path) return cls(path.read_text(), cwd=path.parent)
def _repr_html_(self) -> str: try: png_data = self._generate_png_pipe() b64_data = base64.b64encode(png_data).decode("utf-8") return f'<img src="data:image/png;base64,{b64_data}" />' except Exception as e: return ( '<div style="color: red;">Error rendering PlantUML diagram:' f" {str(e)}</div>" ) def _generate_png_pipe(self) -> bytes: """Generate PNG using PlantUML's pipe mode.""" result = subprocess.run( ["plantuml", "-tpng", "-pipe"], cwd=self.cwd, input=self.plantuml_text.encode("utf-8"), capture_output=True, check=True, ) return result.stdout def __add__(self, other) -> Self: text = self.plantuml_text.replace(START_UML, "").replace(END_UML, "") other_text = other.plantuml_text.replace(START_UML, "").replace(END_UML, "") return PlantUMLDiagram(f"@startuml\n{text}\n{other_text}\n@enduml") def __str__(self) -> str: """Turn diagram to string.""" return self.plantuml_text