diff --git a/src/nemo_run/api.py b/src/nemo_run/api.py index 2590d4f..5e56a4e 100644 --- a/src/nemo_run/api.py +++ b/src/nemo_run/api.py @@ -31,6 +31,7 @@ ) import fiddle as fdl +import os from fiddle.experimental import auto_config as _auto_config from rich.pretty import Pretty from rich.table import Table @@ -38,6 +39,8 @@ from nemo_run.config import Config, Partial from nemo_run.core.execution.base import Executor from nemo_run.core.frontend.console.api import CONSOLE, CustomConfigRepr +from nemo_run.core.serialization.yaml import YamlSerializer +from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer F = TypeVar("F", bound=Callable[..., Any]) T = TypeVar("T") @@ -233,6 +236,27 @@ def dryrun_fn( fdl.build(configured_fn) +def dump_fn_or_script( + fn_or_script: Union[fdl.Partial, fdl.Config], + file_path: str, + serializer_cls: Type[ZlibJSONSerializer | YamlSerializer] = ZlibJSONSerializer, +) -> None: + """ + Serializes `fn_or_script` and writes it to the specified file path. + + Args: + fn_or_script (Union[fdl.Partial, fdl.Config]): The function or script object to be serialized. + file_path (str): The file path where the serialized data will be saved. + serializer_cls (Type[ZlibJSONSerializer | YamlSerializer], optional): + The serializer class to use. Defaults to `ZlibJSONSerializer`. + """ + serialized_data = serializer_cls().serialize(fn_or_script) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, "w") as f: + f.write(serialized_data) + + @runtime_checkable class AutoConfigProtocol(Protocol): def __auto_config__(self) -> bool: ... diff --git a/src/nemo_run/cli/api.py b/src/nemo_run/cli/api.py index 3b49553..99d0e0e 100644 --- a/src/nemo_run/cli/api.py +++ b/src/nemo_run/cli/api.py @@ -53,6 +53,7 @@ from typer.models import OptionInfo from typing_extensions import ParamSpec +from nemo_run.api import dump_fn_or_script from nemo_run.cli import devspace as devspace_cli from nemo_run.cli import experiment as experiment_cli from nemo_run.cli.cli_parser import parse_cli_args, parse_factory @@ -692,6 +693,7 @@ class RunContext: name (str): Name of the run. direct (bool): If True, execute the run directly without using a scheduler. dryrun (bool): If True, print the scheduler request without submitting. + dump (str): If True, serialize and output the configuration without execution. factory (Optional[str]): Name of a predefined factory to use. load (Optional[str]): Path to load a factory from a directory. repl (bool): If True, enter interactive mode. @@ -705,6 +707,7 @@ class RunContext: name: str direct: bool = False dryrun: bool = False + dump: str = "" factory: Optional[str] = None load: Optional[str] = None repl: bool = False @@ -755,6 +758,9 @@ def command( dryrun: bool = typer.Option( False, "--dryrun", help="Print the scheduler request without submitting" ), + dump: Optional[str] = typer.Option( + None, "--dump", help="Serialize and dump configuration without executing" + ), factory: Optional[str] = typer.Option( None, "--factory", "-f", help="Predefined factory to use" ), @@ -775,6 +781,7 @@ def command( name=name, direct=direct, dryrun=dryrun, + dump=dump, factory=factory or default_factory, load=load, repl=repl, @@ -841,6 +848,9 @@ def cli_execute( _, run_args, filtered_args = _parse_prefixed_args(args, "run") self.parse_args(run_args) + if self.dump: + self.dryrun = True + if self.load: raise NotImplementedError("Load is not implemented yet") @@ -867,6 +877,8 @@ def _execute_task(self, fn: Callable, task_args: List[str]): def run_task(): nonlocal task run.dryrun_fn(task, executor=self.executor) + if (self.dump is not None) and (self.dump != ""): + dump_fn_or_script(task, self.dump) if self.dryrun: console.print(f"[bold cyan]Dry run for {self.name}:[/bold cyan]")