Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --dump option to dump fiddle configurations #107

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/nemo_run/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
)

import fiddle as fdl
import os
from fiddle.experimental import auto_config as _auto_config
from rich.pretty import Pretty
from rich.table import Table

from nemo_run.config import Config, Partial
from nemo_run.core.execution.base import Executor
from nemo_run.core.frontend.console.api import CONSOLE, CustomConfigRepr
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer

F = TypeVar("F", bound=Callable[..., Any])
T = TypeVar("T")
Expand Down Expand Up @@ -233,6 +236,27 @@ def dryrun_fn(
fdl.build(configured_fn)


def dump_fn_or_script(
fn_or_script: Union[fdl.Partial, fdl.Config],
file_path: str,
serializer_cls: Type[ZlibJSONSerializer | YamlSerializer] = ZlibJSONSerializer,
) -> None:
"""
Serializes `fn_or_script` and writes it to the specified file path.

Args:
fn_or_script (Union[fdl.Partial, fdl.Config]): The function or script object to be serialized.
file_path (str): The file path where the serialized data will be saved.
serializer_cls (Type[ZlibJSONSerializer | YamlSerializer], optional):
The serializer class to use. Defaults to `ZlibJSONSerializer`.
"""
serialized_data = serializer_cls().serialize(fn_or_script)
os.makedirs(os.path.dirname(file_path), exist_ok=True)

with open(file_path, "w") as f:
f.write(serialized_data)


@runtime_checkable
class AutoConfigProtocol(Protocol):
def __auto_config__(self) -> bool: ...
12 changes: 12 additions & 0 deletions src/nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from typer.models import OptionInfo
from typing_extensions import ParamSpec

from nemo_run.api import dump_fn_or_script
from nemo_run.cli import devspace as devspace_cli
from nemo_run.cli import experiment as experiment_cli
from nemo_run.cli.cli_parser import parse_cli_args, parse_factory
Expand Down Expand Up @@ -692,6 +693,7 @@ class RunContext:
name (str): Name of the run.
direct (bool): If True, execute the run directly without using a scheduler.
dryrun (bool): If True, print the scheduler request without submitting.
dump (str): If True, serialize and output the configuration without execution.
factory (Optional[str]): Name of a predefined factory to use.
load (Optional[str]): Path to load a factory from a directory.
repl (bool): If True, enter interactive mode.
Expand All @@ -705,6 +707,7 @@ class RunContext:
name: str
direct: bool = False
dryrun: bool = False
dump: str = ""
factory: Optional[str] = None
load: Optional[str] = None
repl: bool = False
Expand Down Expand Up @@ -755,6 +758,9 @@ def command(
dryrun: bool = typer.Option(
False, "--dryrun", help="Print the scheduler request without submitting"
),
dump: Optional[str] = typer.Option(
None, "--dump", help="Serialize and dump configuration without executing"
),
factory: Optional[str] = typer.Option(
None, "--factory", "-f", help="Predefined factory to use"
),
Expand All @@ -775,6 +781,7 @@ def command(
name=name,
direct=direct,
dryrun=dryrun,
dump=dump,
factory=factory or default_factory,
load=load,
repl=repl,
Expand Down Expand Up @@ -841,6 +848,9 @@ def cli_execute(
_, run_args, filtered_args = _parse_prefixed_args(args, "run")
self.parse_args(run_args)

if self.dump:
self.dryrun = True

if self.load:
raise NotImplementedError("Load is not implemented yet")

Expand All @@ -867,6 +877,8 @@ def _execute_task(self, fn: Callable, task_args: List[str]):
def run_task():
nonlocal task
run.dryrun_fn(task, executor=self.executor)
if (self.dump is not None) and (self.dump != ""):
dump_fn_or_script(task, self.dump)

if self.dryrun:
console.print(f"[bold cyan]Dry run for {self.name}:[/bold cyan]")
Expand Down
Loading