Source code for quantify_scheduler.backends.corrections

# Repository: https://gitlab.com/quantify-os/quantify-scheduler
# Licensed according to the LICENCE file on the main branch
"""Pulse and acquisition corrections for hardware compilation."""
from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING, Any, Generator

import numpy as np

from quantify_scheduler.backends.types.common import (
    HardwareCompilationConfig,
    HardwareDistortionCorrection,
    SoftwareDistortionCorrection,
)
from quantify_scheduler.helpers.importers import import_python_object_from_string
from quantify_scheduler.helpers.schedule import _extract_port_clocks_used
from quantify_scheduler.helpers.waveforms import get_waveform
from quantify_scheduler.operations.control_flow_library import ControlFlowOperation
from quantify_scheduler.operations.pulse_library import NumericalPulse
from quantify_scheduler.schedules.schedule import Schedule, ScheduleBase

if TYPE_CHECKING:
    from quantify_scheduler.operations.operation import Operation

[docs] logger = logging.getLogger(__name__)
[docs] def determine_relative_latency_corrections( hardware_cfg: HardwareCompilationConfig | dict[str, Any], schedule: Schedule | None = None, ) -> dict[str, float]: """ Generates the latency configuration dict for all port-clock combinations that are present in the schedule (or in the hardware config, if an old-style zhinst config is passed). This is done by first setting unspecified latency corrections to zero, and then subtracting the minimum latency from all latency corrections. """ def _extract_port_clocks(hardware_cfg: dict[str, Any]) -> Generator: """ Extracts all port-clock combinations that are present in a hardware configuration. Based on: https://stackoverflow.com/questions/9807634/find-all-occurrences-of-a-key-in-nested-dictionaries-and-lists. """ if hasattr(hardware_cfg, "items"): for k, v in hardware_cfg.items(): if k == "port": port_clock = f'{hardware_cfg["port"]}-{hardware_cfg["clock"]}' yield port_clock elif isinstance(v, dict): for port_clock in _extract_port_clocks(v): yield port_clock elif isinstance(v, list): for d in v: for port_clock in _extract_port_clocks(d): yield port_clock if isinstance(hardware_cfg, HardwareCompilationConfig): if schedule is None: raise ValueError( f"{determine_relative_latency_corrections.__name__} requires the `schedule` argument if `hardware_cfg` is a `HardwareCompilationConfig`." ) port_clocks = [ "-".join(map(str, port_clock)) for port_clock in _extract_port_clocks_used(schedule) ] latency_corrections = hardware_cfg.hardware_options.latency_corrections else: # Support for legacy hardware config dict (zhinst backend only) port_clocks = _extract_port_clocks(hardware_cfg=hardware_cfg) latency_corrections = hardware_cfg.get("latency_corrections") if latency_corrections is None: return {} relative_latencies = {} for port_clock in port_clocks: # Set unspecified latency corrections to zero to avoid ending up with # negative latency corrections after subtracting minimum relative_latencies[port_clock] = latency_corrections.get(port_clock, 0) # Subtract lowest value to ensure minimal latency is used and offset the latency # corrections to be relative to the minimum. Note that this supports negative delays # (which is useful for calibrating) minimum_of_latency_corrections = min(relative_latencies.values(), default=0) for port_clock, latency_at_port_clock in relative_latencies.items(): relative_latencies[port_clock] = ( latency_at_port_clock - minimum_of_latency_corrections ) return relative_latencies
[docs] def distortion_correct_pulse( pulse_data: dict[str, Any], distortion_correction: SoftwareDistortionCorrection, ) -> NumericalPulse: """ Sample pulse and apply filter function to the sample to distortion correct it. Parameters ---------- pulse_data Definition of the pulse. distortion_correction The distortion_correction configuration for this pulse. Returns ------- : The sampled, distortion corrected pulse wrapped in a ``NumericalPulse``. """ waveform_data = get_waveform( pulse_info=pulse_data, sampling_rate=distortion_correction.sampling_rate ) filter_func = import_python_object_from_string(distortion_correction.filter_func) kwargs = { distortion_correction.input_var_name: waveform_data, **distortion_correction.kwargs, } corrected_waveform_data = filter_func(**kwargs) if ( distortion_correction.clipping_values is not None and len(distortion_correction.clipping_values) == 2 ): corrected_waveform_data = np.clip( corrected_waveform_data, distortion_correction.clipping_values[0], distortion_correction.clipping_values[1], ) if corrected_waveform_data.size == 1: # Interpolation requires two sample points corrected_waveform_data = np.append( corrected_waveform_data, corrected_waveform_data[-1] ) corrected_pulse = NumericalPulse( samples=corrected_waveform_data, t_samples=np.linspace( start=0, stop=pulse_data["duration"], num=corrected_waveform_data.size ), port=pulse_data["port"], clock=pulse_data["clock"], t0=pulse_data["t0"], ) return corrected_pulse
[docs] def _is_distortion_correctable(operation: Operation) -> bool: """Checks whether distortion corrections can be applied to the given operation.""" return operation.valid_pulse and not operation.has_voltage_offset
[docs] def apply_software_distortion_corrections( # noqa: PLR0912 operation: Operation | Schedule, distortion_corrections: dict ) -> Operation | Schedule | None: """ Apply distortion corrections to operations in the schedule. Defined via the hardware configuration file, example: .. code-block:: "distortion_corrections": { "q0:fl-cl0.baseband": { "filter_func": "scipy.signal.lfilter", "input_var_name": "x", "kwargs": { "b": [0.0, 0.5, 1.0], "a": [1] }, "clipping_values": [-2.5, 2.5] } } Clipping values are the boundaries to which the corrected pulses will be clipped, upon exceeding, these are optional to supply. For pulses in need of correcting (indicated by their port-clock combination) we are **only** replacing the dict in ``"pulse_info"`` associated to that specific pulse. This means that we can have a combination of corrected (i.e., pre-sampled) and uncorrected pulses in the same operation. Note that we are **not** updating the ``"operation_id"`` key, used to reference the operation from schedulables. Parameters ---------- operation The operation that contains operations that are to be distortion corrected. Note, this function updates the operation. distortion_corrections The distortion_corrections configuration of the setup. Returns ------- : The new operation with distortion corrected operations, if it needs to be replaced. If it doesn't need to be replaced in the schedule or control flow, it returns ``None``. Warns ----- RuntimeWarning If distortion correction can not be applied to the type of Operation in the schedule. Raises ------ KeyError when elements are missing in distortion correction config for a port-clock combination. KeyError when clipping values are supplied but not two values exactly, min and max. """ if isinstance(operation, ScheduleBase): for inner_operation_id in operation.operations.keys(): replacing_operation = apply_software_distortion_corrections( operation.operations[inner_operation_id], distortion_corrections ) if replacing_operation is not None: operation.operations[inner_operation_id] = replacing_operation return None elif isinstance(operation, ControlFlowOperation): replacing_operation = apply_software_distortion_corrections( operation.body, distortion_corrections ) if replacing_operation is not None: operation.body = replacing_operation return None else: substitute_operation = None for pulse_info_idx, pulse_data in enumerate(operation.data["pulse_info"]): portclock_key = f"{pulse_data['port']}-{pulse_data['clock']}" if portclock_key in distortion_corrections: correction_cfg = distortion_corrections[portclock_key] if isinstance(correction_cfg, (HardwareDistortionCorrection, list)): continue if not _is_distortion_correctable(operation): warnings.warn( f"Schedule contains an operation, for which distortion " f"correction is not implemented. Please either replace the " f"operation, or omit the distortion correction setting for " f"this port in order to suppress this warning. Offending " f"operation: {operation}", RuntimeWarning, ) continue # Zhinst support (still uses old hw dict) if not isinstance( correction_cfg, SoftwareDistortionCorrection ) and not isinstance(correction_cfg, list): try: correction_type = correction_cfg.get( "correction_type", "software" ) except AttributeError: correction_type = correction_cfg[0].get( "correction_type", "software" ) if correction_type != "software": continue corrected_pulse = distortion_correct_pulse( pulse_data=pulse_data, distortion_correction=SoftwareDistortionCorrection.model_validate( correction_cfg ), ) operation.data["pulse_info"][pulse_info_idx] = corrected_pulse.data[ "pulse_info" ][0] if pulse_info_idx == 0: substitute_operation = corrected_pulse # Convert to operation-type of first entry in pulse_info, # required as first entry in pulse_info is used to generate signature in __str__ if substitute_operation is not None: substitute_operation.data["pulse_info"] = operation.data["pulse_info"] return substitute_operation return None