Source code for qupulse.plotting

"""This module defines plotting functionality for instantiated PulseTemplates using matplotlib.

Classes:
    - PlottingNotPossibleException.
Functions:
    - plot: Plot a pulse using matplotlib.
"""

from typing import Dict, Tuple, Any, Optional, Set, List, Union, Mapping
from numbers import Real

import matplotlib.pyplot as plt
import numpy as np
import warnings
import operator
import itertools
import functools

try:
    from matplotlib import colormaps
    get_cmap = colormaps.get_cmap
except (ImportError, AttributeError):  # pragma: no cover
    # was deprecated in matplotlib 3.7, but we keep it around to allow this code to work with older versions
    get_cmap = plt.get_cmap

from qupulse.program import waveforms
from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface
from qupulse.pulses.pulse_template import PulseTemplate
from qupulse.program.waveforms import Waveform
from qupulse.program.loop import Loop, to_waveform


__all__ = ["render", "plot", "PlottingNotPossibleException"]


[docs]def render(program: Union[Loop], sample_rate: Real = 10.0, render_measurements: bool = False, time_slice: Tuple[Real, Real] = None, plot_channels: Optional[Set[ChannelID]] = None) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray], List[MeasurementWindow]]: """'Renders' a pulse program. Samples all contained waveforms into an array according to the control flow of the program. Args: program: The pulse (sub)program to render. Can be represented either by a Loop object or the more old-fashioned InstructionBlock. sample_rate: The sample rate in GHz. render_measurements: If True, the third return value is a list of measurement windows. time_slice: The time slice to be rendered. If None, the entire pulse will be shown. plot_channels: Only channels in this set are rendered. If None, all will. Returns: A tuple (times, values, measurements). times is a numpy.ndarray of dimensions sample_count where containing the time values. voltages is a dictionary of one numpy.ndarray of dimensions sample_count per defined channel containing corresponding sampled voltage values for that channel. measurements is a sequence of all measurements where each measurement is represented by a tuple (name, start_time, duration). """ if has_type_interface(program, Loop): waveform, measurements = _render_loop(program, render_measurements=render_measurements) else: raise ValueError('Cannot render an object of type %r' % type(program), program) if waveform is None: return np.array([]), dict(), measurements if plot_channels is None: channels = waveform.defined_channels else: channels = waveform.defined_channels & plot_channels if time_slice is None: start_time, end_time = 0, waveform.duration elif time_slice[1] < time_slice[0] or time_slice[0] < 0 or time_slice[1] < 0: raise ValueError("time_slice is not valid.") else: start_time, end_time, *_ = time_slice # filter measurement windows measurements = [(name, begin, length) for name, begin, length in measurements if begin < end_time and begin + length > start_time] sample_count = (end_time - start_time) * sample_rate + 1 if sample_count < 2: raise PlottingNotPossibleException(pulse=None, description='cannot render sequence with less than 2 data points') if not round(float(sample_count), 10).is_integer(): warnings.warn(f"Sample count {sample_count} is not an integer. Will be rounded (this changes the sample rate).", stacklevel=2) times = np.linspace(float(start_time), float(end_time), num=int(sample_count)) times[-1] = np.nextafter(times[-1], times[-2]) voltages = {ch: waveforms._ALLOCATION_FUNCTION(times, **waveforms._ALLOCATION_FUNCTION_KWARGS) for ch in channels} for ch, ch_voltage in voltages.items(): waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage) return times, voltages, measurements
def _render_loop(loop: Loop, render_measurements: bool,) -> Tuple[Waveform, List[MeasurementWindow]]: """Transform program into single waveform and measurement windows. The specific implementation of render for Loop arguments.""" waveform = to_waveform(loop) if render_measurements: measurement_dict = loop.get_measurement_windows() measurement_list = [] for name, (begins, lengths) in measurement_dict.items(): measurement_list.extend(zip(itertools.repeat(name), begins, lengths)) measurements = sorted(measurement_list, key=operator.itemgetter(1)) else: measurements = [] return waveform, measurements
[docs]def plot(pulse: PulseTemplate, parameters: Dict[str, Real]=None, sample_rate: Optional[Real]=10, axes: Any=None, show: bool=True, plot_channels: Optional[Set[ChannelID]]=None, plot_measurements: Optional[Set[str]]=None, stepped: bool=True, maximum_points: int=10**6, time_slice: Tuple[Real, Real]=None, **kwargs) -> Any: # pragma: no cover """Plots a pulse using matplotlib. The given pulse template will first be turned into a pulse program (represented by a Loop object) with the provided parameters. The render() function is then invoked to obtain voltage samples over the entire duration of the pulse which are then plotted in a matplotlib figure. Args: pulse: The pulse to be plotted. parameters: An optional mapping of parameter names to Parameter objects. sample_rate: The rate with which the waveforms are sampled for the plot in samples per time unit. If None, then automatically determine the sample rate (default = 10) axes: matplotlib Axes object the pulse will be drawn into if provided show: If true, the figure will be shown plot_channels: If specified only channels from this set will be plotted. If omitted all channels will be. stepped: If true pyplot.step is used for plotting plot_measurements: If specified measurements in this set will be plotted. If omitted no measurements will be. maximum_points: If the sampled waveform is bigger, it is not plotted time_slice: The time slice to be plotted. If None, the entire pulse will be shown. kwargs: Forwarded to pyplot. Overwrites other settings. Returns: matplotlib.pyplot.Figure instance in which the pulse is rendered Raises: PlottingNotPossibleException if the sequencing is interrupted before it finishes, e.g., because a parameter value could not be evaluated all Exceptions possibly raised during sequencing """ from matplotlib import pyplot as plt channels = pulse.defined_channels if parameters is None: parameters = dict() if sample_rate is None: if time_slice is None: duration = pulse.duration else: duration = time_slice[1]-time_slice[0] if duration == 0: sample_rate = 1 else: duration_per_sample = float(duration) / 1000 sample_rate = 1 / duration_per_sample program = pulse.create_program(parameters=parameters, channel_mapping={ch: ch for ch in channels}, measurement_mapping={w: w for w in pulse.measurement_names}) if program is not None: times, voltages, measurements = render(program, sample_rate, render_measurements=bool(plot_measurements), time_slice=time_slice) else: times, voltages, measurements = np.array([]), dict(), [] duration = 0 if times.size == 0: warnings.warn("Pulse to be plotted is empty!") elif times.size > maximum_points: # todo [2018-05-30]: since it results in an empty return value this should arguably be an exception, not just a warning warnings.warn(f"Sampled pulse of size {times.size} is lager than {maximum_points}", stacklevel=2) return None else: duration = times[-1] if time_slice is None: time_slice = (0, duration) legend_handles = [] if axes is None: # plot to figure figure = plt.figure() axes = figure.add_subplot(111) if plot_channels is not None: voltages = {ch: voltage for ch, voltage in voltages.items() if ch in plot_channels} for ch_name, voltage in voltages.items(): label = 'channel {}'.format(ch_name) if stepped: line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs}) else: line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs}) legend_handles.append(line) if plot_measurements: measurement_dict = dict() for name, begin, length in measurements: if name in plot_measurements: measurement_dict.setdefault(name, []).append((begin, begin+length)) color_map = get_cmap('plasma') meas_colors = {name: color_map(i/len(measurement_dict)) for i, name in enumerate(measurement_dict.keys())} for name, begin_end_list in measurement_dict.items(): for begin, end in begin_end_list: poly = axes.axvspan(begin, end, alpha=0.2, label=name, edgecolor='black', facecolor=meas_colors[name]) legend_handles.append(poly) axes.legend(handles=legend_handles) max_voltage = max((max(channel, default=0) for channel in voltages.values()), default=0) min_voltage = min((min(channel, default=0) for channel in voltages.values()), default=0) # add some margins in the presentation axes.set_xlim(-0.5+time_slice[0], time_slice[1] + 0.5) voltage_difference = max_voltage-min_voltage if voltage_difference>0: axes.set_ylim(min_voltage - 0.1*voltage_difference, max_voltage + 0.1*voltage_difference) axes.set_xlabel('Time (ns)') axes.set_ylabel('Voltage (a.u.)') if pulse.identifier: axes.set_title(pulse.identifier) if show: with warnings.catch_warnings(): # do not show warnings in jupyter notebook with matplotlib inline backend warnings.filterwarnings(action="ignore",message=".*which is a non-GUI backend, so cannot show the figure.*") axes.get_figure().show() return axes.get_figure()
@functools.singledispatch def plot_2d(program: Loop, channels: Tuple[ChannelID, ChannelID], sample_rate: float = None, ax: plt.Axes = None, plot_kwargs: Mapping = None) -> plt.Figure: """Plot the pulse/program in the plane of the given channels. Args: program: The program to plot channels: (x_axis, y_axis) name tuple sample_rate: Sample rate to use. Defaults to max(1000 samples per program, 10 per nano second) ax: Axis to plot into. plot_kwargs: Forwarded to the plot function. """ if sample_rate is None: sample_rate = max(1000 / program.duration, 10) _, rendered, _ = render(program, sample_rate, plot_channels=set(channels)) x_y = np.array([rendered[channels[0]], rendered[channels[1]]]) keep = np.full(x_y.shape[1], fill_value=True) keep[1:] = np.any(x_y[:, 1:] != x_y[:, :-1], axis=0) x_y_plt = x_y[:, keep] ax = ax or plt.subplots()[1] ax.plot(x_y_plt[0, :], x_y_plt[1, :], **(plot_kwargs or {})) ax.set_xlabel(channels[0]) ax.set_ylabel(channels[1]) return ax.get_figure() @plot_2d.register def _(pulse_template: PulseTemplate, channels: Tuple[ChannelID, ChannelID], sample_rate: float = None, ax: plt.Axes = None, plot_kwargs: Mapping = None, parameters=None, channel_mapping=None) -> plt.Figure: if channel_mapping is None: channel_mapping = {ch: ch if ch in channels else None for ch in pulse_template.defined_channels} create_program_kwargs = {'channel_mapping': channel_mapping} if parameters is not None: create_program_kwargs['parameters'] = parameters program = pulse_template.create_program(**create_program_kwargs) return plot_2d(program, channels, sample_rate=sample_rate, ax=ax, plot_kwargs=plot_kwargs)
[docs]class PlottingNotPossibleException(Exception): """Indicates that plotting is not possible because the sequencing process did not translate the entire given PulseTemplate structure.""" def __init__(self, pulse, description = None) -> None: super().__init__() self.pulse = pulse self.description = description def __str__(self) -> str: if self.description is None: return "Plotting is not possible. There are parameters which cannot be computed." else: return "Plotting is not possible: %s." % self.description