Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix use of jsonargparse avoiding reliance on non-public internal logic (
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Jun 27, 2023
1 parent 30bf0f5 commit ebe80e3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 57 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ torchmetrics >0.7.0, <0.11.0 # strict
pytorch-lightning >1.8.0, <2.0.0 # strict
pyDeprecate >0.2.0
pandas >1.1.0, <=1.5.2
jsonargparse[signatures] >4.0.0, <=4.9.0
jsonargparse[signatures] >=4.22.0, <4.23.0
click >=7.1.2, <=8.1.3
protobuf <=3.20.1
fsspec[http] >=2022.5.0,<=2023.6.0
Expand Down
18 changes: 13 additions & 5 deletions src/flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

import pytorch_lightning as pl
from jsonargparse import ArgumentParser
from jsonargparse.signatures import get_class_signature_functions
from jsonargparse import ArgumentParser, class_from_function
from lightning_utilities.core.overrides import is_overridden
from pytorch_lightning import LightningModule, Trainer

Expand All @@ -31,7 +30,6 @@
LightningArgumentParser,
LightningCLI,
SaveConfigCallback,
class_from_function,
)
from flash.core.utilities.stability import beta

Expand Down Expand Up @@ -107,6 +105,16 @@ def wrapper(*args, **kwargs):
return wrapper


def get_class_signature_functions(classes):
signatures = []
for num, cls in enumerate(classes):
if cls.__new__ is not object.__new__ and not any(cls.__new__ is c.__new__ for c in classes[num + 1 :]):
signatures.append((cls, cls.__new__))
if not any(cls.__init__ is c.__init__ for c in classes[num + 1 :]):
signatures.append((cls, cls.__init__))
return signatures


def get_overlapping_args(func_a, func_b) -> Set[str]:
func_a = get_class_signature_functions([func_a])[0][1]
func_b = get_class_signature_functions([func_b])[0][1]
Expand Down Expand Up @@ -214,7 +222,7 @@ def add_arguments_to_parser(self, parser) -> None:
def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
if get_kwarg_name(function) == "data_module_kwargs":
datamodule_function = class_from_function(function, return_type=self.local_datamodule_class)
datamodule_function = class_from_function(function, self.local_datamodule_class)
subcommand.add_class_arguments(
datamodule_function,
fail_untyped=False,
Expand All @@ -233,7 +241,7 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None
},
)
else:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
datamodule_function = class_from_function(drop_kwargs(function), self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
Expand Down
41 changes: 4 additions & 37 deletions src/flash/core/utilities/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
import os
import warnings
from argparse import Namespace
from functools import wraps
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import torch
from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode
from jsonargparse.signatures import ClassFromFunctionBase
from jsonargparse.typehints import ClassType
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand All @@ -25,46 +22,16 @@
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]


def class_from_function(
func: Callable[..., ClassType],
return_type: Optional[Type[ClassType]] = None,
) -> Type[ClassType]:
"""Creates a dynamic class which if instantiated is equivalent to calling func.
Args:
func: A function that returns an instance of a class. It must have a return type annotation.
"""

@wraps(func)
def __new__(cls, *args, **kwargs):
return func(*args, **kwargs)

if return_type is None:
return_type = inspect.signature(func).return_annotation

if isinstance(return_type, str):
raise RuntimeError("Classmethod instantiation is not supported when the return type annotation is a string.")

class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore
pass

ClassFromFunction.__new__ = __new__ # type: ignore
ClassFromFunction.__doc__ = func.__doc__
ClassFromFunction.__name__ = func.__name__

return ClassFromFunction


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
For full details of accepted arguments see
`ArgumentParser.__init__ <https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.__init__>`_.
"""
super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
super().__init__(*args, **kwargs)
self.add_argument(
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
Expand Down Expand Up @@ -95,7 +62,7 @@ def add_lightning_class_args(

if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class),
(Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
(Trainer, LightningModule, LightningDataModule, Callback),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
Expand Down
28 changes: 14 additions & 14 deletions tests/core/utilities/test_lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_default_args(mock_argparse, tmpdir):
"""Tests default argument parser for Trainer."""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())

parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
args = parser.parse_args([])

args.max_epochs = 5
Expand All @@ -54,7 +54,7 @@ def test_default_args(mock_argparse, tmpdir):
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--default_root_dir=./"], []])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)

args = parser.parse_args(cli_args)
Expand All @@ -79,19 +79,19 @@ def test_add_argparse_args_redefined(cli_args):
("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}),
(
"--auto_lr_find any_string --auto_scale_batch_size ON",
{"auto_lr_find": "any_string", "auto_scale_batch_size": True},
{"auto_lr_find": "any_string", "auto_scale_batch_size": "ON"},
),
("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": True}),
("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": False}),
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}),
("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": "On"}),
("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": "No"}),
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": "FALSE"}),
("--limit_train_batches=100", {"limit_train_batches": 100}),
("--limit_train_batches 0.8", {"limit_train_batches": 0.8}),
],
)
def test_parse_args_parsing(cli_args, expected):
"""Test parsing simple types and None optionals not modified."""
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand All @@ -112,7 +112,7 @@ def test_parse_args_parsing(cli_args, expected):
)
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
"""Test parsing complex types."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand All @@ -137,7 +137,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu):
"""Test parsing of gpus and instantiation of Trainer."""
mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1])
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand Down Expand Up @@ -310,8 +310,8 @@ def test_lightning_cli_args(tmpdir):
config = yaml.safe_load(f.read())
assert "model" not in config
assert "model" not in cli.config
assert config["data"] == cli.config["data"]
assert config["trainer"] == cli.config["trainer"]
assert config["data"] == cli.config["data"].as_dict()
assert config["trainer"] == cli.config["trainer"].as_dict()


@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
Expand Down Expand Up @@ -363,9 +363,9 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
assert os.path.isfile(config_path)
with open(config_path) as f:
config = yaml.safe_load(f.read())
assert config["model"] == cli.config["model"]
assert config["data"] == cli.config["data"]
assert config["trainer"] == cli.config["trainer"]
assert config["model"] == cli.config["model"].as_dict()
assert config["data"] == cli.config["data"].as_dict()
assert config["trainer"] == cli.config["trainer"].as_dict()


def any_model_any_data_cli():
Expand Down

0 comments on commit ebe80e3

Please sign in to comment.