From 8eef90212c3658ef35a959ae53601d819e55a92c Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Mon, 13 May 2024 21:13:22 +0200 Subject: [PATCH] Avoid using Command for get --- src/qcodes/parameters/parameter.py | 116 +++++++++++++++++---------- src/qcodes/utils/function_helpers.py | 3 +- 2 files changed, 76 insertions(+), 43 deletions(-) diff --git a/src/qcodes/parameters/parameter.py b/src/qcodes/parameters/parameter.py index effda2573f30..1bb4fbeb5100 100644 --- a/src/qcodes/parameters/parameter.py +++ b/src/qcodes/parameters/parameter.py @@ -5,9 +5,10 @@ import logging import os -from functools import wraps from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from qcodes.utils import is_function from .command import Command from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType @@ -22,6 +23,65 @@ log = logging.getLogger(__name__) +def _get_parameter_factory( + function: Callable[[str], ParamRawDataType] | None, + cmd: str | Callable[[], ParamRawDataType] | None, + parameter_name: str, +) -> Callable[[Parameter], ParamRawDataType]: + if cmd is None: + + def get_manual_parameter(self: Parameter) -> ParamRawDataType: + if self.root_instrument is not None: + mylogger: InstrumentLoggerAdapter | logging.Logger = ( + self.root_instrument.log + ) + else: + mylogger = log + mylogger.debug( + "Getting raw value of parameter: %s as %s", + self.full_name, + self.cache.raw_value, + ) + return self.cache.raw_value + + return get_manual_parameter + + elif isinstance(cmd, str) and is_function(function, 1): + # cast is safe since we just checked this above using is_function + function = cast(Callable[[str], ParamRawDataType], function) + + def get_parameter_ask(self: Parameter) -> ParamRawDataType: + # for some reason mypy does not understand + # that cmd is a str even if this is defined inside + # an if isinstance block + assert isinstance(cmd, str) + return function(cmd) + + return get_parameter_ask + + elif is_function(cmd, 0): + # cast is safe since we just checked this above using is_function + cmd = cast(Callable[[], ParamRawDataType], cmd) + + def get_parameter_func(self: Parameter) -> ParamRawDataType: + return cmd() + + return get_parameter_func + + elif isinstance(cmd, str) and function is None: + raise TypeError( + f"Cannot use a str get_cmd without " + f"binding to an instrument. " + f"Got: get_cmd {cmd} for parameter {parameter_name}" + ) + + else: + raise TypeError( + "Unexpected options for parameter get. " + f"Got: get_cmd {cmd} for parameter {parameter_name}" + ) + + class Parameter(ParameterBase): """ A parameter represents a single degree of freedom. Most often, @@ -172,7 +232,7 @@ def __init__( instrument: InstrumentBase | None = None, label: str | None = None, unit: str | None = None, - get_cmd: str | Callable[..., Any] | Literal[False] | None = None, + get_cmd: str | Callable[[], ParamRawDataType] | Literal[False] | None = None, set_cmd: str | Callable[..., Any] | Literal[False] | None = False, initial_value: float | str | None = None, max_val_age: float | None = None, @@ -211,25 +271,10 @@ def _set_manual_parameter( self.cache._set_from_raw_value(x) return x - def _get_command_caller(parameter: Parameter, command: Command) -> MethodType: - @wraps(Command.__call__) - def call_command(self: Parameter) -> Any: - return command() - - return MethodType(call_command, parameter) - - def _set_command_caller(parameter: Parameter, command: Command) -> MethodType: - @wraps(Command.__call__) - def call_command(self: Parameter, val: ParamRawDataType) -> Any: - return command(val) - - return MethodType(call_command, parameter) - if instrument is not None and bind_to_instrument: existing_parameter = instrument.parameters.get(name, None) if existing_parameter: - # this check is redundant since its also in the baseclass # but if we do not put it here it would be an api break # as parameter duplication check won't be done first, @@ -281,27 +326,15 @@ def call_command(self: Parameter, val: ParamRawDataType) -> Any: " get_raw is an error." ) elif not self.gettable and get_cmd is not False: - if get_cmd is None: - # ignore typeerror since mypy does not allow setting a method dynamically - self.get_raw = MethodType(_get_manual_parameter, self) # type: ignore[method-assign] - else: - if isinstance(get_cmd, str) and instrument is None: - raise TypeError( - f"Cannot use a str get_cmd without " - f"binding to an instrument. " - f"Got: get_cmd {get_cmd} for parameter {name}" - ) + exec_str_ask: Callable[[str], ParamRawDataType] | None = ( + getattr(instrument, "ask", None) if instrument else None + ) + + self.get_raw = MethodType( # type: ignore[method-assign] + _get_parameter_factory(exec_str_ask, cmd=get_cmd, parameter_name=name), + self, + ) - exec_str_ask = getattr(instrument, "ask", None) if instrument else None - # ignore typeerror since mypy does not allow setting a method dynamically - self.get_raw = _get_command_caller( # type: ignore[method-assign] - self, - Command( - arg_count=0, - cmd=get_cmd, - exec_str=exec_str_ask, - ), - ) self._gettable = True self.get = self._wrap_get(self.get_raw) @@ -326,10 +359,11 @@ def call_command(self: Parameter, val: ParamRawDataType) -> Any: exec_str_write = ( getattr(instrument, "write", None) if instrument else None ) + # TODO get_raw should also be a method here. This should probably be done by wrapping + # it with MethodType like above # ignore typeerror since mypy does not allow setting a method dynamically - self.set_raw = _set_command_caller( # type: ignore[method-assign] - self, - Command(arg_count=1, cmd=set_cmd, exec_str=exec_str_write), + self.set_raw = Command( # type: ignore[assignment] + arg_count=1, cmd=set_cmd, exec_str=exec_str_write ) self._settable = True self.set = self._wrap_set(self.set_raw) diff --git a/src/qcodes/utils/function_helpers.py b/src/qcodes/utils/function_helpers.py index ba919efcae99..1fca19de39ed 100644 --- a/src/qcodes/utils/function_helpers.py +++ b/src/qcodes/utils/function_helpers.py @@ -40,6 +40,5 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: inputs = [0] * arg_count sig.bind(*inputs) return True - except TypeError as e: - raise e + except TypeError: return False