import base64
import logging
import subprocess
from enum import Enum
from functools import reduce
from pathlib import Path
from typing import Self
from pydantic import BaseModel
from ._latex import model_to_table
from ._visitor import (
Relation,
all_classes_in_model,
classes_in_model,
related_classes_in_model,
walk_model_to_depth,
)
logger = logging.getLogger(__name__)
__all__ = [
"model_to_plantuml_class",
"generate_plantuml_diagrams",
"PlantUMLDiagram",
]
RELATION_MAP = {
Relation.contains: "1",
Relation.contains_many: "..*",
Relation.contains_one_of: "0..1",
}
START_UML = "@startuml"
END_UML = "@enduml"
def model_name(model: type[BaseModel | Enum]) -> str:
"""Return name of model used for PlantUML.
This will also include the optional namespace if it exists.
"""
if hasattr(model, "_namespace"):
return f"{model._namespace}.{model.__name__}"
return model.__name__
[docs]
def model_to_plantuml_class(model: type[BaseModel | Enum]) -> str:
"""Return PlantUML class string."""
puml = [START_UML]
if issubclass(model, Enum):
puml.append(f"class {model_name(model)} <<enum>> {{")
else:
puml.append(f"class {model_name(model)} {{")
if issubclass(model, BaseModel) or isinstance(model, BaseModel):
table = model_to_table(model, optional_symbol=" [opt]", type_sep=" | ")
for row in table:
puml.append(f" + {row['Name']}: <i>{row['Type']}</i>")
elif issubclass(model, Enum):
for member in model:
puml.append(f" + {member}")
else:
msg = f"couldn't convert: {model}"
raise TypeError(msg)
puml.append("}")
puml.append(END_UML)
return "\n".join(puml)
def _uml_relations_for_models(
models: list[type[BaseModel]],
allowed_types=BaseModel,
) -> tuple[list[str], set[type[BaseModel | Enum]]]:
"""
Generate PlantUML inheritance relationships.
The search goes both up (parents) and down (children), recursively.
Parameters
----------
models : list[type]
Classes to analyze.
allowed_types : type | tuple[type], optional
Only include classes that are subclasses of these types.
Returns
-------
list[str]
Sorted list of PlantUML relationship lines like "Parent <|-- Child".
set[type[BaseModel|Enum]]
List if models that should be imported
"""
# Start from the given classes and expand outward
to_visit = set(models)
all_related = set(models)
seen = set()
models_to_import = set()
while to_visit:
cls = to_visit.pop()
if cls in seen:
continue
seen.add(cls)
parents, children = related_classes_in_model(cls)
new_related = set(parents) | set(children)
to_visit |= new_related - all_related
all_related |= new_related
# Now build PlantUML relationships among all related classes
relations = set()
for cls in all_related:
models_to_import.add(cls)
for parent in cls.__mro__[1:]:
if issubclass(parent, allowed_types) and parent in all_related:
models_to_import.add(parent)
relations.add(f"{model_name(parent)} <|-- {model_name(cls)}")
return sorted(relations), models_to_import
def _generate_imports(models: list[type[BaseModel]]) -> list[str]:
"""Generate import statement for the models."""
models = list(set(models)) # demove dups
puml = []
for model in models:
puml.append(f"!include ./definitions/{model_name(model)}.plantuml")
return puml
[docs]
def model_to_plantuml_relations(
model: type[BaseModel],
*,
max_depth=1000,
import_definitions: bool = False,
show_inheritance: bool = False,
) -> str:
"""
Output PlantUML relationships for a model.
By default containment relations are generated, and optionally inheritance
relations.
Parameters
----------
model: type[BaseModel]
the model
max_depth: int
how deep to go
import_defintions: bool
include plantuml imports of the class definitions
show_inheritance: bool
include parent/child relations
Returns
-------
str:
PlantUML code
"""
puml = [START_UML]
models_to_import = set()
models_to_import.add(model)
for _, path in walk_model_to_depth(model, max_depth=max_depth):
if len(path) < 2:
continue
item = path[-1]
parent = path[-2]
if isinstance(item.item_type, type) and issubclass(
item.item_type, BaseModel | Enum
):
rel = RELATION_MAP[item.parent_relation]
if item.is_optional and item.parent_relation != Relation.contains_one_of:
rel = "0.." + rel
models_to_import.add(item.item_type)
puml.append(
f'{model_name(parent.item_type)} "+{item.item_name}" *-- "{rel}"'
f" {model_name(item.item_type)} "
)
if show_inheritance:
relations, relation_models = _uml_relations_for_models([model])
models_to_import.update(relation_models)
puml += relations
if import_definitions:
puml += _generate_imports(list(models_to_import))
puml.append(END_UML)
return "\n".join(puml)
[docs]
def generate_plantuml_diagrams(
models: list[type[BaseModel]],
output_dir: Path | str,
show_inheritance: bool = True,
imports: bool = True,
max_depth: int = 1000,
) -> None:
"""
Generate plantuml diagrams for the given models.
Parameters
----------
models: list[type[BaseModel]]
list of Models to generte diagrams for
output_dir: Path | str
where to write the plantuml code
show_inheritance: bool
If True, also include parent/child relations and classes
max_depth: int
maximum depth to include in the diagrams.
"""
base = Path(output_dir)
defs = base / "definitions"
defs.mkdir(parents=True, exist_ok=True)
# First generate all necessary class definitions
all_classes = []
for model in models:
all_classes += all_classes_in_model(model)
all_classes = list(set(all_classes))
for model in all_classes:
def_filename = defs / f"{model_name(model)}.plantuml"
def_filename.write_text(model_to_plantuml_class(model))
logger.debug("Wrote: %s", def_filename)
# now generate the relationship diagrams:
logger.debug("Writing relationship diagrams...")
for model in models:
if imports:
rel_filename = base / f"relations_{model_name(model)}.plantuml"
rel_filename.write_text(
model_to_plantuml_relations(
model, max_depth=max_depth, show_inheritance=show_inheritance
)
)
logger.debug("Wrote: %s", rel_filename)
full_filename = base / f"full_{model_name(model)}.plantuml"
full_filename.write_text(
model_to_plantuml_relations(
model,
import_definitions=True,
max_depth=max_depth,
show_inheritance=True,
)
)
logger.debug("Wrote: %s", full_filename)
[docs]
class PlantUMLDiagram:
"""Render a PlantUML diagram in a jupyter notebook.
Diagrams can be composed using the ``+`` operator.
"""
def __init__(
self,
model_or_text: str | type[BaseModel | Enum],
relations: bool = True,
inheritance: bool = False,
details: bool = False,
max_depth: int = 1000,
cwd: Path | str | None = None,
) -> None:
"""
Create a PlantUML diagram.
Parameters
----------
model_or_text: str | type[BaseModel]
Either a model, or any arbitrary plantuml-format text.
relations: bool
include composed classes in the diagram
inheritance: bool
include inherited classes in the diagram
details: bool
include details of inherited or related classes
cwd: Path | str
current working directory, in case you have a relative
``!include`` directive in the plantuml text.
"""
if isinstance(model_or_text, str):
self.plantuml_text = model_or_text
elif isinstance(model_or_text, type) and issubclass(
model_or_text, BaseModel | Enum
):
self.plantuml_text = model_to_plantuml_class(model_or_text)
if relations and issubclass(model_or_text, BaseModel):
self.plantuml_text = (
self
+ PlantUMLDiagram(
model_to_plantuml_relations(
model_or_text,
show_inheritance=inheritance,
max_depth=max_depth,
)
)
).plantuml_text
if details and issubclass(model_or_text, BaseModel):
classes = classes_in_model(model_or_text, max_depth=max_depth)
if len(classes) > 1:
self.plantuml_text = (
self
+ reduce(
lambda x, y: x + y,
(
PlantUMLDiagram(m, inheritance=False, relations=False)
for m in classes
if m is not model_or_text
),
)
).plantuml_text
else:
raise ValueError("Expected plantuml string or type[BaseModel]")
self.cwd = cwd
[docs]
@classmethod
def from_path(cls, path) -> Self:
"""Return diagram from an existing file at path."""
path = Path(path)
return cls(path.read_text(), cwd=path.parent)
def _repr_html_(self) -> str:
try:
png_data = self._generate_png_pipe()
b64_data = base64.b64encode(png_data).decode("utf-8")
return f'<img src="data:image/png;base64,{b64_data}" />'
except Exception as e:
return (
'<div style="color: red;">Error rendering PlantUML diagram:'
f" {str(e)}</div>"
)
def _generate_png_pipe(self) -> bytes:
"""Generate PNG using PlantUML's pipe mode."""
result = subprocess.run(
["plantuml", "-tpng", "-pipe"],
cwd=self.cwd,
input=self.plantuml_text.encode("utf-8"),
capture_output=True,
check=True,
)
return result.stdout
def __add__(self, other) -> Self:
text = self.plantuml_text.replace(START_UML, "").replace(END_UML, "")
other_text = other.plantuml_text.replace(START_UML, "").replace(END_UML, "")
return PlantUMLDiagram(f"@startuml\n{text}\n{other_text}\n@enduml")
def __str__(self) -> str:
"""Turn diagram to string."""
return self.plantuml_text