# Repository: https://gitlab.com/quantify-os/quantify-scheduler
# Licensed according to the LICENCE file on the main branch
"""Functions for drawing pulse diagrams."""
from __future__ import annotations
import logging
from collections import defaultdict
from dataclasses import dataclass
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import quantify_scheduler.operations.pulse_library as pl
from quantify_core.visualization.SI_utilities import set_xlabel, set_ylabel
from quantify_scheduler.helpers.waveforms import (
exec_waveform_function,
modulate_waveform,
)
from quantify_scheduler.operations.acquisition_library import Acquisition
from quantify_scheduler.operations.control_flow_library import (
ConditionalOperation,
LoopOperation,
)
from quantify_scheduler.schedules.schedule import ScheduleBase
from quantify_scheduler.waveforms import interpolated_complex_waveform
if TYPE_CHECKING:
from quantify_scheduler import CompiledSchedule, Operation, Schedule
from collections import Counter
[docs]
logger = logging.getLogger(__name__)
@dataclass
[docs]
class SampledPulse:
"""Class containing the necessary information to display pulses in a plot."""
@dataclass
[docs]
class SampledAcquisition:
"""Class containing the necessary information to display acquisitions in a plot."""
@dataclass
[docs]
class ScheduledInfo:
"""
Class containing pulse or acquisition info, with some additional information.
This class is used in the schedule sampling process to temporarily hold pulse info
or acquisition info dictionaries, together with some useful information from the
operation and schedulable that they are a part of.
"""
[docs]
op_info: dict[str, Any]
"""Pulse info or acquisition info."""
"""The sum of the ``Schedulable["abs_time"]`` and the ``info["t0"]``."""
"""The name of the operation containing the pulse or acquisition info."""
[docs]
def get_sampled_pulses_from_voltage_offsets(
schedule: Schedule | CompiledSchedule,
offset_infos: dict[str, dict[str, list[ScheduledInfo]]],
x_min: float,
x_max: float,
modulation: Literal["off", "if", "clock"] = "off",
modulation_if: float = 0.0,
sampling_rate: float = 1e9,
sampled_pulses: dict[str, list[SampledPulse]] | None = None,
) -> dict[str, list[SampledPulse]]:
"""
Generate :class:`.SampledPulse` objects from :class:`.VoltageOffset` pulse_info dicts.
This function groups all VoltageOffset operations by port-clock combination and
turns each of those groups of operations into a single SampledPulse. The returned
dictionary contains these SampledPulse objects grouped by port.
Parameters
----------
schedule :
The schedule to render.
offset_infos :
A nested dictionary containing lists of pulse_info dictionaries. The outer
dictionary's keys are ports, and the inner dictionary's keys are clocks.
x_min :
The left limit of the x-axis of the intended plot.
x_max :
The right limit of the x-axis of the intended plot.
modulation :
Determines if modulation is included in the visualization.
modulation_if :
Modulation frequency used when modulation is set to "if".
sampling_rate :
Number of samples per second to draw when drawing modulated pulses.
sampled_pulses :
An already existing dictionary (same type as the return value). If provided,
this dictionary will be extended with the SampledPulse objects created in this
function.
Returns
-------
dict[str, list[SampledPulse]] :
SampledPulse objects grouped by port.
"""
if sampled_pulses is None:
sampled_pulses = defaultdict(list)
for port, offset_info in offset_infos.items():
for clock, info_list in offset_info.items():
time: list[float] = []
signal: list[float] = []
for info in info_list:
if len(time) > 0:
# Each offset is a point, so the previous offset is extended to a
# line before adding the next.
# Subtract a small number from the time so that interpolation (in
# sum_waveforms) looks correct visually.
time.append(info.time - 0.01 / sampling_rate)
signal.append(signal[-1])
time.append(info.time)
signal.append(info.op_info["offset_path_I"] + 1j * info.op_info["offset_path_Q"])
if signal[-1] != 0:
# If the offset is not 0, let it run to the end of the schedule.
time.append(schedule.duration / schedule.repetitions)
signal.append(signal[-1])
# Filter in time: Keep one point before and one point after the limit (if
# possible).
start_idx = next(i for i, v in enumerate(time) if v > x_min)
if start_idx > 0:
start_idx -= 1
try:
stop_idx = next(i for i, v in enumerate(time) if v > x_max) + 1
except StopIteration:
stop_idx = len(time)
time = time[start_idx:stop_idx]
signal = signal[start_idx:stop_idx]
time = np.array(time)
signal = np.array(signal)
if modulation != "off":
new_time = np.linspace(
time[0], time[-1], round((time[-1] - time[0]) * sampling_rate) + 1
)
signal = interpolated_complex_waveform(t=new_time, samples=signal, t_samples=time)
time = new_time
if modulation == "clock":
signal = modulate_waveform(time, signal, schedule.resources[clock]["freq"])
elif modulation == "if":
signal = modulate_waveform(time, signal, modulation_if)
signal = np.real_if_close(signal)
sampled_pulses[port].append(
SampledPulse(
time=np.array(time),
signal=np.array(signal),
label=f"VoltageOffset, clock {clock}",
)
)
return sampled_pulses
[docs]
def get_sampled_pulses(
schedule: Schedule | CompiledSchedule,
pulse_infos: dict[str, list[ScheduledInfo]],
x_min: float,
x_max: float,
modulation: Literal["off", "if", "clock"] = "off",
modulation_if: float = 0.0,
sampling_rate: float = 1e9,
sampled_pulses: dict[str, list[SampledPulse]] | None = None,
) -> dict[str, list[SampledPulse]]:
"""
Generate :class:`.SampledPulse` objects from pulse_info dicts.
This function creates a SampledPulse for each pulse_info dict. The pulse_info must
contain a valid ``"wf_func"``.
Parameters
----------
schedule :
The schedule to render.
pulse_infos :
A dictionary from ports to lists of pulse_info dictionaries.
x_min :
The left limit of the x-axis of the intended plot.
x_max :
The right limit of the x-axis of the intended plot.
modulation :
Determines if modulation is included in the visualization.
modulation_if :
Modulation frequency used when modulation is set to "if".
sampling_rate :
The time resolution used to sample the schedule in Hz.
sampled_pulses :
An already existing dictionary (same type as the return value). If provided,
this dictionary will be extended with the SampledPulse objects created in this
function.
Returns
-------
dict[str, list[SampledPulse]] :
SampledPulse objects grouped by port.
"""
if sampled_pulses is None:
sampled_pulses = defaultdict(list)
for port, info_list in pulse_infos.items():
for info in info_list:
t0 = info.time
t1 = t0 + info.op_info["duration"]
if t1 < x_min or t0 > x_max:
continue
t0 = max(x_min, t0)
t1 = min(x_max, t1)
t = np.arange(t0, t1 - 0.5 / sampling_rate, 1 / sampling_rate)
if len(t) == 0:
continue
if (
info.op_info["wf_func"]
== "quantify_scheduler.waveforms.interpolated_complex_waveform"
):
info.op_info["t_samples"] = (
np.asarray(info.op_info["t_samples"]) - info.op_info["t_samples"][0]
)
# Add the final datapoint for nicer plots
t = np.append(t, t[-1] + 0.99 / sampling_rate)
waveform = exec_waveform_function(
wf_func=info.op_info["wf_func"],
t=t - t[0],
pulse_info=info.op_info,
)
# Add 0 amplitude points before and after the pulse such that interpolation
# in sum_waveforms looks correct visually.
t = np.concatenate(
(
[t[0] - 0.01 / sampling_rate],
t,
[t[-1] + 0.01 / sampling_rate],
)
)
waveform = np.concatenate(([0], waveform, [0]))
if modulation == "clock":
freq = schedule.resources[info.op_info["clock"]]["freq"]
elif modulation == "if":
freq = modulation_if
elif modulation == "off":
freq = 0
else:
raise ValueError(f"Unknown modulation {modulation}")
wf_func_name = info.op_info["wf_func"].rsplit(".", maxsplit=1)[-1]
is_linear = wf_func_name in ["square", "ramp"]
if freq == 0 and is_linear:
# In certain case only 4 points are needed
waveform = np.concatenate((waveform[:2], waveform[-2:]))
t = np.concatenate((t[:2], t[-2:]))
else:
waveform = modulate_waveform(t, waveform, freq)
waveform = np.real_if_close(waveform)
label = f"{info.op_name}, clock {info.op_info['clock']}"
sampled_pulses[port].append(SampledPulse(time=t, signal=waveform, label=label))
return sampled_pulses
[docs]
def get_sampled_acquisitions(
acq_infos: dict[str, list[ScheduledInfo]],
) -> dict[str, list[SampledAcquisition]]:
"""
Generate :class:`.SampledAcquisition` objects from acquisition_info dicts.
Parameters
----------
acq_infos :
A dictionary from ports to lists of acquisition_info dictionaries.
Returns
-------
dict[str, list[SampledAcquisition]] :
SampledAcquisition objects grouped by port.
"""
sampled_acqs: dict[str, list[SampledAcquisition]] = defaultdict(list)
for port, info_list in acq_infos.items():
for info in info_list:
sampled_acqs[port].append(
SampledAcquisition(
t0=info.time, duration=info.op_info["duration"], label=info.op_name
)
)
return sampled_acqs
[docs]
def merge_pulses_and_offsets(operations: list[SampledPulse]) -> SampledPulse:
"""
Combine multiple ``SampledPulse`` objects by interpolating the ``signal`` at the
``time`` points used by all pulses together, and then summing the result.
Interpolation outside of a ``SampledPulse.time`` array results in 0 for that pulse.
"""
result_time = np.sort(np.concatenate([op.time for op in operations]))
if len(operations) > 3:
# If the label would become too large, opt for this short form:
label = f"{len(operations)} operations"
else:
label = "+\n".join(op.label for op in operations)
return SampledPulse(
time=result_time,
signal=sum(
np.interp(result_time, op.time, op.signal, left=0.0, right=0.0) for op in operations
), # type: ignore
label=label,
)
[docs]
def sample_schedule(
schedule: Schedule | CompiledSchedule,
port_list: list[str] | None = None,
modulation: Literal["off", "if", "clock"] = "off",
modulation_if: float = 0.0,
sampling_rate: float = 1e9,
x_range: tuple[float, float] = (-np.inf, np.inf),
combine_waveforms_on_same_port: bool = False,
) -> dict[str, tuple[list[SampledPulse], list[SampledAcquisition]]]:
"""
Generate :class:`.SampledPulse` and :class:`.SampledAcquisition` objects grouped by
port.
This function generates SampledPulse objects for all pulses and voltage offsets
defined in the Schedule, and SampledAcquisition for all acquisitions defined in the
Schedule.
Parameters
----------
schedule :
The schedule to render.
port_list :
A list of ports to show. if set to ``None`` (default), it will use all ports in
the Schedule.
modulation :
Determines if modulation is included in the visualization.
modulation_if :
Modulation frequency used when modulation is set to "if".
sampling_rate :
The time resolution used to sample the schedule in Hz.
x_range :
The range of the x-axis that is plotted, given as a tuple (left limit, right
limit). This can be used to reduce memory usage when plotting a small section of
a long pulse sequence. By default (-np.inf, np.inf).
combine_waveforms_on_same_port :
By default False. If True, combines all waveforms on the same port into one
single waveform. The resulting waveform is the sum of all waveforms on that
port (small inaccuracies may occur due to floating point approximation). If
False, the waveforms are shown individually.
Returns
-------
dict[str, tuple[list[SampledPulse], list[SampledAcquisition]]] :
SampledPulse and SampledAcquisition objects grouped by port.
"""
offset_infos: dict[str, dict[str, list[ScheduledInfo]]] = defaultdict(lambda: defaultdict(list))
pulse_infos: dict[str, list[ScheduledInfo]] = defaultdict(list)
acq_infos: dict[str, list[ScheduledInfo]] = defaultdict(list)
_extract_schedule_infos(
schedule,
port_list,
0,
offset_infos,
pulse_infos,
acq_infos,
)
x_min, x_max = x_range
sampled_pulses = get_sampled_pulses_from_voltage_offsets(
schedule=schedule,
offset_infos=offset_infos,
x_min=x_min,
x_max=x_max,
modulation=modulation,
modulation_if=modulation_if,
)
sampled_pulses = get_sampled_pulses(
schedule=schedule,
pulse_infos=pulse_infos,
x_min=x_min,
x_max=x_max,
modulation=modulation,
modulation_if=modulation_if,
sampling_rate=sampling_rate,
sampled_pulses=sampled_pulses,
)
if combine_waveforms_on_same_port:
for port, pulses in sampled_pulses.copy().items():
sampled_pulses[port] = [merge_pulses_and_offsets(pulses)]
sampled_acqs = get_sampled_acquisitions(acq_infos)
sampled_all: dict[str, tuple[list[SampledPulse], list[SampledAcquisition]]] = {}
for port in chain(sampled_pulses, sampled_acqs):
sampled_all[port] = (sampled_pulses[port], sampled_acqs[port])
return sampled_all
[docs]
def pulse_diagram_plotly(
sampled_pulses_and_acqs: dict[str, tuple[list[SampledPulse], list[SampledAcquisition]]],
title: str = "Pulse diagram",
fig_ch_height: float = 300,
fig_width: float = 1000,
) -> go.Figure:
"""
Produce a plotly visualization of the pulses used in the schedule.
Parameters
----------
sampled_pulses_and_acqs :
SampledPulse and SampledAcquisition objects grouped by port.
title :
Plot title.
fig_ch_height :
Height for each channel subplot in px.
fig_width :
Width for the figure in px.
Returns
-------
:class:`plotly.graph_objects.Figure` :
the plot
"""
n_rows = len(sampled_pulses_and_acqs)
fig = make_subplots(rows=n_rows, cols=1, shared_xaxes=True, vertical_spacing=0.02)
fig.update_layout(
height=fig_ch_height * n_rows,
width=fig_width,
title=title,
showlegend=False,
)
colors = px.colors.qualitative.Plotly
col_idx = 0
legendgroup = -1
for i, (port, (pulses, acqs)) in enumerate(sampled_pulses_and_acqs.items()):
row = i + 1
for pulse in pulses:
legendgroup += 1
fig.add_trace(
go.Scatter(
x=pulse.time,
y=pulse.signal.real,
mode="lines",
name=pulse.label,
legendgroup=legendgroup,
showlegend=True,
line_color=colors[col_idx],
fill="tozeroy",
hoverinfo="x+y+name",
hoverlabel={"namelength": -1},
),
row=row,
col=1,
)
col_idx = (col_idx + 1) % len(colors)
if np.iscomplexobj(pulse.signal):
fig.add_trace(
go.Scatter(
x=pulse.time,
y=pulse.signal.imag,
mode="lines",
name=f"{pulse.label} (imag)",
legendgroup=legendgroup,
showlegend=True,
line_color="darkgrey",
fill="tozeroy",
hoverinfo="x+y+name",
hoverlabel={"namelength": -1},
),
row=row,
col=1,
)
fig.update_xaxes(
row=row,
col=1,
tickformat=".2s",
hoverformat=".3s",
ticksuffix="s",
showgrid=True,
)
fig.update_yaxes(
row=row,
col=1,
tickformat=".2s",
hoverformat=".3s",
ticksuffix="V",
title=port,
autorange=True,
)
for acq in acqs:
yref = f"y{row} domain" if row != 1 else "y domain"
fig.add_trace(
go.Scatter(
x=[acq.t0, acq.t0 + acq.duration],
y=[0, 0],
name=acq.label,
mode="markers",
marker=dict(
size=15,
color="rgba(0,0,0,.25)",
symbol=["arrow-bar-left", "arrow-bar-right"],
),
),
row=row,
col=1,
)
fig.add_shape(
type="rect",
xref="x",
yref=yref,
x0=acq.t0,
y0=0,
x1=acq.t0 + acq.duration,
y1=1,
name=acq.label,
line=dict(
color="rgba(0,0,0,0)",
width=3,
),
fillcolor="rgba(255,0,0,0.1)",
layer="below",
)
fig.update_xaxes(
row=row,
col=1,
tickformat=".2s",
hoverformat=".3s",
ticksuffix="s",
showgrid=True,
)
fig.update_yaxes(
row=row,
col=1,
tickformat=".2s",
hoverformat=".3s",
ticksuffix="V",
title=port,
autorange=True,
)
fig.update_xaxes(
row=n_rows,
col=1,
title="Time",
tickformatstops=[
dict(dtickrange=[None, 1e-9], value=".10s"),
dict(dtickrange=[1e-9, 1e-6], value=".7s"),
dict(dtickrange=[1e-6, 1e-3], value=".4s"),
],
ticksuffix="s",
)
return fig
[docs]
def deduplicate_legend_handles_labels(ax: mpl.axes.Axes) -> None:
"""
Remove duplicate legend entries.
See also: https://stackoverflow.com/a/13589144
"""
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(by_label.values(), by_label.keys())
[docs]
def _multiset_intersection_with_duplicates(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
"""
Compute the multiset intersection of two arrays, preserving duplicates.
The result contains each common element repeated the minimum number of times
it appears in both input arrays, sorted in ascending order.
Parameters
----------
array1 :
First input array.
array2 :
Second input array.
Returns
-------
np.ndarray
Sorted array of elements present in both arrays, including duplicates.
"""
counter1 = Counter(array1)
counter2 = Counter(array2)
intersection_counter = counter1 & counter2 # min counts for each element
result_list = []
for element, count in intersection_counter.items():
result_list.extend([element] * count)
return np.array(sorted(result_list))
[docs]
def _merge_signal_and_offset(
signal_x: np.ndarray,
signal_y: np.ndarray,
offset_x: np.ndarray,
offset_y: np.ndarray,
scale: float = 1e9,
) -> tuple[np.ndarray, np.ndarray]:
"""
Merge signal and stepwise offset into a unified time and value vector,
preserving exact offsets and only interpolating where necessary.
Time is returned in seconds.
Parameters
----------
signal_x
Time points of the signal.
signal_y
Signal values at `signal_x`.
offset_x
Time points of the offset.
offset_y
Offset values at `offset_x`.
scale
Scaling factor to convert time to integer domain (default is 1e9 for nanoseconds).
Returns
-------
tuple of np.ndarray
Tuple containing:
- merged time vector in seconds,
- merged signal + offset values.
"""
signal_x = np.asarray(signal_x)
signal_y = np.asarray(signal_y)
offset_x = np.asarray(offset_x)
offset_y = np.asarray(offset_y)
# Convert time values to scaled integers for exact comparison
signal_x_int = np.round(signal_x * scale).astype(np.int64)
offset_x_int = np.round(offset_x * scale).astype(np.int64)
# Detect duplicated time points between signal and offset
common_time_int = _multiset_intersection_with_duplicates(signal_x_int, offset_x_int)
offset_unique_int = offset_x_int[~np.isin(offset_x_int, common_time_int)]
# Merge signal and offset time points (excluding offset duplicates)
all_time_int = np.sort(np.concatenate([signal_x_int, offset_unique_int]))
# Initialize merged signal and offset arrays
merged_signal_y = np.zeros(len(all_time_int), dtype=signal_y.dtype)
merged_offset_y = np.zeros(len(all_time_int), dtype=offset_y.dtype)
# Assign signal values
signal_indices = np.where(np.isin(all_time_int, signal_x_int))[0]
merged_signal_y[signal_indices] = signal_y
# Assign and fill stepwise offset values
offset_indices = np.where(np.isin(all_time_int, offset_x_int))[0]
if len(offset_indices) > len(offset_y):
offset_indices = offset_indices[: len(offset_y)]
merged_offset_y[offset_indices] = offset_y
# Fill gaps in offset with previous nonzero step value
for i in range(len(offset_indices) - 1):
start = offset_indices[i]
end = offset_indices[i + 1]
val = merged_offset_y[start]
merged_offset_y[start:end] = val
# Extend last value forward if any
if offset_indices.size > 0:
merged_offset_y[offset_indices[-1] :] = merged_offset_y[offset_indices[-1]]
# Convert time back to float in seconds
merged_time = all_time_int.astype(np.float64) / scale
merged_values = merged_signal_y + merged_offset_y
return merged_time, merged_values
[docs]
def _clean_flat_artifacts(x_data: np.ndarray, y_data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Remove artifacts where 4 identical consecutive x-values occur by discarding
the two middle points in each block.
Args:
x_data: The x-values of the signal.
y_data: The y-values of the signal.
Returns:
A tuple of filtered x and y arrays with flat artifacts removed.
"""
scale = 1e9
x_data = np.asarray(x_data)
x_data = np.round(x_data * scale).astype(np.int64)
y_data = np.asarray(y_data)
keep = np.ones(len(x_data), dtype=bool)
# Detect sequences of 4 repeated x values
repeated = (
(x_data[:-3] == x_data[1:-2]) & (x_data[:-3] == x_data[2:-1]) & (x_data[:-3] == x_data[3:])
)
idx_start = np.where(repeated)[0]
for i in idx_start:
keep[i + 1] = False
keep[i + 2] = False
return x_data[keep] / scale, y_data[keep]
[docs]
def plot_single_subplot_mpl(
sampled_schedule: dict[str, list[SampledPulse]],
ax: mpl.axes.Axes | None = None,
title: str = "Pulse diagram",
) -> tuple[mpl.figure.Figure, mpl.axes.Axes]:
"""
Plot all pulses for all ports in the same subplot using Matplotlib.
Pulses in the same port have the same color and legend entry, and each port
has its own legend entry.
Parameters
----------
sampled_schedule :
Dictionary mapping port names to lists of SampledPulse objects.
ax :
Existing axes to draw on. If None, a new figure and axes will be created.
title : str, optional
Title of the plot (default is "Pulse diagram").
Returns
-------
fig :
The matplotlib figure object.
ax :
The axes used for the subplot.
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
ax.set_title(title)
for i, (port, pulses) in enumerate(sampled_schedule.items()):
time, value, offset_time, offset_value = _extract_signal_component(pulses)
time, value = _clean_flat_artifacts(time, value)
time, value = _merge_signal_and_offset(time, value, offset_time, offset_value)
ax.plot(time, value.real, color=f"C{i}", label=f"port {port}")
ax.fill_between(time, value.real, color=f"C{i}", alpha=0.2)
if np.any(np.imag(value) != 0):
ax.plot(time, value.imag, color=f"C{i}", label=f"port {port}", linestyle="--")
ax.fill_between(time, value.imag, color=f"C{i}", alpha=0.2)
deduplicate_legend_handles_labels(ax)
set_xlabel(label="Time", unit="s", axis=ax)
set_ylabel(label=r"$\dfrac{V}{V_{max}}$", unit="", axis=ax)
ax.set_title(title)
return fig, ax
[docs]
def plot_multiple_subplots_mpl(
sampled_schedule: dict[str, list[SampledPulse]],
title: str = "Pulse diagram",
) -> tuple[mpl.figure.Figure, list[mpl.axes.Axes]]:
"""
Plot pulses in a different subplot for each port in the sampled schedule.
For each subplot, each different type of pulse gets its own color and legend
entry.
Parameters
----------
sampled_schedule :
Dictionary that maps each used port to the sampled pulses played on that port.
title :
Plot title.
Returns
-------
fig :
A matplotlib :class:`matplotlib.figure.Figure` containing the subplots.
axs :
An array of Axes objects belonging to the Figure.
"""
fig, axs = plt.subplots(len(sampled_schedule), 1, sharex=True)
for i, (port, data) in enumerate(sampled_schedule.items()):
# This automatically creates a label-to-color map as the plots get created.
color: dict[str, str] = defaultdict(lambda: f"C{len(color)}") # noqa: B023 false positive
for pulse in data:
axs[i].plot(
pulse.time,
pulse.signal.real,
color=color[pulse.label],
label=pulse.label,
)
axs[i].fill_between(pulse.time, pulse.signal.real, color=color[pulse.label], alpha=0.2)
if np.iscomplexobj(pulse.signal):
axs[i].plot(
pulse.time,
pulse.signal.imag,
color=color[pulse.label],
linestyle="--",
label=f"{pulse.label} (imag)",
)
axs[i].fill_between(
pulse.time, pulse.signal.imag, color=color[pulse.label], alpha=0.4
)
deduplicate_legend_handles_labels(axs[i])
set_ylabel(label=f"port {port}\nAmplitude", unit="V", axis=axs[i])
set_xlabel(label="Time", unit="s", axis=axs[-1])
# Make the figure taller if y-labels overlap.
fig.set_figheight(max(4.8 * len(axs) / 3, 4.8))
axs[0].set_title(title)
return fig, axs
[docs]
def pulse_diagram_matplotlib(
sampled_pulses_and_acqs: dict[str, tuple[list[SampledPulse], list[SampledAcquisition]]],
multiple_subplots: bool = False,
ax: mpl.axes.Axes | None = None,
title: str = "Pulse diagram",
) -> tuple[mpl.figure.Figure, mpl.axes.Axes | list[mpl.axes.Axes]]:
"""
Plots a schedule using matplotlib.
Parameters
----------
sampled_pulses_and_acqs :
SampledPulse and SampledAcquisition objects grouped by port.
multiple_subplots :
Plot the pulses for each port on a different subplot if True, else plot
everything in one subplot. By default False. When using just one
subplot, the pulses are colored according to the port on which they
play. For multiple subplots, each pulse has its own
color and legend entry.
ax :
Axis onto which to plot. If ``None`` (default), this is created within the
function. By default None.
title :
Plot title.
Returns
-------
fig :
A matplotlib :class:`matplotlib.figure.Figure` containing the subplot(s).
ax :
The Axes object belonging to the Figure, or an array of Axes if
``multiple_subplots=True``.
"""
pulses = {port: pulses for port, (pulses, _) in sampled_pulses_and_acqs.items()}
if len(pulses) == 0:
raise RuntimeError(
"Attempting to sample schedule, "
"but the schedule does not contain any `pulse_info`. "
"Please verify that the schedule has been populated and "
"device compilation has been performed."
)
if not multiple_subplots or len(pulses) == 1:
return plot_single_subplot_mpl(sampled_schedule=pulses, ax=ax, title=title)
return plot_multiple_subplots_mpl(sampled_schedule=pulses, title=title)
[docs]
def get_window_operations(
schedule: Schedule,
) -> list[tuple[float, float, Operation]]:
r"""
Return a list of all :class:`.WindowOperation`\s with start and end time.
Parameters
----------
schedule:
Schedule to use.
Returns
-------
:
List of all window operations in the schedule.
"""
window_operations = []
for _, schedulable in enumerate(schedule.schedulables.values()):
operation = schedule.operations[schedulable["operation_id"]]
if isinstance(operation, pl.WindowOperation):
for pulse_info in operation["pulse_info"]:
t0 = schedulable["abs_time"] + pulse_info["t0"]
t1 = t0 + pulse_info["duration"]
window_operations.append((t0, t1, operation))
return window_operations
[docs]
def plot_window_operations(
schedule: Schedule,
ax: mpl.axes.Axes | None = None,
time_scale_factor: float = 1,
) -> tuple[mpl.figure.Figure, mpl.axes.Axes]:
"""
Plot the window operations in a schedule.
Parameters
----------
schedule:
Schedule from which to plot window operations.
ax:
Axis handle to use for plotting.
time_scale_factor:
Used to scale the independent data before using as data for the
x-axis of the plot.
Returns
-------
fig
The matplotlib figure.
ax
The matplotlib ax.
"""
if ax is None:
ax = plt.gca()
window_operations = get_window_operations(schedule)
cmap = mpl.colormaps.get_cmap("jet")
for idx, (t0, t1, operation) in enumerate(window_operations):
window_name = operation.window_name
logger.debug(f"plot_window_operations: window {window_name}: {t0}, {t1}")
colormap = cmap(idx / (1 + len(window_operations)))
label = window_name
ax.axvspan(
time_scale_factor * t0,
time_scale_factor * (t1),
alpha=0.2,
color=colormap,
label=label,
)
return ax.get_figure(), ax
[docs]
def plot_acquisition_operations(
schedule: Schedule, ax: mpl.axes.Axes | None = None, **kwargs
) -> list[Any]:
"""
Plot the acquisition operations in a schedule.
Parameters
----------
schedule:
Schedule from which to plot window operations.
ax:
Axis handle to use for plotting.
kwargs:
Passed to matplotlib plotting routine
Returns
-------
:
List of handles
"""
if ax is None:
ax = plt.gca()
handles_list = []
for idx, schedulable in enumerate(schedule.schedulables.values()):
_ = idx # unused variable
operation = schedule.operations[schedulable["operation_id"]]
if isinstance(operation, Acquisition):
t0 = schedulable["abs_time"] + operation.data["acquisition_info"][0]["t0"]
t1 = t0 + operation.duration
handle = ax.axvspan(t0, t1, **kwargs)
handles_list.append(handle)
return handles_list