Source code for quantify_scheduler.schedules._visualization.circuit_diagram

# Repository:
# Licensed according to the LICENCE file on the main branch
"""Plotting functions used in the visualization backend of the sequencer."""

from __future__ import annotations

from copy import deepcopy
from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Iterable, Iterator

import matplotlib

import quantify_scheduler.schedules._visualization.pulse_scheme as ps
from quantify_scheduler.compilation import _determine_absolute_timing
from quantify_scheduler.helpers.importers import import_python_object_from_string
from quantify_scheduler.operations.control_flow_library import (
from quantify_scheduler.operations.operation import Operation
from quantify_scheduler.schedules._visualization import constants
from quantify_scheduler.schedules.schedule import Schedule, ScheduleBase

    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from quantify_scheduler.resources import Resource

[docs] def gate_box(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None: """ A box for a single gate containing a label. Parameters ---------- ax : The matplotlib Axes. time : The time of the gate. device_element_idxs : The device_element indices. text : The gate name. kw : Additional keyword arguments to be passed to drawing the gate box. """ for device_element_idx in device_element_idxs: ps.box_text( ax, x0=time, y0=device_element_idx, text=text, fillcolor=constants.COLOR_LAZURE, width=0.8, height=0.5, **kw, )
[docs] def pulse_baseband(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None: """ Adds a visual indicator for a Baseband pulse to the `matplotlib.axes.Axis` instance. Parameters ---------- ax : The matplotlib Axes. time : The time of the pulse. device_element_idxs : The device_element indices. text : The pulse name. kw : Additional keyword arguments to be passed to drawing the pulse. """ cartoon_width = 0.6 for device_element_idx in device_element_idxs: ps.flux_pulse( ax, pos=time - cartoon_width / 2, y_offs=device_element_idx, width=cartoon_width, s=0.0025, amp=0.33, **kw, ) ax.text(time, device_element_idx + 0.45, text, ha="center", va="center", zorder=6)
[docs] def pulse_modulated(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None: """ Adds a visual indicator for a Modulated pulse to the `matplotlib.axes.Axis` instance. Parameters ---------- ax : The matplotlib Axes. time : The time of the pulse. device_element_idxs : The device_element indices. text : The pulse name. kw : Additional keyword arguments to be passed to drawing the pulse. """ cartoon_width = 0.6 for device_element_idx in device_element_idxs: ps.mw_pulse( ax, pos=time - cartoon_width / 2, y_offs=device_element_idx, width=cartoon_width, amp=0.33, **kw, ) ax.text(time, device_element_idx + 0.45, text, ha="center", va="center", zorder=6)
[docs] def meter( ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw, # Noqa: ARG001 ) -> None: """ A simple meter to depict a measurement. Parameters ---------- ax : The matplotlib Axes. time : The time of the measurement. device_element_idxs : The device_element indices. text : The measurement name. kw : Additional keyword arguments to be passed to drawing the meter. """ for device_element_idx in device_element_idxs: ps.meter( ax, x0=time, y0=device_element_idx, fillcolor=constants.COLOR_GREY, y_offs=0, width=0.8, height=0.5, **kw, )
[docs] def acq_meter( ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw, # Noqa: ARG001 ) -> None: """ Variation of the meter to depict a acquisition. Parameters ---------- ax : The matplotlib Axes. time : The time of the measurement. device_element_idxs : The device_element indices. text : The measurement name. kw : Additional keyword arguments to be passed to drawing the acq meter. """ for device_element_idx in device_element_idxs: ps.meter( ax, x0=time, y0=device_element_idx, fillcolor="white", y_offs=0.0, width=0.8, height=0.5, framewidth=constants.ACQ_METER_LINEWIDTH, **kw, )
[docs] def acq_meter_text(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None: """ Same as acq_meter, but also displays text. Parameters ---------- ax : The matplotlib Axes. time : The time of the measurement. device_element_idxs : The device_element indices. text : The measurement name. kw : Additional keyword arguments to be passed to drawing the acq meter. """ acq_meter(ax, time, device_element_idxs, text, **kw) ax.text(time, max(device_element_idxs) + 0.45, text, ha="center", va="center", zorder=6)
[docs] def cnot( ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw, # Noqa: ARG001 ) -> None: """ Markers to denote a CNOT gate between two device_elements. Parameters ---------- ax : The matplotlib Axes. time : The time of the CNOT. device_element_idxs : The device_element indices. text : The CNOT name. kw : Additional keyword arguments to be passed to drawing the CNOT. """ ax.plot( [time, time], device_element_idxs, marker="o", markersize=15, color=constants.COLOR_BLUE ) ax.plot([time], device_element_idxs[1], marker="+", markersize=12, color="white")
[docs] def cz( ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw, # Noqa: ARG001 ) -> None: """ Markers to denote a CZ gate between two device_elements. Parameters ---------- ax : The matplotlib Axes. time : The time of the CZ. device_element_idxs : The device_element indices. text : The CZ name. kw : Additional keyword arguments to be passed to drawing the CZ. """ ax.plot( [time, time], device_element_idxs, marker="o", markersize=15, color=constants.COLOR_BLUE )
[docs] def reset(ax: Axes, time: float, device_element_idxs: list[int], text: str, **kw) -> None: """ A broken line to denote device_element initialization. Parameters ---------- ax matplotlib axis object. time x position to draw the reset on device_element_idxs indices of the device_elements that the reset is performed on. text : The reset name. kw : Additional keyword arguments to be passed to drawing the reset. """ for device_element_idx in device_element_idxs: ps.box_text( ax, x0=time, y0=device_element_idx, text=text, color="white", fillcolor="white", width=0.4, height=0.5, **kw, )
[docs] class _ControlFlowEnd(Enum): """Identifer for end of a control-flow scope."""
[docs] LOOP_END = auto()
[docs] CONDI_END = auto()
[docs] def _walk_schedule( sched_or_op: Schedule | Operation, time_offset: int = 0 ) -> Iterator[tuple[int, Operation | Schedule | _ControlFlowEnd]]: if isinstance(sched_or_op, ScheduleBase): yield time_offset, sched_or_op for schedulable in sched_or_op.schedulables.values(): operation = sched_or_op.operations[schedulable["operation_id"]] yield from _walk_schedule( sched_or_op=operation, time_offset=time_offset + schedulable["abs_time"] ) elif isinstance(sched_or_op, LoopOperation): yield time_offset, sched_or_op if isinstance(sched_or_op.body, ScheduleBase): yield from _walk_schedule(sched_or_op.body, time_offset) else: yield time_offset, sched_or_op.body yield time_offset + int(sched_or_op.duration), _ControlFlowEnd.LOOP_END elif isinstance(sched_or_op, ConditionalOperation): yield time_offset, sched_or_op if isinstance(sched_or_op.body, ScheduleBase): yield from _walk_schedule(sched_or_op.body, time_offset) else: yield time_offset, sched_or_op.body yield time_offset + int(sched_or_op.duration), _ControlFlowEnd.CONDI_END elif isinstance(sched_or_op, Operation): yield time_offset, sched_or_op else: raise ValueError(f"Unknown operation type {type(sched_or_op)}.")
[docs] def _walk_schedule_only_operations( sched_or_op: Schedule | Operation, ) -> Iterator[Operation]: if isinstance(sched_or_op, ScheduleBase): for operation in sched_or_op.operations.values(): yield from _walk_schedule_only_operations(operation) elif isinstance(sched_or_op, (LoopOperation, ConditionalOperation)): if isinstance(sched_or_op.body, ScheduleBase): yield from _walk_schedule_only_operations(sched_or_op.body) else: yield sched_or_op.body elif isinstance(sched_or_op, Operation): yield sched_or_op else: raise ValueError(f"Unknown operation type {type(sched_or_op)}.")
[docs] def _draw_operation( operation: Operation, device_element_map: dict[str, int], port_map: dict[str, int], ax: Axes, time: int, schedule_resources: dict[str, Resource], ) -> None: if operation.valid_gate: plot_func = import_python_object_from_string(operation["gate_info"]["plot_func"]) idxs = [ device_element_map[device_element] for device_element in operation["gate_info"]["device_elements"] ] plot_func(ax, time=time, device_element_idxs=idxs, text=operation["gate_info"]["tex"]) elif operation.valid_pulse: idxs = list( { port_map[pulse_info["port"]] for pulse_info in operation["pulse_info"] if pulse_info["port"] is not None } ) for pulse_info in operation["pulse_info"]: clock_id: str = pulse_info["clock"] clock_resource = schedule_resources[clock_id] if clock_resource["freq"] == 0: pulse_baseband(ax, time=time, device_element_idxs=idxs, else: pulse_modulated(ax, time=time, device_element_idxs=idxs, elif operation.valid_acquisition: idxs = list({port_map[acq_info["port"]] for acq_info in operation["acquisition_info"]}) for _ in operation["acquisition_info"]: acq_meter(ax, time=time, device_element_idxs=idxs, else: raise ValueError("Unknown operation")
[docs] def _get_indices( sched_or_op: Schedule | Operation, device_element_map: dict[str, int], port_map: dict[str, int], ) -> set[int]: def add_index_from_operation(operation: Operation, index_set: set[int]) -> None: if operation.valid_gate: index_set.update( device_element_map[device_element] for device_element in operation["gate_info"]["device_elements"] ) index_set.update( port_map[info["port"]] for info in chain(operation["pulse_info"], operation["acquisition_info"]) ) indices: set[int] = set() if isinstance(sched_or_op, Operation): add_index_from_operation(sched_or_op, indices) return indices for operation in _walk_schedule_only_operations(sched_or_op): add_index_from_operation(operation, indices) return indices
[docs] def _draw_loop( ax: Axes, device_element_map: dict[str, int], port_map: dict[str, int], operation: LoopOperation, start_time: int, end_time: int, x_offset: float = 0.35, y_offset: float = 0.3, fraction: float = 0.2, ) -> None: reps = operation["control_flow_info"]["repetitions"] def draw_brackets(bottom_device_element: int, top_device_element: int) -> None: x_start = start_time - x_offset x_end = end_time + x_offset y_top = top_device_element + y_offset y_bottom = bottom_device_element - y_offset ax.annotate( "", xy=(x_start, y_bottom), xytext=(x_start, y_top), arrowprops=dict( arrowstyle="-", linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH, facecolor=constants.COLOR_DARK_MODE_LINE, connectionstyle=f"bar," f"fraction={fraction / (top_device_element - bottom_device_element + 1)}", ), ) ax.annotate( "", xy=(x_end, y_bottom), xytext=(x_end, y_top), arrowprops=dict( arrowstyle="-", linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH, facecolor=constants.COLOR_DARK_MODE_LINE, connectionstyle=f"bar," f"fraction=-{fraction / (top_device_element - bottom_device_element + 1)}", ), ) ax.text(x_end + 0.1, y_top + 0.05, f"x{reps}") involved_indices = _get_indices(operation.body, device_element_map, port_map) if len(involved_indices) == len(device_element_map): draw_brackets(0, len(device_element_map) - 1) else: for idx in involved_indices: draw_brackets(idx, idx)
[docs] def _draw_conditional( ax: Axes, measure_time: int, measure_device_element_idx: int, body: Operation | Schedule, body_start: int, body_end: int, device_element_map: dict[str, int], port_map: dict[str, int], ) -> None: def draw_for_single_operation(index: int) -> None: ax.annotate( "", xy=(body_start, index + 0.25), xytext=(measure_time, measure_device_element_idx + 0.25), arrowprops=dict( arrowstyle="->", facecolor=constants.COLOR_DARK_MODE_LINE, linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH, connectionstyle="bar,angle=180,fraction=-0.25", ), ) ax.text( measure_time + 0.5, measure_device_element_idx + 0.5, "m=1", ha="center", va="center", backgroundcolor="white", ) def draw_rectangle_with_arrow(bottom_device_element: int, top_device_element: int) -> None: p1 = matplotlib.patches.Rectangle( # type: ignore (body_start - 0.45, bottom_device_element - 0.45), body_end - body_start + 0.9, top_device_element - bottom_device_element + 0.95, edgecolor=constants.COLOR_DARK_MODE_LINE, fill=False, ) ax.add_patch(p1) ax.annotate( "", xy=(body_start - 0.45, top_device_element + 0.4), xytext=(measure_time, measure_device_element_idx + 0.25), arrowprops=dict( arrowstyle="->", facecolor=constants.COLOR_DARK_MODE_LINE, linewidth=constants.CTRL_FLOW_ARROW_LINEWIDTH, connectionstyle="angle,angleA=90,angleB=180,rad=0", ), ) def draw_for_schedule_all_device_elements() -> None: draw_rectangle_with_arrow(0, len(device_element_map) - 1) ax.text( measure_time - 0.1, measure_device_element_idx + 0.5, "m=1", ha="center", va="center", backgroundcolor="white", ) def draw_for_schedule_single_device_elements(involved_indices: Iterable[int]) -> None: for index in involved_indices: draw_rectangle_with_arrow(index, index) ax.text( measure_time - 0.1, measure_device_element_idx + 0.5, "m=1", ha="center", va="center", backgroundcolor="white", ) involved_indices = _get_indices(body, device_element_map, port_map) if isinstance(body, Operation): for idx in involved_indices: draw_for_single_operation(idx) elif len(involved_indices) == len(device_element_map): draw_for_schedule_all_device_elements() else: draw_for_schedule_single_device_elements(involved_indices)
[docs] def _get_device_element_and_port_map_from_schedule( schedule: Schedule, ) -> tuple[dict[str, int], dict[str, int]]: ports: set[str] = set() device_elements: set[str] = set() for operation in _walk_schedule_only_operations(schedule): if operation.valid_gate: device_elements.update(operation["gate_info"]["device_elements"]) continue for info in chain(operation["pulse_info"], operation["acquisition_info"]): if (port := info["port"]) is not None: # Can be None e.g. in case of NCO operations. ports.add(port) device_element_map = { device_element: idx for idx, device_element in enumerate(sorted(device_elements)) } port_map: dict[str, int] = {} added_other = False for port in ports: maybe_device_element = port.split(":")[0] if maybe_device_element in device_elements: port_map[port] = device_element_map[maybe_device_element] elif not added_other: for device_element, idx in device_element_map.items(): device_element_map[device_element] = idx + 1 device_element_map["other"] = 0 port_map[port] = 0 added_other = True else: port_map[port] = device_element_map["other"] return device_element_map, port_map
[docs] def _get_feedback_label_and_device_element_idx( operation: Operation, port_map: dict[str, int], device_element_map: dict[str, int] ) -> tuple[str, int] | None: """Check if the operation is an acquisition/measure gate with a feedback trigger label.""" if ( len(operation["acquisition_info"]) and (feedback_label := operation["acquisition_info"][0].get("feedback_trigger_label", None)) is not None ): return feedback_label, port_map[operation["acquisition_info"][0]["port"]] if ( operation.valid_gate and (feedback_label := operation["gate_info"].get("feedback_trigger_label", None)) is not None ): return feedback_label, device_element_map[operation["gate_info"]["device_elements"][0]] return None
[docs] def circuit_diagram_matplotlib( schedule: Schedule, figsize: tuple[int, int] | None = None, ax: Axes | None = None, ) -> tuple[Figure | None, Axes]: # to prevent the original input schedule from being modified. schedule = _determine_absolute_timing(deepcopy(schedule), "ideal") device_element_map, port_map = _get_device_element_and_port_map_from_schedule(schedule) if figsize is None: figsize = (10, len(device_element_map)) fig, ax = ps.new_pulse_fig(figsize=figsize, ax=ax) ax.set_title(["name"]) ax.set_aspect("equal") ax.set_ylim(-0.5, len(device_element_map) - 0.5) for y in device_element_map.values(): ax.axhline(y, color=constants.COLOR_DARK_MODE_LINE, linewidth=0.9) # plot the device_element names on the y-axis ax.set_yticks(list(device_element_map.values())) ax.set_yticklabels(device_element_map.keys()) current_diagram_time = 0 last_operation_time = 0 # Stack of (loop start time, loop operation) tuples loop_scopes: list[tuple[int, LoopOperation]] = [] # Stack of (conditional start time, conditional operation) tuples conditional_scopes: list[tuple[int, ConditionalOperation]] = [] # Map from feedback_trigger_label to (thresholded acq time, device_element) tuple feedback_acq_map: dict[str, tuple[int, int]] = {} for abs_time, operation in _walk_schedule(schedule): if isinstance(operation, LoopOperation): loop_scopes.append((current_diagram_time + 1, operation)) elif isinstance(operation, ConditionalOperation): conditional_scopes.append((current_diagram_time + 1, operation)) elif isinstance(operation, Operation): if abs_time > last_operation_time: current_diagram_time += 1 last_operation_time = abs_time # draw_time is a quick fix for displaying simultaneity consistently if a # single operation is simultaneous with a sub-schedule with multiple # operations. draw_time = current_diagram_time else: draw_time = abs_time if ( feedback_label_and_device_element_idx := _get_feedback_label_and_device_element_idx( operation, port_map, device_element_map ) ) is not None: feedback_label, device_element_idx = feedback_label_and_device_element_idx feedback_acq_map[feedback_label] = current_diagram_time, device_element_idx _draw_operation( operation=operation, device_element_map=device_element_map, port_map=port_map, ax=ax, time=draw_time, schedule_resources=schedule.resources, ) elif isinstance(operation, ScheduleBase): pass elif operation == _ControlFlowEnd.LOOP_END: start_time, loop_op = loop_scopes.pop() _draw_loop( ax=ax, device_element_map=device_element_map, port_map=port_map, operation=loop_op, start_time=start_time, end_time=current_diagram_time, ) elif operation == _ControlFlowEnd.CONDI_END: body_start, conditional_op = conditional_scopes.pop() feedback_trigger_label = conditional_op["control_flow_info"]["feedback_trigger_label"] try: measure_time, measure_device_element_idx = feedback_acq_map[feedback_trigger_label] except KeyError as err: raise KeyError( f"Feedback trigger label '{feedback_trigger_label}' not found in " "any preceding Measure or acquisition operation." ) from err _draw_conditional( ax=ax, measure_time=measure_time, measure_device_element_idx=measure_device_element_idx, body=conditional_op.body, body_start=body_start, body_end=current_diagram_time, device_element_map=device_element_map, port_map=port_map, ) ax.set_xlim(-1, current_diagram_time + 1) return fig, ax