Skip to content

Extending or modifying erdantic

Note

The backend of erdantic was significantly updated in v1.0. This now works very differently than previous versions.

Plugins for model frameworks

erdantic supports data modeling frameworks through a plugins system. Each plugin must implement specific functionality and get registered for erdantic to use it.

The following built-in plugins are provided by erdantic:

  1. attrs — for classes decorated using the attrs package
  2. dataclasses — for classes decorated using the dataclasses standard library module
  3. pydantic — for classes that subclass Pydantic's BaseModel class
  4. pydantic_v1 — for classes that subclass Pydantic's legacy pydantic.v1.BaseModel class

It is possible to customize erdantic by registering a custom plugin, either by overriding a provided one or adding as a new one. The following sections document what you need to register your own plugin.

Components of a plugin

A plugin must implement the following two components: a predicate function and a field extractor function.

Predicate function

A predicate function takes a single object as an input and return a boolean value True if the object is a data model class that this plugin is for and False otherwise. The return type of this function is a TypeGuard for the model class. A protocol class ModelPredicate defines the specification for a valid predicate function.

Source from erdantic/plugins/__init__.py

class ModelPredicate(Protocol[_ModelType_co]):
    """Protocol class for a predicate function for a plugin."""

    def __call__(self, obj: Any) -> TypeGuard[_ModelType_co]: ...

Example implementations of predicate functions include is_pydantic_model and is_dataclass_class.

Field extractor function

A field extractor function takes a single model class of the appropriate type and returns a sequence of FieldInfo instances. A protocol class ModelFieldExtractor defines the specification for a valid field extractor function.

Source from erdantic/plugins/__init__.py

class ModelFieldExtractor(Protocol[_ModelType_contra]):
    """Protocol class for a field extractor function for a plugin."""

    def __call__(self, model: _ModelType_contra) -> Sequence["FieldInfo"]: ...

Example implementations of field extractor functions include get_fields_from_pydantic_model and get_fields_from_dataclass.

The field extractor function is the place where you should try to resolve forward references. Some frameworks provide utility functions to resolve forward references, like Pydantic's model_rebuild and attr's resolve_types. If there isn't one, you should write your own using erdantic's resolve_types_on_dataclass as a reference implementation.

Registering a plugin

A plugin must be registered by calling the register_plugin function with a key identifier and the two functions. If you use a key that already exists, it will overwrite the existing plugin.

Info

register_plugin

register_plugin(
    key: str,
    predicate_fn: ModelPredicate[_ModelType],
    get_fields_fn: ModelFieldExtractor[_ModelType],
)

Register a plugin for a specific model class type.

Parameters:

Name Type Description Default
key str

An identifier for this plugin.

required
predicate_fn ModelPredicate

A predicate function to determine if an object is a class of the model that is supported by this plugin.

required
get_fields_fn ModelFieldExtractor

A function to extract fields from a model class that is supported by this plugin.

required

Currently, manual registration is required. This means that custom plugins can only be loaded when using erdantic as a library, and not as a CLI. In the future, we may support automatic loading of plugins that are distributed with packages through the entry points specification.

Modifying model analysis or diagram rendering

If you would like to make any major changes to the functionality of erdantic, such as:

  • Changing what data gets extracted when analyzing a model
  • Structural changes to how models are represented in the diagram

then you can subclass EntityRelationshipDiagram, ModelInfo, FieldInfo, and/or Edge to modify any behavior.

Warning

Changes like these depend on the internal APIs of erdantic and may be more likely to break between erdantic versions. If you're trying to do something like this, it would be nice to let the maintainers know in the repository discussions.

What to change

Here are some tips on what to change depending on your goals:

Get erdantic to use your subclasses

The best way to use custom subclasses is to subclass EntityRelationshipDiagram. Then, you can instantiate an instance of it and call its methods.

Example

from erdantic.core import EntityRelationshipDiagram
from erdantic.examples.pydantic import Party

class CustomEntityRelationshipDiagram(EntityRelationshipDiagram):
    ...

diagram = CustomEntityRelationshipDiagram()
diagram.add_model(Party)
diagram.draw("diagram.png")

