# Repository: https://gitlab.com/quantify-os/quantify-scheduler
# Licensed according to the LICENCE file on the main branch
"""Plotting functions used in the visualization backend of the sequencer."""
from __future__ import annotations
from copy import deepcopy
from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Iterable, Iterator
import matplotlib
import quantify_scheduler.schedules._visualization.pulse_scheme as ps
from quantify_scheduler.compilation import _determine_absolute_timing
from quantify_scheduler.helpers.importers import import_python_object_from_string
from quantify_scheduler.operations.control_flow_library import (
ConditionalOperation,
LoopOperation,
)
from quantify_scheduler.operations.operation import Operation
from quantify_scheduler.schedules._visualization import constants
from quantify_scheduler.schedules.schedule import Schedule, ScheduleBase
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from quantify_scheduler.resources import Resource
[docs]
def gate_box(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None:
"""
A box for a single gate containing a label.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the gate.
device_element_idxs :
The device_element indices.
text :
The gate name.
kw :
Additional keyword arguments to be passed to drawing the gate box.
"""
for device_element_idx in device_element_idxs:
ps.box_text(
ax,
x0=time,
y0=device_element_idx,
text=text,
fillcolor=constants.COLOR_LAZURE,
width=0.8,
height=0.5,
**kw,
)
[docs]
def pulse_baseband(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None:
"""
Adds a visual indicator for a Baseband pulse to the `matplotlib.axes.Axis`
instance.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the pulse.
device_element_idxs :
The device_element indices.
text :
The pulse name.
kw :
Additional keyword arguments to be passed to drawing the pulse.
"""
cartoon_width = 0.6
for device_element_idx in device_element_idxs:
ps.flux_pulse(
ax,
pos=time - cartoon_width / 2,
y_offs=device_element_idx,
width=cartoon_width,
s=0.0025,
amp=0.33,
**kw,
)
ax.text(time, device_element_idx + 0.45, text, ha="center", va="center", zorder=6)
[docs]
def pulse_modulated(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None:
"""
Adds a visual indicator for a Modulated pulse to the `matplotlib.axes.Axis`
instance.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the pulse.
device_element_idxs :
The device_element indices.
text :
The pulse name.
kw :
Additional keyword arguments to be passed to drawing the pulse.
"""
cartoon_width = 0.6
for device_element_idx in device_element_idxs:
ps.mw_pulse(
ax,
pos=time - cartoon_width / 2,
y_offs=device_element_idx,
width=cartoon_width,
amp=0.33,
**kw,
)
ax.text(time, device_element_idx + 0.45, text, ha="center", va="center", zorder=6)
[docs]
def meter(
ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw # Noqa: ARG001
) -> None:
"""
A simple meter to depict a measurement.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the measurement.
device_element_idxs :
The device_element indices.
text :
The measurement name.
kw :
Additional keyword arguments to be passed to drawing the meter.
"""
for device_element_idx in device_element_idxs:
ps.meter(
ax,
x0=time,
y0=device_element_idx,
fillcolor=constants.COLOR_GREY,
y_offs=0,
width=0.8,
height=0.5,
**kw,
)
[docs]
def acq_meter(
ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw # Noqa: ARG001
) -> None:
"""
Variation of the meter to depict a acquisition.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the measurement.
device_element_idxs :
The device_element indices.
text :
The measurement name.
kw :
Additional keyword arguments to be passed to drawing the acq meter.
"""
for device_element_idx in device_element_idxs:
ps.meter(
ax,
x0=time,
y0=device_element_idx,
fillcolor="white",
y_offs=0.0,
width=0.8,
height=0.5,
framewidth=constants.ACQ_METER_LINEWIDTH,
**kw,
)
[docs]
def acq_meter_text(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None:
"""
Same as acq_meter, but also displays text.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the measurement.
device_element_idxs :
The device_element indices.
text :
The measurement name.
kw :
Additional keyword arguments to be passed to drawing the acq meter.
"""
acq_meter(ax, time, device_element_idxs, text, **kw)
ax.text(time, max(device_element_idxs) + 0.45, text, ha="center", va="center", zorder=6)
[docs]
def cnot(
ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw # Noqa: ARG001
) -> None:
"""
Markers to denote a CNOT gate between two device_elements.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the CNOT.
device_element_idxs :
The device_element indices.
text :
The CNOT name.
kw :
Additional keyword arguments to be passed to drawing the CNOT.
"""
ax.plot(
[time, time], device_element_idxs, marker="o", markersize=15, color=constants.COLOR_BLUE
)
ax.plot([time], device_element_idxs[1], marker="+", markersize=12, color="white")
[docs]
def cz(
ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw # Noqa: ARG001
) -> None:
"""
Markers to denote a CZ gate between two device_elements.
Parameters
----------
ax :
The matplotlib Axes.
time :
The time of the CZ.
device_element_idxs :
The device_element indices.
text :
The CZ name.
kw :
Additional keyword arguments to be passed to drawing the CZ.
"""
ax.plot(
[time, time], device_element_idxs, marker="o", markersize=15, color=constants.COLOR_BLUE
)
[docs]
def reset(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None:
"""
A broken line to denote device_element initialization.
Parameters
----------
ax
matplotlib axis object.
time
x position to draw the reset on
device_element_idxs
indices of the device_elements that the reset is performed on.
text :
The reset name.
kw :
Additional keyword arguments to be passed to drawing the reset.
"""
for device_element_idx in device_element_idxs:
ps.box_text(
ax,
x0=time,
y0=device_element_idx,
text=text,
color="white",
fillcolor="white",
width=0.4,
height=0.5,
**kw,
)
[docs]
class _ControlFlowEnd(Enum):
"""Identifer for end of a control-flow scope."""
[docs]
def _walk_schedule(
sched_or_op: Schedule | Operation, time_offset: int = 0
) -> Iterator[tuple[int, Operation | Schedule | _ControlFlowEnd]]:
if isinstance(sched_or_op, ScheduleBase):
yield time_offset, sched_or_op
for schedulable in sched_or_op.schedulables.values():
operation = sched_or_op.operations[schedulable["operation_id"]]
yield from _walk_schedule(
sched_or_op=operation, time_offset=time_offset + schedulable["abs_time"]
)
elif isinstance(sched_or_op, LoopOperation):
yield time_offset, sched_or_op
if isinstance(sched_or_op.body, ScheduleBase):
yield from _walk_schedule(sched_or_op.body, time_offset)
else:
yield time_offset, sched_or_op.body
yield time_offset + int(sched_or_op.duration), _ControlFlowEnd.LOOP_END
elif isinstance(sched_or_op, ConditionalOperation):
yield time_offset, sched_or_op
if isinstance(sched_or_op.body, ScheduleBase):
yield from _walk_schedule(sched_or_op.body, time_offset)
else:
yield time_offset, sched_or_op.body
yield time_offset + int(sched_or_op.duration), _ControlFlowEnd.CONDI_END
elif isinstance(sched_or_op, Operation):
yield time_offset, sched_or_op
else:
raise ValueError(f"Unknown operation type {type(sched_or_op)}.")
[docs]
def _walk_schedule_only_operations(
sched_or_op: Schedule | Operation,
) -> Iterator[Operation]:
if isinstance(sched_or_op, ScheduleBase):
for operation in sched_or_op.operations.values():
yield from _walk_schedule_only_operations(operation)
elif isinstance(sched_or_op, (LoopOperation, ConditionalOperation)):
if isinstance(sched_or_op.body, ScheduleBase):
yield from _walk_schedule_only_operations(sched_or_op.body)
else:
yield sched_or_op.body
elif isinstance(sched_or_op, Operation):
yield sched_or_op
else:
raise ValueError(f"Unknown operation type {type(sched_or_op)}.")
[docs]
def _draw_operation(
operation: Operation,
device_element_map: dict[str, int],
port_map: dict[str, int],
ax: Axes,
time: int,
schedule_resources: dict[str, Resource],
) -> None:
if operation.valid_gate:
plot_func = import_python_object_from_string(operation["gate_info"]["plot_func"])
idxs = [
device_element_map[device_element]
for device_element in operation["gate_info"]["device_elements"]
]
plot_func(ax, time=time, device_element_idxs=idxs, text=operation["gate_info"]["tex"])
elif operation.valid_pulse:
idxs = list(
{
port_map[pulse_info["port"]]
for pulse_info in operation["pulse_info"]
if pulse_info["port"] is not None
}
)
for pulse_info in operation["pulse_info"]:
clock_id: str = pulse_info["clock"]
clock_resource = schedule_resources[clock_id]
if clock_resource["freq"] == 0:
pulse_baseband(ax, time=time, device_element_idxs=idxs, text=operation.name)
else:
pulse_modulated(ax, time=time, device_element_idxs=idxs, text=operation.name)
elif operation.valid_acquisition:
idxs = list({port_map[acq_info["port"]] for acq_info in operation["acquisition_info"]})
for _ in operation["acquisition_info"]:
acq_meter(ax, time=time, device_element_idxs=idxs, text=operation.name)
else:
raise ValueError("Unknown operation")
[docs]
def _get_indices(
sched_or_op: Schedule | Operation,
device_element_map: dict[str, int],
port_map: dict[str, int],
) -> set[int]:
def add_index_from_operation(operation: Operation, index_set: set[int]) -> None:
if operation.valid_gate:
index_set.update(
device_element_map[device_element]
for device_element in operation["gate_info"]["device_elements"]
)
index_set.update(
port_map[info["port"]]
for info in chain(operation["pulse_info"], operation["acquisition_info"])
)
indices: set[int] = set()
if isinstance(sched_or_op, Operation):
add_index_from_operation(sched_or_op, indices)
return indices
for operation in _walk_schedule_only_operations(sched_or_op):
add_index_from_operation(operation, indices)
return indices
[docs]
def _draw_loop(
ax: Axes,
device_element_map: dict[str, int],
port_map: dict[str, int],
operation: LoopOperation,
start_time: int,
end_time: int,
x_offset: float = 0.35,
y_offset: float = 0.3,
fraction: float = 0.2,
) -> None:
reps = operation["control_flow_info"]["repetitions"]
def draw_brackets(bottom_device_element: int, top_device_element: int) -> None:
x_start = start_time - x_offset
x_end = end_time + x_offset
y_top = top_device_element + y_offset
y_bottom = bottom_device_element - y_offset
ax.annotate(
"",
xy=(x_start, y_bottom),
xytext=(x_start, y_top),
arrowprops=dict(
arrowstyle="-",
linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH,
facecolor=constants.COLOR_DARK_MODE_LINE,
connectionstyle=f"bar,fraction={fraction/(top_device_element-bottom_device_element+1)}",
),
)
ax.annotate(
"",
xy=(x_end, y_bottom),
xytext=(x_end, y_top),
arrowprops=dict(
arrowstyle="-",
linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH,
facecolor=constants.COLOR_DARK_MODE_LINE,
connectionstyle=f"bar,fraction=-{fraction/(top_device_element-bottom_device_element+1)}",
),
)
ax.text(x_end + 0.1, y_top + 0.05, f"x{reps}")
involved_indices = _get_indices(operation.body, device_element_map, port_map)
if len(involved_indices) == len(device_element_map):
draw_brackets(0, len(device_element_map) - 1)
else:
for idx in involved_indices:
draw_brackets(idx, idx)
[docs]
def _draw_conditional(
ax: Axes,
measure_time: int,
measure_device_element_idx: int,
body: Operation | Schedule,
body_start: int,
body_end: int,
device_element_map: dict[str, int],
port_map: dict[str, int],
) -> None:
def draw_for_single_operation(index: int) -> None:
ax.annotate(
"",
xy=(body_start, index + 0.25),
xytext=(measure_time, measure_device_element_idx + 0.25),
arrowprops=dict(
arrowstyle="->",
facecolor=constants.COLOR_DARK_MODE_LINE,
linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH,
connectionstyle="bar,angle=180,fraction=-0.25",
),
)
ax.text(
measure_time + 0.5,
measure_device_element_idx + 0.5,
"m=1",
ha="center",
va="center",
backgroundcolor="white",
)
def draw_rectangle_with_arrow(bottom_device_element: int, top_device_element: int) -> None:
p1 = matplotlib.patches.Rectangle( # type: ignore
(body_start - 0.45, bottom_device_element - 0.45),
body_end - body_start + 0.9,
top_device_element - bottom_device_element + 0.95,
edgecolor=constants.COLOR_DARK_MODE_LINE,
fill=False,
)
ax.add_patch(p1)
ax.annotate(
"",
xy=(body_start - 0.45, top_device_element + 0.4),
xytext=(measure_time, measure_device_element_idx + 0.25),
arrowprops=dict(
arrowstyle="->",
facecolor=constants.COLOR_DARK_MODE_LINE,
linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH,
connectionstyle="angle,angleA=90,angleB=180,rad=0",
),
)
def draw_for_schedule_all_device_elements() -> None:
draw_rectangle_with_arrow(0, len(device_element_map) - 1)
ax.text(
measure_time - 0.1,
measure_device_element_idx + 0.5,
"m=1",
ha="center",
va="center",
backgroundcolor="white",
)
def draw_for_schedule_single_device_elements(involved_indices: Iterable[int]) -> None:
for index in involved_indices:
draw_rectangle_with_arrow(index, index)
ax.text(
measure_time - 0.1,
measure_device_element_idx + 0.5,
"m=1",
ha="center",
va="center",
backgroundcolor="white",
)
involved_indices = _get_indices(body, device_element_map, port_map)
if isinstance(body, Operation):
for idx in involved_indices:
draw_for_single_operation(idx)
elif len(involved_indices) == len(device_element_map):
draw_for_schedule_all_device_elements()
else:
draw_for_schedule_single_device_elements(involved_indices)
[docs]
def _get_device_element_and_port_map_from_schedule(
schedule: Schedule,
) -> tuple[dict[str, int], dict[str, int]]:
ports: set[str] = set()
device_elements: set[str] = set()
for operation in _walk_schedule_only_operations(schedule):
if operation.valid_gate:
device_elements.update(operation["gate_info"]["device_elements"])
continue
for info in chain(operation["pulse_info"], operation["acquisition_info"]):
if (port := info["port"]) is not None:
# Can be None e.g. in case of NCO operations.
ports.add(port)
device_element_map = {
device_element: idx for idx, device_element in enumerate(sorted(device_elements))
}
port_map: dict[str, int] = {}
added_other = False
for port in ports:
maybe_device_element = port.split(":")[0]
if maybe_device_element in device_elements:
port_map[port] = device_element_map[maybe_device_element]
elif not added_other:
for device_element, idx in device_element_map.items():
device_element_map[device_element] = idx + 1
device_element_map["other"] = 0
port_map[port] = 0
added_other = True
else:
port_map[port] = device_element_map["other"]
return device_element_map, port_map
[docs]
def _get_feedback_label_and_device_element_idx(
operation: Operation, port_map: dict[str, int], device_element_map: dict[str, int]
) -> tuple[str, int] | None:
"""Check if the operation is an acquisition/measure gate with a feedback trigger label."""
if (
len(operation["acquisition_info"])
and (feedback_label := operation["acquisition_info"][0].get("feedback_trigger_label", None))
is not None
):
return feedback_label, port_map[operation["acquisition_info"][0]["port"]]
if (
operation.valid_gate
and (feedback_label := operation["gate_info"].get("feedback_trigger_label", None))
is not None
):
return feedback_label, device_element_map[operation["gate_info"]["device_elements"][0]]
return None
[docs]
def circuit_diagram_matplotlib(
schedule: Schedule,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
) -> tuple[Figure | None, Axes]:
# to prevent the original input schedule from being modified.
schedule = _determine_absolute_timing(deepcopy(schedule), "ideal")
device_element_map, port_map = _get_device_element_and_port_map_from_schedule(schedule)
if figsize is None:
figsize = (10, len(device_element_map))
fig, ax = ps.new_pulse_fig(figsize=figsize, ax=ax)
ax.set_title(schedule.data["name"])
ax.set_aspect("equal")
ax.set_ylim(-0.5, len(device_element_map) - 0.5)
for y in device_element_map.values():
ax.axhline(y, color=constants.COLOR_DARK_MODE_LINE, linewidth=0.9)
# plot the device_element names on the y-axis
ax.set_yticks(list(device_element_map.values()))
ax.set_yticklabels(device_element_map.keys())
current_diagram_time = 0
last_operation_time = 0
# Stack of (loop start time, loop operation) tuples
loop_scopes: list[tuple[int, LoopOperation]] = []
# Stack of (conditional start time, conditional operation) tuples
conditional_scopes: list[tuple[int, ConditionalOperation]] = []
# Map from feedback_trigger_label to (thresholded acq time, device_element) tuple
feedback_acq_map: dict[str, tuple[int, int]] = {}
for abs_time, operation in _walk_schedule(schedule):
if isinstance(operation, LoopOperation):
loop_scopes.append((current_diagram_time + 1, operation))
elif isinstance(operation, ConditionalOperation):
conditional_scopes.append((current_diagram_time + 1, operation))
elif isinstance(operation, Operation):
if abs_time > last_operation_time:
current_diagram_time += 1
last_operation_time = abs_time
# draw_time is a quick fix for displaying simultaneity consistently if a
# single operation is simultaneous with a sub-schedule with multiple
# operations.
draw_time = current_diagram_time
else:
draw_time = abs_time
if (
feedback_label_and_device_element_idx := _get_feedback_label_and_device_element_idx(
operation, port_map, device_element_map
)
) is not None:
feedback_label, device_element_idx = feedback_label_and_device_element_idx
feedback_acq_map[feedback_label] = current_diagram_time, device_element_idx
_draw_operation(
operation=operation,
device_element_map=device_element_map,
port_map=port_map,
ax=ax,
time=draw_time,
schedule_resources=schedule.resources,
)
elif isinstance(operation, ScheduleBase):
pass
elif operation == _ControlFlowEnd.LOOP_END:
start_time, loop_op = loop_scopes.pop()
_draw_loop(
ax=ax,
device_element_map=device_element_map,
port_map=port_map,
operation=loop_op,
start_time=start_time,
end_time=current_diagram_time,
)
elif operation == _ControlFlowEnd.CONDI_END:
body_start, conditional_op = conditional_scopes.pop()
feedback_trigger_label = conditional_op["control_flow_info"]["feedback_trigger_label"]
try:
measure_time, measure_device_element_idx = feedback_acq_map[feedback_trigger_label]
except KeyError as err:
raise KeyError(
f"Feedback trigger label '{feedback_trigger_label}' not found in "
"any preceding Measure or acquisition operation."
) from err
_draw_conditional(
ax=ax,
measure_time=measure_time,
measure_device_element_idx=measure_device_element_idx,
body=conditional_op.body,
body_start=body_start,
body_end=current_diagram_time,
device_element_map=device_element_map,
port_map=port_map,
)
ax.set_xlim(-1, current_diagram_time + 1)
return fig, ax