Source code for quantify_scheduler.structure.types
# Repository: https://gitlab.com/quantify-os/quantify-scheduler
# Licensed according to the LICENCE file on the main branch
"""
Types that support validation in Pydantic.
Pydantic recognizes magic method ``__get_validators__`` to receive additional
validators, that can be used, i.e., for custom serialization and deserialization.
We implement several custom types here to tune behavior of our models.
See `Pydantic documentation`_ for more information about implementing new types.
.. _Pydantic documentation: https://docs.pydantic.dev/latest/usage/types/custom/
"""
from __future__ import annotations
import base64
from collections.abc import Callable, Mapping
from importlib.metadata import version as pkg_version
from typing import TYPE_CHECKING, Any
import networkx as nx
import numpy as np
from pydantic_core import core_schema
if TYPE_CHECKING:
from numpy.typing import ArrayLike
# NOTE: NetworkX dropped support for Python 3.9 with version 3.3
# (see: https://networkx.org/documentation/stable/release/release_3.3.html#maintenance)
# The last version that supported Python 3.9 is 3.2.1, which however has a different function
# signature for `node_data_link()`. For this reason, we must pass different parameters to the
# serialization function depending on the package version.
if pkg_version("networkx") < "3.3": # noqa: SIM108
[docs]
NODE_LINK_DATA_KWARGS = {"link": "links"}
else:
NODE_LINK_DATA_KWARGS = {"edges": "links"}
[docs]
class NDArray(np.ndarray):
"""
Pydantic-compatible version of :class:`numpy.ndarray`.
Serialization is implemented using custom methods :meth:`.to_dict` and
:meth:`.from_dict`. Data array is encoded in Base64.
"""
def __new__(cls: type[NDArray], array_like: ArrayLike) -> NDArray: # noqa: D102
return np.asarray(array_like).view(cls)
@classmethod
def __get_pydantic_core_schema__(
cls: type[NDArray],
_source_type: Any, # noqa: ANN401
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
def to_dict(v: NDArray) -> dict[str, Any]:
"""Convert the array to JSON-compatible dictionary."""
return {
"data": base64.b64encode(v.tobytes()).decode("ascii"),
"shape": v.shape,
"dtype": str(v.dtype),
}
return core_schema.no_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(
to_dict, when_used="json"
),
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert the array to JSON-compatible dictionary."""
return {
"data": base64.b64encode(self.tobytes()).decode("ascii"),
"shape": self.shape,
"dtype": str(self.dtype),
}
@classmethod
[docs]
def from_dict(cls: type[NDArray], serialized: Mapping[str, Any]) -> NDArray:
"""
Construct an instance from a dictionary generated by :meth`to_dict`.
Parameters
----------
serialized
Dictionary that has ``"data"``, ``"shape"`` and ``"dtype"`` keys.",
where data is a base64-encoded bytes array, shape is a tuple and dtype is
a string representation of a Numpy data type.
"""
return (
np.frombuffer(base64.b64decode(serialized["data"]), dtype=serialized["dtype"])
.reshape(serialized["shape"])
.view(cls)
)
@classmethod
[docs]
def validate(cls: type[NDArray], v: Any) -> NDArray: # noqa: ANN401
"""Validate the data and cast from all known representations."""
if isinstance(v, Mapping):
return cls.from_dict(v) # type: ignore
return cls(v)
[docs]
class Graph(nx.Graph):
"""Pydantic-compatible version of :class:`networkx.Graph`."""
# Avoid showing inherited init docstring (which leads to cross-reference issues)
def __init__(self, incoming_graph_data=None, **attr) -> None: # noqa: ANN001
"""Create a new graph instance."""
super().__init__(incoming_graph_data, **attr)
@classmethod
def __get_pydantic_core_schema__(
cls: type[Graph],
_source_type: Any, # noqa: ANN401
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda g: nx.node_link_data(g, **NODE_LINK_DATA_KWARGS), when_used="always"
),
)
@classmethod
[docs]
def validate(cls: type[Graph], v: Any) -> Graph: # noqa: ANN401
"""Validate the data and cast from all known representations."""
if isinstance(v, dict):
return cls(nx.node_link_graph(v))
return cls(v)