Then, depending on which classes you're implementing subclasses of, you will want to do the following:

  • If subclassing ModelInfo...
    • Also subclass EntityRelationshipDiagram and override the type annotation for models to use your custom subclass. The model info class used is determined by this type annotation.
  • If subclassing FieldInfo...
    • Also subclass ModelInfo and override the the type annotation for fields to use your custom subclass.
    • Add or override a plugin's field extractor function. The field info instances are instantiated in the field extractor function.
  • If subclassing Edge...
    • Also subclass EntityRelationshipDiagram and override the type annotation for edges to use your custom subclass. The edge class used is determined by this type annotation.

Example: Adding a column with default field values

Below is an example that has modified handling of Pydantic models. It extracts and stores the default value for fields, and it adds them as a third column to the tables in the diagram.

Here is how the rendered diagram looks:

Diagram with erdantic modified to add a third column

And here is the source code:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
from html import escape
from typing import Any, Dict, List

import pydantic
import pydantic_core
from typenames import REMOVE_ALL_MODULES, typenames

from erdantic.core import (
    EntityRelationshipDiagram,
    FieldInfo,
    FullyQualifiedName,
    ModelInfo,
    SortedDict,
)
from erdantic.exceptions import UnresolvableForwardRefError
from erdantic.plugins import register_plugin
from erdantic.plugins.pydantic import is_pydantic_model


class FieldInfoWithDefault(FieldInfo):
    """Custom FieldInfo subclass that adds a 'default_value' field and diagram column."""

    default_value: str

    _dot_row_template = (
        """<tr>"""
        """<td>{name}</td>"""
        """<td>{type_name}</td>"""
        """<td port="{name}" width="36">{default_value}</td>"""
        """</tr>"""
    )

    @classmethod
    def from_raw_type(
        cls, model_full_name: FullyQualifiedName, name: str, raw_type: type, raw_default_value: Any
    ):
        default_value = (
            "" if raw_default_value is pydantic_core.PydanticUndefined else repr(raw_default_value)
        )
        field_info = cls(
            model_full_name=model_full_name,
            name=name,
            type_name=typenames(raw_type, remove_modules=REMOVE_ALL_MODULES),
            default_value=default_value,
        )
        field_info._raw_type = raw_type
        return field_info

    def to_dot_row(self) -> str:
        return self._dot_row_template.format(
            name=self.name,
            type_name=self.type_name,
            default_value=escape(self.default_value),  # Escape HTML-unsafe characters
        )


class ModelInfoWithDefault(ModelInfo):
    """Custom ModelInfo subclass that uses FieldInfoWithDefault instead of FieldInfo."""

    fields: Dict[str, FieldInfoWithDefault] = {}


class EntityRelationshipDiagramWithDefault(EntityRelationshipDiagram):
    """Custom EntityRelationshipDiagram subclass that uses ModelInfoWithDefault instead of
    ModelInfo.
    """

    models: SortedDict[str, ModelInfoWithDefault] = SortedDict()


def get_fields_from_pydantic_model_with_default(model) -> List[FieldInfoWithDefault]:
    """Copied from erdantic.plugins.pydantic.get_fields_from_pydantic_model and modified to
    extract default values of fields.
    """
    try:
        # Rebuild model schema to resolve forward references
        model.model_rebuild()
    except pydantic.errors.PydanticUndefinedAnnotation as e:
        model_full_name = FullyQualifiedName.from_object(model)
        forward_ref = e.name
        msg = (
            f"Failed to resolve forward reference '{forward_ref}' in the type annotations for "
            f"Pydantic model {model_full_name}. "
            "You should use the model's model_rebuild() method to manually resolve it."
        )
        raise UnresolvableForwardRefError(
            msg, name=forward_ref, model_full_name=model_full_name
        ) from e
    return [
        FieldInfoWithDefault.from_raw_type(
            model_full_name=FullyQualifiedName.from_object(model),
            name=name,
            # typing special forms currently get typed as object
            # https://github.com/python/mypy/issues/9773
            raw_type=pydantic_field_info.annotation or Any,  # type: ignore
            raw_default_value=pydantic_field_info.default,
        )
        for name, pydantic_field_info in model.model_fields.items()
    ]


# Register this plugin. Will override erdantic's built-in 'pydantic' plugin.
register_plugin("pydantic", is_pydantic_model, get_fields_from_pydantic_model_with_default)