Source code for datapipe_testbench.visualization

#!/usr/bin/env python3
"""
Functions to help visualize Benchmarks and their inputs.
"""

from dataclasses import asdict

import graphviz

from .benchmark import Benchmark
from .inputdataset import InputDataset

__all__ = ["graphviz_inputs_to_benchmarks"]

SET = "✔"
UNSET = " - "
DEFAULT_FILENAME_LENGTH = 20


def _dict_to_label(title, thedict, file_length=DEFAULT_FILENAME_LENGTH):
    """Make a nice graphviz-friendly html table."""
    text = [
        "<",
        '<TABLE BORDER="1" CELLBORDER="1" CELLSPACING="0">',
        f'<TR><TD COLSPAN="2" bgcolor="#cccccc"><B>{title}</B></TD></TR>',
    ]
    for field, value in thedict.items():
        if field == "name":
            text.append(
                f'<TR><TD ALIGN="center" colspan="2" bgcolor="#dddddd"> <i>{value}</i> </TD></TR>'
            )
        else:
            val = "" if not value else f"<i>...{str(value)[-file_length:]}</i>"
            state = SET if value else UNSET
            text.append(
                f'<TR><TD ALIGN="LEFT"> {state} {field} </TD><TD PORT="{field}"> {val} </TD></TR>'
            )
    text.append("</TABLE>")
    text.append(">")
    return "\n".join(text)


[docs] def graphviz_inputs_to_benchmarks( input_dataset_list: list[InputDataset], benchmark_list: list[Benchmark], file_length: int = DEFAULT_FILENAME_LENGTH, ) -> graphviz.Digraph: """Generate a graphviz diagram mapping InputDataSets to MetricsStores via benchmarks. The output will display automatically in a Jupyter Notebook, or can be saved to a file by calling the ``render()`` method of the Digraph. Parameters ---------- input_dataset_list: list[InputDataset] List of InputDatasets the user will pass to each Benchmark's generate_metrics method. benchmark_list: list[Benchmark]: Which benchmarks will process the InputDatasets. file_length: int Max length of the filenames in the InputDataset. The last N characters of the filename will be retained, and the rest truncated. Returns ------- graphviz.Digraph: graph that can be displayed or rendered. """ digraph = graphviz.Digraph("Benchmark Workflow") digraph.attr(rankdir="LR") for ii, input_dataset in enumerate(input_dataset_list): nodename = f"InputDataset {ii}" digraph.node( nodename, label=_dict_to_label( nodename, asdict(input_dataset), file_length=file_length ), shape="plaintext", ) for benchmark in benchmark_list: digraph.node( benchmark.name, label=_dict_to_label( " ".join([benchmark.name, " Metrics"]), benchmark.outputs(), file_length=file_length, ), shape="plaintext", ) for req in list(benchmark.required_inputs): has_input = getattr(input_dataset, req) exists = getattr(input_dataset, req).exists() if has_input else False color = "red" if has_input: color = "yellow" if exists: color = "darkgreen" digraph.edge( nodename + ":" + req + ":e", benchmark.name, color=color, ) return digraph