# Repository: https://gitlab.com/quantify-os/quantify-core
# Licensed according to the LICENCE file on the main branch
"""Module containing the MeasurementControl."""
from __future__ import annotations
import itertools
import math
import signal
import tempfile
import threading
import time
import types
from collections.abc import Iterable
from functools import reduce
from itertools import chain
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Hashable,
Dict,
Literal,
Optional,
Protocol,
Sequence,
TypeVar,
cast,
)
import adaptive
import numpy as np
from filelock import FileLock
from qcodes import validators as vals
from qcodes.instrument import Instrument, InstrumentChannel
from qcodes.parameters import InstrumentRefParameter, ManualParameter
from tqdm.auto import tqdm
from typing_extensions import Self
from quantify_core import __version__ as _quantify_version
from quantify_core.data.experiment import QuantifyExperiment
from quantify_core.data.handling import (
DATASET_NAME,
_is_uniformly_spaced_array,
create_exp_folder,
grow_dataset,
initialize_dataset,
snapshot,
trim_dataset,
)
from quantify_core.measurement.types import Gettable, Settable, is_batched
from quantify_core.utilities.general import call_if_has_method
import xarray as xr
import logging
if TYPE_CHECKING:
from numpy.typing import NDArray
from xarray import Dataset
logger = logging.getLogger(__name__)
# Intended for plotting monitors that run in separate processes
_DATASET_LOCKS_DIR = Path(tempfile.gettempdir())
[docs]
class MeasurementControl(Instrument): # pylint: disable=too-many-instance-attributes
"""
Instrument responsible for controlling the data acquisition loop.
MeasurementControl (MC) is based on the notion that every experiment consists of
the following steps:
1. Set some parameter(s) (settable_pars)
2. Measure some other parameter(s) (gettable_pars)
3. Store the data.
Example:
.. code-block:: python
meas_ctrl.settables(mw_source1.freq)
meas_ctrl.setpoints(np.arange(5e9, 5.2e9, 100e3))
meas_ctrl.gettables(pulsar_QRM.signal)
dataset = meas_ctrl.run(name='Frequency sweep')
MC exists to enforce structure on experiments. Enforcing this structure allows:
- Standardization of data storage.
- Providing basic real-time visualization.
MC imposes minimal constraints and allows:
- Iterative loops, experiments in which setpoints are processed step by step.
- Batched loops, experiments in which setpoints are processed in batches.
- Adaptive loops, setpoints are determined based on measured values.
.. seealso:: :ref:`Measurement Control How-To <howto-measurement-control>`
Parameters
----------
name
name of this instrument.
"""
def __init__(self, name: str):
super().__init__(name=name)
# Parameters are attributes included in logging and which the user can change.
self.lazy_set = ManualParameter(
vals=vals.Bool(),
initial_value=False,
name="lazy_set",
instrument=self,
)
"""If set to ``True``, only set any settable if the setpoint differs
from the previous setpoint. Note that this parameter is overridden by the
``lazy_set`` argument passed to the :meth:`.run` and :meth:`.run_adaptive`
methods."""
self.verbose = ManualParameter(
vals=vals.Bool(),
initial_value=True,
instrument=self,
name="verbose",
)
"""If set to ``True``, prints to ``std_out`` during experiments."""
self.on_progress_callback = ManualParameter(
vals=vals.Callable(),
instrument=self,
name="on_progress_callback",
)
"""A callback to communicate progress. This should be a callable accepting
floats between 0 and 100 indicating the percentage done."""
self.instr_plotmon = InstrumentRefParameter(
vals=vals.MultiType(vals.Strings(), vals.Enum(None)),
instrument=self,
name="instr_plotmon",
)
"""Instrument responsible for live plotting. Can be set to ``None`` to disable
live plotting."""
self.update_interval = ManualParameter(
initial_value=0.5,
vals=vals.Numbers(min_value=0.1),
instrument=self,
name="update_interval",
)
"""Interval for updates during the data acquisition loop, every time more than
:attr:`.update_interval` time has elapsed when acquiring new data points, data
is written to file (and the live monitoring detects updated)."""
# Add experiment_data submodule to allow user to save custom metadata
experiment_data = InstrumentChannel(self, "experiment_data")
self.add_submodule("experiment_data", experiment_data)
self._soft_avg_validator = vals.Ints(1, int(1e8)).validate
# variables that are set before the start of any experiment.
self._settable_pars: list[Settable] = []
"""Parameter(s) to be set during the acquisition loop."""
self._setpoints: list[np.ndarray] = []
"""An (M, N) matrix of N setpoints for M settables."""
self._setpoints_input: Iterable[np.ndarray] = []
"""The values to loop over in the experiment."""
self._gettable_pars: list[Gettable] = []
"""Parameter(s) to be get during the acquisition loop."""
# variables used for book keeping during acquisition loop.
self._soft_avg = 1
self._nr_acquired_values = 0
self._loop_count = 0
self._begintime = time.time()
self._last_upd = time.time()
self._batch_size_last = None
self._dataarray_cache: Optional[Dict[str, Any]] = None
# variables used for persistence, plotting and data handling
self._dataset = None
self._exp_folder: Path = None
self._experiment = None
self._plotmon_name = ""
# attributes named as if they are python attributes, e.g. dset.drid_2d == True
self._plot_info = {
"grid_2d": False,
"grid_2d_uniformly_spaced": False,
"1d_2_settables_uniformly_spaced": False,
}
# properly handling KeyboardInterrupts
self._interrupt_manager = _KeyboardInterruptManager()
def __repr__full__(self):
str_out = super().__repr__() + "\n"
# hasattr is necessary in case the instrument was closed
if hasattr(self, "_settable_pars"):
settable_names = [p.name for p in self._settable_pars]
str_out += f" settables: {settable_names}\n"
if hasattr(self, "_gettable_pars"):
gettable_names = [p.name for p in self._gettable_pars]
str_out += f" gettables: {gettable_names}\n"
if hasattr(self, "_setpoints_input") and self._setpoints_input is not None:
input_shapes = [
np.asarray(points).shape for points in self._setpoints_input
]
str_out += f" setpoints_grid input shapes: {input_shapes}\n"
# Report the transposed shape to keep consistency with the UI (self.setpoints).
if hasattr(self, "_setpoints") and self._setpoints is not None:
try:
setpoints_shape = (
len(self._setpoints[0]),
len(self._setpoints),
)
except IndexError:
setpoints_shape = (0, 0)
str_out += f" setpoints shape: {setpoints_shape}\n"
return str_out
def __repr__(self):
"""
Returns a string containing a summary of this object regarding settables,
gettables and setpoints.
Intended, for example, to give a more useful representation in interactive
shells.
"""
return self.__repr__full__()
[docs]
def get_idn(self) -> dict[str, str | None]:
return {
"vendor": "Quantify",
"model": f"{self.__module__}.{self.__class__.__name__}",
"serial": self.name,
"firmware": _quantify_version,
}
[docs]
def show(self):
"""Print short representation of the object to stdout."""
print(self.__repr__full__())
[docs]
def set_experiment_data(
self, experiment_data: Dict[str, Any], overwrite: bool = True
):
"""
Populates the experiment_data submodule with experiment_data parameters
Parameters
-----------
experiment_data:
Dict specifying the names of the experiment_data parameters and their
values. Follows the format:
.. code-block:: python
{
"parameter_name": {
"value": 10.2
"label": "parameter label"
"unit": "Hz"
}
}
overwrite:
If True, clear all previously saved experiment_data parameters and save new
ones.
If False, keep all previously saved experiment_data parameters and change
their values if necessary
"""
if overwrite:
self.clear_experiment_data()
for name, parameter in experiment_data.items():
if name not in self.experiment_data.parameters:
self.experiment_data.add_parameter(
name=name, parameter_class=ManualParameter
)
self.experiment_data.parameters[name](parameter.get("value"))
self.experiment_data.parameters[name].label = parameter.get("label", name)
self.experiment_data.parameters[name].unit = parameter.get("unit", "")
[docs]
def clear_experiment_data(self):
"""
Remove all experiment_data parameters from the experiment_data submodule
"""
self.experiment_data.parameters = {}
[docs]
@staticmethod
def _reshape_data(
acq_protocol: str, vals: NDArray, real_imag: bool
) -> list[NDArray]:
"""Convert an array of complex numbers into two arrays of real numbers."""
if acq_protocol == "TriggerCount":
return [vals.real.astype(np.uint64)]
if acq_protocol == "Timetag":
return [vals.real.astype(np.float64)]
if acq_protocol == "ThresholdedAcquisition":
return [vals.real.astype(np.uint32)]
if acq_protocol in (
"Trace",
"SSBIntegrationComplex",
"ThresholdedAcquisition",
"WeightedIntegratedSeparated",
"NumericalSeparatedWeightedIntegration",
"NumericalWeightedIntegration",
):
ret_val = []
if real_imag:
ret_val.append(vals.real)
ret_val.append(vals.imag)
return ret_val
else:
ret_val.append(np.abs(vals))
ret_val.append(np.angle(vals, deg=True))
return ret_val
raise NotImplementedError(
f"Acquisition protocol {acq_protocol} is not supported."
)
[docs]
@classmethod
def _process_acquired_data( # noqa: PLR0912
cls, # noqa: ANN102
acquired_data: Dataset,
batched: bool,
real_imag: bool,
) -> tuple[NDArray[np.float64], ...]:
"""
Reshapes the data as returned from the gettable into the form
accepted by the measurement control.
Parameters
----------
acquired_data
Data that is returned by gettable.
batched
Parameter to distinct iterative and batched experiment.
real_imag:
If true, the gettable returns I, Q values. Otherwise, magnitude and phase
(degrees) are returned.
Returns
-------
:
A tuple of data, casted to a historical conventions on data format.
"""
# retrieve the acquisition results
return_data = []
# We sort acquisition channels so that the user
# has control over the order of the return data.
# https://gitlab.com/quantify-os/quantify-scheduler/-/issues/466
sorted_acq_channels: list[Hashable] = sorted(acquired_data.data_vars)
for idx, acq_channel in enumerate(sorted_acq_channels):
acq_channel_data = acquired_data[acq_channel]
acq_protocol = acq_channel_data.attrs["acq_protocol"]
num_dims = len(acq_channel_data.dims)
if acq_protocol == "Trace" and num_dims != 2:
raise ValueError(
f"Data returned by a gettable for "
f"{acq_protocol} acquisition protocol is expected to be an "
f"array of complex numbers with with two dimensions: "
f"acquisition index and trace index. This is not the case for "
f"acquisition channel {acq_channel}, that has data "
f"type {acq_channel_data.dtype} and {num_dims} dimensions: "
f"{', '.join(str(dim) for dim in acq_channel_data.dims)}."
)
if acq_protocol in (
"SSBIntegrationComplex",
"WeightedIntegratedSeparated",
"NumericalSeparatedWeightedIntegration",
"NumericalWeightedIntegration",
"ThresholdedAcquisition",
) and num_dims not in (1, 2):
raise ValueError(
f"Data returned by an gettable for "
f"{acq_protocol} acquisition protocol is expected to be an "
f"array of complex numbers with with one or two dimensions: "
f"acquisition index and optionally repetition index. This is not the case for "
f"acquisition channel {acq_channel}, that has data "
f"type {acq_channel_data.dtype} and {num_dims} dimensions: "
f"{', '.join(str(dim) for dim in acq_channel_data.dims)}."
)
if acq_protocol == "Trace" and acq_channel_data.shape[0] != 1:
raise ValueError(
"Trace acquisition protocol with several acquisitions on the "
"same acquisition channel is not supported by "
"a ScheduleGettable"
)
if acq_protocol not in (
"TriggerCount",
"Timetag",
"Trace",
"SSBIntegrationComplex",
"WeightedIntegratedSeparated",
"NumericalSeparatedWeightedIntegration",
"NumericalWeightedIntegration",
"ThresholdedAcquisition",
):
raise ValueError(f"ScheduleGettable does not support {acq_protocol}.")
vals = acq_channel_data.to_numpy().reshape((-1,))
if not batched and len(vals) != 1:
raise ValueError(
f"For iterative mode, only one value is expected for each "
f"acquisition channel. Got {len(vals)} values for acquisition "
f"channel '{acq_channel}' instead."
)
return_data.extend(cls._reshape_data(acq_protocol, vals, real_imag))
logger.debug(f"Returning {len(return_data)} values.")
return tuple(return_data)
############################################
# Methods used to control the measurements #
############################################
[docs]
def _reset(self, save_data=True):
"""
Resets all experiment specific variables for a new run.
"""
self._nr_acquired_values = 0
self._loop_count = 0
self._begintime = time.time()
self._batch_size_last = None
self._save_data = save_data
self._dataarray_cache = None
[docs]
def _reset_post(self):
"""
Resets specific variables that can change before `.run()`.
"""
self._plot_info = {
"grid_2d": False,
"grid_2d_uniformly_spaced": False,
"1d_2_settables_uniformly_spaced": False,
}
# Make sure tqdm progress bar attribute is closed and removed if mc is interrupted and shot down gracefully
if self.verbose() and hasattr(self, "pbar"):
self.pbar.close()
del self.pbar
[docs]
def _init(self, name):
"""
Initializes MC, such as creating the Dataset, experiment folder and such.
"""
# needs to be calculated here because we need the settables' `.batched`
if self._setpoints is None:
self._setpoints = grid_setpoints(self._setpoints_input, self._settable_pars)
# initialize an empty dataset
self._dataset = initialize_dataset(
self._settable_pars, self._setpoints, self._gettable_pars
)
self._dataset.attrs["name"] = name
# cannot add it as a separate (nested) dict so make it flat.
self._dataset.attrs.update(self._plot_info)
tuid = self._dataset.attrs["tuid"]
self._experiment = QuantifyExperiment(tuid=tuid)
if self._save_data:
self._exp_folder = Path(create_exp_folder(tuid=tuid, name=name))
self._safe_write_dataset() # Write the empty dataset
snap = snapshot(update=False, clean=True) # Save a snapshot of all
self._experiment.save_snapshot(snap)
else:
self._exp_folder = None
if self.instr_plotmon():
# Tell plotmon to start monitoring the new dataset
self.instr_plotmon.get_instr().update(tuid=tuid)
[docs]
def run(
self,
name: str = "",
soft_avg: int = 1,
lazy_set: Optional[bool] = None,
save_data: bool = True,
) -> xr.Dataset:
"""
Starts a data acquisition loop.
Parameters
----------
name
Name of the measurement. It is included in the name of the data files.
soft_avg
Number of software averages to be performed by the measurement control.
E.g. if `soft_avg=3` the full dataset will be measured 3 times and the
measured values will be averaged element-wise, the averaged dataset is then
returned.
lazy_set
If ``True`` and a setpoint equals the previous setpoint, the ``.set`` method
of the settable will not be called for that iteration.
If this argument is ``None``, the ``.lazy_set()`` ManualParameter is used
instead (which by default is ``False``).
.. warning:: This feature is not available yet when running in batched mode.
save_data
If ``True`` that the measurement data is stored.
"""
with self._interrupt_manager:
lazy_set = lazy_set if lazy_set is not None else self.lazy_set()
self._soft_avg_validator(soft_avg) # validate first
self._soft_avg = soft_avg
self._reset(save_data=save_data)
self._init(name)
self._prepare_settables()
try:
if self._get_is_batched():
if self.verbose():
print("Starting batched measurement...")
self._run_batched()
else:
if self.verbose():
print("Starting iterative measurement...")
self._run_iterative(lazy_set)
except KeyboardInterrupt:
print("\nInterrupt signaled, exiting gracefully...")
if self._save_data:
self._safe_write_dataset() # Wrap up experiment and store data
self._finish()
self._reset_post()
return self._dataset
[docs]
def run_adaptive(self, name, params, lazy_set: Optional[bool] = None) -> xr.Dataset:
"""
Starts a data acquisition loop using an adaptive function.
.. warning ::
The functionality of this mode can be complex - it is recommended to read
the relevant long form documentation.
Parameters
----------
name
Name of the measurement. This name is included in the name of the data
files.
params
Key value parameters describe the adaptive function to use, and any further
parameters for that function.
lazy_set
If ``True`` and a setpoint equals the previous setpoint, the ``.set`` method
of the settable will not be called for that iteration.
If this argument is ``None``, the ``.lazy_set()`` ManualParameter is used
instead (which by default is ``False``).
"""
lazy_set = lazy_set if lazy_set is not None else self.lazy_set()
def measure(vec) -> float:
"""
This function executes the measurement and is passed to the adaptive
function (often a minimization algorithm) to be evaluated many times.
Although the measure function acquires (and stores) all gettable parameters,
only the first value is returned to match the function signature for a valid
measurement function.
"""
if len(self._dataset["y0"]) == self._nr_acquired_values:
self._dataset = grow_dataset(self._dataset)
# 1D sweeps return single values, wrap in a list
if np.isscalar(vec):
vec = [vec]
self._iterative_set_and_get(vec, self._nr_acquired_values, lazy_set)
# only y0 is returned so as to match the function signature for a valid
# measurement function.
ret = self._dataset["y0"].values[self._nr_acquired_values]
self._nr_acquired_values += 1
self._update(".")
self._interrupt_manager.raise_if_interrupted()
return ret
def subroutine():
self._prepare_settables()
self._prepare_gettables()
adaptive_function = params.get("adaptive_function")
af_pars_copy = dict(params)
# if the adaptive function is part of the python adaptive library
if isinstance(adaptive_function, type) and issubclass(
adaptive_function, adaptive.learner.BaseLearner
):
goal = af_pars_copy["goal"]
unusued_pars = ["adaptive_function", "goal"]
for unusued_par in unusued_pars:
af_pars_copy.pop(unusued_par, None)
learner = adaptive_function(measure, **af_pars_copy)
adaptive.runner.simple(learner, goal)
# any object that is callable
elif callable(adaptive_function):
unused_pars = ["adaptive_function"]
for unused_par in unused_pars:
af_pars_copy.pop(unused_par, None)
adaptive_function(measure, **af_pars_copy)
else:
raise TypeError(
"The adaptive_function must either be a BaseLearner subclass,"
+ " or be callable."
)
with self._interrupt_manager:
self._reset()
self.setpoints(
np.zeros((64, len(self._settable_pars)))
) # block out some space in the dataset
self._init(name)
try:
print("Running adaptively...")
subroutine()
except KeyboardInterrupt:
print("\nInterrupt signaled, exiting gracefully...")
self._finish()
self._dataset = trim_dataset(self._dataset)
self._safe_write_dataset() # Wrap up experiment and store data
return self._dataset
def _run_iterative(self, lazy_set: bool = False):
while self._get_fracdone() < 1.0:
self._prepare_gettables()
self._dataarray_cache = {}
for idx in range(len(self._setpoints[0])):
self._iterative_set_and_get(
[spt[idx] for spt in self._setpoints],
self._curr_setpoint_idx(),
lazy_set,
)
self._nr_acquired_values += 1
self._update()
self._interrupt_manager.raise_if_interrupted()
self._dataarray_cache = None
self._loop_count += 1
def _run_batched(self): # pylint: disable=too-many-locals
batch_size = self._get_batch_size()
where_batched = self._get_where_batched()
where_iterative = self._get_where_iterative()
batched_settables = self._get_batched_settables()
iterative_settables = self._get_iterative_settables()
if self.verbose():
print(
"Iterative settable(s) [outer loop(s)]:\n\t",
", ".join(par.name for par in iterative_settables) or "--- (None) ---",
"\nBatched settable(s):\n\t",
", ".join(par.name for par in batched_settables),
f"\nBatch size limit: {batch_size:d}\n",
)
while self._get_fracdone() < 1.0:
setpoint_idx = self._curr_setpoint_idx()
self._batch_size_last = batch_size
slice_len = setpoint_idx + self._batch_size_last
for i, spar in enumerate(iterative_settables):
# Here ensure that all setpoints of each iterative settable are the same
# within each batch
val, iterator = next(
itertools.groupby(
self._setpoints[where_iterative[i]][setpoint_idx:slice_len]
)
)
spar.set(val)
# We also determine the size of each next batch
self._batch_size_last = min(self._batch_size_last, len(tuple(iterator)))
slice_len = setpoint_idx + self._batch_size_last
for i, spar in enumerate(batched_settables):
pnts = self._setpoints[where_batched[i]][setpoint_idx:slice_len]
spar.set(pnts)
# Update for `print_progress`
self._batch_size_last = min(self._batch_size_last, len(pnts))
self._prepare_gettables()
y_off = 0
for gpar in self._gettable_pars:
new_data_raw = gpar.get() # return xarray dataset
if isinstance(new_data_raw, xr.Dataset):
batched = gpar.batched
real_imag = gpar.real_imag
new_data = self._process_acquired_data(
new_data_raw, batched, real_imag
) # can return (N, M)
else:
new_data = new_data_raw
# if we get a simple array, shape it to (1, M)
if len(np.shape(new_data)) == 1:
new_data = new_data.reshape(1, (len(new_data)))
for row in new_data:
yi_name = f"y{y_off}"
slice_len = setpoint_idx + len(row) # the slice we will be updating
old_vals = self._dataset[yi_name].values[setpoint_idx:slice_len]
old_vals[np.isnan(old_vals)] = (
0 # will be full of NaNs on the first iteration, change to 0
)
self._dataset[yi_name].values[setpoint_idx:slice_len] = (
self._build_data(row, old_vals)
)
y_off += 1
self._nr_acquired_values += np.shape(new_data)[1]
self._update()
self._interrupt_manager.raise_if_interrupted()
def _build_data(self, new_data, old_data):
if self._soft_avg == 1:
return old_data + new_data
return (new_data + old_data * self._loop_count) / (1 + self._loop_count)
[docs]
def _iterative_set_and_get(
self, setpoints: np.ndarray, idx: int, lazy_set: bool = False
):
"""
Processes one row of setpoints. Sets all settables, gets all gettables, encodes
new data in dataset.
If lazy_set==True and any setpoint equals the corresponding previous setpoint,
that setpoint is not set in its corresponding settable.
.. note ::
Note: some lines in this function are redundant depending on mode (sweep vs
adaptive). Specifically:
- in sweep, the x dimensions are already filled
- in adaptive, soft_avg is always 1
"""
assert self._dataset is not None
# set all individual setparams
for setpar_idx, (spar, spt) in enumerate(zip(self._settable_pars, setpoints)):
xi_name = f"x{setpar_idx}"
if self._dataarray_cache is None:
xi_dataarray_values = self._dataset[xi_name].values
else:
if not xi_name in self._dataarray_cache:
self._dataarray_cache[xi_name] = self._dataset[xi_name].values
xi_dataarray_values = self._dataarray_cache[xi_name]
xi_dataarray_values[idx] = spt
prev_spt = xi_dataarray_values[idx - 1] if idx else None
# if lazy_set==True and the setpoint equals the previous setpoint, do not
# set the setpoint.
if not (lazy_set and spt == prev_spt):
spar.set(spt)
# get all data points
y_offset = 0
for gpar in self._gettable_pars:
new_data_raw = gpar.get() # return xarray dataset
if isinstance(new_data_raw, xr.Dataset):
batched = gpar.batched
real_imag = gpar.real_imag
new_data = self._process_acquired_data(
new_data_raw, batched, real_imag
) # can return (N, M)
else:
new_data = new_data_raw
# if the gettable returned a float, cast to list
if np.isscalar(new_data):
new_data = [new_data]
# iterate through the data list, each element is different y for these
# x coordinates
for val in new_data:
yi_name = f"y{y_offset}"
if self._dataarray_cache is None:
yi_dataarray_values = self._dataset[yi_name].values
else:
if not yi_name in self._dataarray_cache:
self._dataarray_cache[yi_name] = self._dataset[yi_name].values
yi_dataarray_values = self._dataarray_cache[yi_name]
old_val = yi_dataarray_values[idx]
if self._soft_avg == 1 or np.isnan(old_val):
if isinstance(val, np.ndarray) and val.size == 1:
# This branch avoids usage of deprecated code
yi_dataarray_values[idx] = val.item()
else:
# This is deprecated if val is an np.ndarray with ndim > 0
yi_dataarray_values[idx] = val
else:
averaged = (val + old_val * self._loop_count) / (
1 + self._loop_count
)
yi_dataarray_values[idx] = averaged
y_offset += 1
############################################
# Methods used to control the measurements #
############################################
[docs]
def _update(self, print_message: str = None):
"""
Do any updates to/from external systems, such as saving, plotting, etc.
"""
update = (
time.time() - self._last_upd > self.update_interval()
or self._nr_acquired_values == self._get_max_setpoints()
)
if update:
self.print_progress(print_message)
if self._save_data:
self._safe_write_dataset()
self._last_upd = time.time()
[docs]
def _prepare_gettables(self) -> None:
"""
Call prepare() on the Gettable, if prepare() exists
"""
for getpar in self._gettable_pars:
call_if_has_method(getpar, "prepare")
[docs]
def _prepare_settables(self) -> None:
"""
Call prepare() on all Settable, if prepare() exists
"""
for setpar in self._settable_pars:
call_if_has_method(setpar, "prepare")
[docs]
def _finish(self) -> None:
"""
Call finish() on all Settables and Gettables, if finish() exists
"""
for par in self._gettable_pars + self._settable_pars:
call_if_has_method(par, "finish")
def _get_batched_mask(self):
return tuple(is_batched(spar) for spar in self._settable_pars)
def _get_where_batched(self):
# Indices to select correct entries in results data
return np.where(self._get_batched_mask())[0]
def _get_where_iterative(self):
return np.where(tuple(not m for m in self._get_batched_mask()))[0]
def _get_iterative_settables(self):
return tuple(spar for spar in self._settable_pars if not is_batched(spar))
def _get_batched_settables(self):
return tuple(spar for spar in self._settable_pars if is_batched(spar))
def _get_batch_size(self):
# np.inf is not supported by the JSON schema, but we keep the code robust
min_with_inf = min(
getattr(par, "batch_size", np.inf)
for par in chain.from_iterable((self._settable_pars, self._gettable_pars))
)
return min(min_with_inf, len(self._setpoints[0]))
def _get_is_batched(self) -> bool:
if any(
is_batched(gpar) for gpar in chain(self._gettable_pars, self._settable_pars)
):
if not all(is_batched(gpar) for gpar in self._gettable_pars):
raise RuntimeError(
"Control mismatch; all Gettables must have batched Control Mode, "
"i.e. all gettables must have `.batched=True`."
)
if not any(is_batched(spar) for spar in self._settable_pars):
raise RuntimeError(
"Control mismatch; At least one settable must have "
"`settable.batched=True`, if the gettables are batched."
)
return True
return False
[docs]
def _get_max_setpoints(self) -> int:
"""
The total number of setpoints to examine
"""
try:
return len(self._setpoints[0]) * self._soft_avg
except IndexError:
return 0
[docs]
def _curr_setpoint_idx(self) -> int:
"""
Current position through the sweep
Returns
-------
int
setpoint_idx
"""
acquired = self._nr_acquired_values
setpoint_idx = acquired % len(self._setpoints[0])
self._loop_count = acquired // len(self._setpoints[0])
return setpoint_idx
[docs]
def _get_fracdone(self) -> float:
"""
Returns the fraction of the experiment that is completed.
"""
return self._nr_acquired_values / self._get_max_setpoints()
[docs]
def print_progress(self, progress_message: str = None):
"""
Prints the provided `progress_messages` or displays tqdm progress bar; and calls the
callback specified by `on_progress_callback`.
NB: if called with no progress message (progress bar is used),
`self.pbar` attribute should be closed and removed.
Printing and progress bar display can be suppressed with `.verbose(False)`.
"""
# There are no points initialized, progress does not make sense
if self._get_max_setpoints() == 0:
raise ValueError("No setpoints available, progress cannot be defined")
# by checking if `progress_message` is None we make sure we change print behavior for adaptive run
if self.verbose() and not hasattr(self, "pbar") and progress_message is None:
# when you use `unit` instead of `postfix` it removes unintended comma
# see https://github.com/tqdm/tqdm/issues/712
custom_bar_format = "{l_bar}{bar} [ elapsed time: {elapsed} | time left: {remaining} ] {unit}"
self.pbar = tqdm(total=100, desc="Completed", bar_format=custom_bar_format)
progress_percent = self._get_fracdone() * 100
if self.verbose() and progress_message is None:
progress_diff = math.floor(progress_percent) - self.pbar.n
if self._batch_size_last is not None:
self.pbar.unit = f" last batch size: {self._batch_size_last}"
else:
# if no unit attribute provided `custom_bar_format` breaks with extra `it` output
self.pbar.unit = ""
self.pbar.update(progress_diff)
if (
self.verbose()
and hasattr(self, "pbar")
and math.floor(progress_percent) >= 100
):
self.pbar.close()
del self.pbar
if self.on_progress_callback() is not None:
self.on_progress_callback()(progress_percent)
if self.verbose() and progress_message is not None:
print(progress_message, end="")
[docs]
def _safe_write_dataset(self):
"""
Uses a lock when writing the file to stay safe for multiprocessing.
Locking files are written into a temporary dir to avoid polluting
the experiment container.
"""
# Multiprocess safe
lockfile = (
_DATASET_LOCKS_DIR / f"{self._dataset.attrs['tuid']}-{DATASET_NAME}.lock"
)
with FileLock(lockfile, 5):
self._experiment.write_dataset(self._dataset)
####################################
# Non-parameter get/set functions #
####################################
[docs]
def settables(self, settable_pars):
"""
Define the settable parameters for the acquisition loop.
The :class:`.Settable` helper class defines the
requirements for a Settable object.
Parameters
---------
settable_pars
parameter(s) to be set during the acquisition loop, accepts a list or tuple
of multiple Settable objects or a single Settable object.
"""
# for native nD compatibility we treat this like a list of settables.
if not isinstance(settable_pars, (list, tuple)):
settable_pars = [settable_pars]
self._settable_pars = []
for settable in settable_pars:
self._settable_pars.append(Settable(settable))
[docs]
def setpoints(self, setpoints: np.ndarray):
"""
Set setpoints that determine values to be set in acquisition loop.
.. tip::
Use :func:`~numpy.column_stack` to reshape multiple
1D arrays when setting multiple settables.
Parameters
----------
setpoints :
An array that defines the values to loop over in the experiment.
The shape of the array has to be either (N,) or (N,1) for a 1D loop;
or (N, M) in the case of an MD loop.
"""
if len(np.shape(setpoints)) == 1:
setpoints = setpoints.reshape((len(setpoints), 1))
elif len(np.shape(setpoints)) == 2:
# used in plotmon to detect need for interpolation in 2d plot
is_uniform = all(
_is_uniformly_spaced_array(setpoints_i) for setpoints_i in setpoints.T
)
self._plot_info["1d_2_settables_uniformly_spaced"] = is_uniform
# UI is to provide an (N, M) array, but internally we store an (M, N) array
self._setpoints = setpoints.T
# `.setpoints()` and `.setpoints_grid()` cannot be used at the same time
self._setpoints_input = None
[docs]
def setpoints_grid(self, setpoints: Iterable[np.ndarray]):
"""
Makes a grid from the provided `setpoints` assuming each array element
corresponds to an orthogonal dimension.
The resulting gridded points determine values to be set in the acquisition loop.
The gridding is such that the inner most loop corresponds to the batched
settable with the smallest `.batch_size`.
.. seealso:: :ref:`Measurement Control How-To <howto-measurement-control>`
Parameters
----------
setpoints
The values to loop over in the experiment. The grid is reshaped in the same
order.
"""
self._setpoints = None # assigned later in the `._init()`
self._setpoints_input = setpoints
if len(setpoints) == 2:
self._plot_info["xlen"] = len(setpoints[0])
self._plot_info["ylen"] = len(setpoints[1])
self._plot_info["grid_2d"] = True
is_uniform = all( # used in plotmon to detect need for interpolation
_is_uniformly_spaced_array(setpoints[i]) for i in (0, 1)
)
self._plot_info["grid_2d_uniformly_spaced"] = is_uniform
[docs]
def gettables(self, gettable_pars):
"""
Define the parameters to be acquired during the acquisition loop.
The :class:`.Gettable` helper class defines the
requirements for a Gettable object.
Parameters
----------
gettable_pars
parameter(s) to be get during the acquisition loop, accepts:
- list or tuple of multiple Gettable objects
- a single Gettable object
"""
if not isinstance(gettable_pars, (list, tuple)):
gettable_pars = [gettable_pars]
self._gettable_pars = []
for gpar in gettable_pars:
self._gettable_pars.append(Gettable(gpar))
[docs]
def measurement_description(self) -> Dict[str, Any]:
"""Return a serializable description of the latest measurement
Users can add additional information to the description manually.
Returns
-------
:
Dictionary with description of the measurement
"""
experiment_description = {
"name": self._dataset.attrs["name"],
"settables": [str(s) for s in self._settable_pars],
}
experiment_description["gettables"] = [str(s) for s in self._gettable_pars]
# Report the transposed shape to keep consistency with the UI (self.setpoints).
try:
experiment_description["setpoints_shape"] = (
len(self._setpoints[0]),
len(self._setpoints),
)
except IndexError:
experiment_description["setpoints_shape"] = (0, 0)
experiment_description["soft_avg"] = self._soft_avg
experiment_description["acquired_dataset"] = {"tuid": self._dataset.tuid}
return experiment_description
[docs]
def grid_setpoints(
setpoints: Sequence[Sequence],
settables: Iterable | None = None,
) -> list[np.ndarray]:
"""
Make gridded setpoints.
If ``settables`` is provided, the gridding is such that the inner most loop
corresponds to the batched settable with the smallest ``.batch_size``.
Parameters
----------
setpoints
A list of arrays that defines the values to loop over in the experiment for each
orthogonal dimension. The grid is reshaped in the same order.
settables
A list of settable objects to which the elements in the `setpoints` correspond
to. Used to correctly grid data when mixing batched and iterative settables.
Returns
-------
list[np.ndarray]
A 2D array where the first axis corresponds to the settables, and the second
axis to individual setpoints.
"""
if settables is None:
settables = [None] * len(setpoints)
coordinates_batched = [i for i, spar in enumerate(settables) if is_batched(spar)]
coordinates_iterative = [
i for i, spar in enumerate(settables) if not is_batched(spar)
][::-1]
stack_order = coordinates_iterative
if len(coordinates_batched):
batch_sizes = [
getattr(spar, "batch_size", np.inf)
for spar in settables
if is_batched(spar)
]
inner_coord = coordinates_batched[np.argmin(batch_sizes)]
coordinates_batched.remove(inner_coord)
# The inner most coordinate must correspond to the batched settable with
# min `.batch_size`
stack_order += coordinates_batched[::-1] + [inner_coord]
order_of_order = np.argsort(stack_order)
stacked_dset = _cartesian_product_transposed(*(setpoints[i] for i in stack_order))
stacked_dset = [stacked_dset[i] for i in order_of_order]
return stacked_dset
[docs]
class _SupportsMul(Protocol):
"""A type that supports multiplication (*)."""
def __mul__(self, other): ...
def __rmul__(self, other): ...
T = TypeVar("T", bound=_SupportsMul)
def _prod(iter_: Iterable[T]) -> T:
return reduce(lambda x, y: x * y, iter_)
def _cartesian_product_transposed(*setpoints: Sequence) -> list[np.ndarray]:
lengths = [len(arr) for arr in setpoints]
out = []
for i, arr in enumerate(setpoints):
row = np.array(arr)
if i < len(setpoints) - 1:
row = np.repeat(row, _prod(lengths[i + 1 :]))
if i > 0:
row = np.tile(row, _prod(lengths[:i]))
out.append(row)
return out
Handler = Callable[[int, Optional[types.FrameType]], Any]
[docs]
class _KeyboardInterruptManager:
"""Support class for handling keyboard interrupts in a controlled way."""
def __init__(self, n_forced: int = 5) -> None:
self._n_forced = n_forced
self.n_interrupts = 0
self._previous_handler: Optional[Handler] = None
def __enter__(self) -> Self:
self.n_interrupts = 0
if threading.current_thread() is threading.main_thread():
# Signal handlers can only be installed in main thread,
# do nothing in other thread.
self._previous_handler = cast(
Handler, signal.signal(signal.SIGINT, self._handle_interrupt)
)
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
) -> Literal[False]:
if self._previous_handler is not None:
signal.signal(signal.SIGINT, self._previous_handler)
if self.n_interrupts > 0: # call outside handler on exit
self._previous_handler(signal.SIGINT, None)
self.n_interrupts = 0
self._previous_handler = None
return False
def _handle_interrupt(self, sig: int, frame: Optional[types.FrameType]) -> None:
del sig, frame # unused arguments
self.n_interrupts += 1
if self.n_interrupts >= self._n_forced:
raise KeyboardInterrupt("Measurement interruption forced")
print(
f"\n\n[!!!] {self.n_interrupts} interruption(s) signaled. "
"Stopping after this iteration/batch.\n"
f"[Send {self._n_forced - self.n_interrupts} more interruptions to force"
f"stop (not safe!)].\n"
)
[docs]
def raise_if_interrupted(self) -> None:
"""
Verifies if the user has signaled the interruption of the experiment.
Intended to be used after each iteration or after each batch of data.
"""
if self.n_interrupts > 0:
raise KeyboardInterrupt("Measurement interrupted")