Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: re-work to be more pip-install-able #1165

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

olupton
Copy link
Collaborator

@olupton olupton commented Nov 21, 2024

The overarching goal of this PR is to get closer to a world where the nsys-jax tooling is straightforwardly pip install-able. While the diff looks scary, it's mostly re-organisation.

Substantive changes:

  • nsys-jax no longer bundles Python code in the output archives, the install.sh script provided for users to run on local machines becomes, loosely, install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax', where COMMIT corresponds to the nsys-jax command that produced the archive. For the ghcr.io/nvidia/jax containers, this is the commit of JAX-Toolbox that triggered the container build.

Changes included:

  • Introduce /opt/pip-tools-post-install.d, which pip-finalize.sh will execute the contents of after installing the pip-managed world
  • Move nsys-jax installation (specifically for the containers) into install-nsys-jax.sh and thereby clean up install-nsight.sh. The new script has to be told the git commit hash of JAX-Toolbox that is being built, because nsys-jax bakes this into an installation script in its output .zip archives to ensure the local environment matches the profile-collection environment.
  • The CLI tools like nsys-jax, nsys-jax-combine and install-protoc are now handled via [project.scripts] in pyproject.toml instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code between nsys-jax and nsys-jax-combine.
  • The Python library is renamed from jax_nsys to nsys_jax for consistency.
  • It's now possible to set the default data loading path via the NSYS_JAX_DEFAULT_PREFIX environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.

@olupton olupton force-pushed the olupton/pip-install-nsys-jax branch 2 times, most recently from 1d0b327 to 80c4ca0 Compare November 22, 2024 13:59
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is new (but trivial)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really new, it's the old install-protoc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really gone, it became .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really new, it's the old nsys-jax

Comment on lines 49 to 86
install_script_template = r"""#!/bin/bash
#
# Usage: ./install.sh [optional arguments to virtualenv]
#
# If it doesn't already exist, this creates a virtual environment named
# `nsys_jax_env` in the current directory and installs Jupyter Lab and the
# dependencies of the Analysis.ipynb notebook that is shipped alongside this
# script inside the output archives of the `nsys-jax` wrapper.
#
# The expectation is that those archives will be copied and extracted on a
# laptop or workstation, and this installation script will be run there, while
# the `nsys-jax` wrapper is executed on a remote GPU cluster.
set -ex
SCRIPT_DIR=$(cd -- "$( dirname -- "${{BASH_SOURCE[0]}}" )" &> /dev/null && pwd)
VIRTUALENV="${{SCRIPT_DIR}}/nsys_jax_venv"
BIN="${{VIRTUALENV}}/bin"
if [[ ! -d "${{VIRTUALENV}}" ]]; then
# Let `virtualenv` find/choose a Python. Currently >=3.10 is supported.
virtualenv -p 3.13 -p 3.12 -p 3.11 -p 3.10 "$@" "${{VIRTUALENV}}"
"${{BIN}}/pip" install -U pip
"${{BIN}}/pip" install 'nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@{jax_toolbox_commit}#subdirectory=.github/container/nsys_jax'
"${{BIN}}/install-flamegraph" "${{VIRTUALENV}}"
"${{BIN}}/install-protoc" "${{VIRTUALENV}}"
else
echo "Virtual environment already exists, not installing anything..."
fi
if [ -z ${{NSYS_JAX_INSTALL_SKIP_LAUNCH+x}} ]; then
# Pick up the current profile data by default
export NSYS_JAX_DEFAULT_PREFIX="${{PWD}}"
# https://setuptools.pypa.io/en/latest/userguide/datafiles.html#accessing-data-files-at-runtime
NOTEBOOK=$("${{BIN}}/python" -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")')
echo "Launching: cd ${{SCRIPT_DIR}} && ${{BIN}}/jupyter-lab ${{NOTEBOOK}}"
cd "${{SCRIPT_DIR}}" && "${{BIN}}/jupyter-lab" "${{NOTEBOOK}}"
else
echo "Skipping launch of Jupyter Lab due to NSYS_JAX_INSTALL_SKIP_LAUNCH"
fi
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces the old install.sh, otherwise the content of this file is mostly the old nsys-jax.

Comment on lines +88 to +96
def create_install_script(output_queue):
"""
Write an install.sh to the output archive that installs nsys-jax at the same
version/commit that the current execution is using.
"""
jax_toolbox_sha = jax_toolbox_sha_with_prefix[1:]
install_script = install_script_template.format(jax_toolbox_commit=jax_toolbox_sha)
output_queue.put(("install.sh", install_script.encode(), COMPRESS_DEFLATE))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This drives the install script generation. The 1: slice is because the hash is prefixed with g (for git).

assert mirror_dir is not None
# Execute post-processing recipes and add any outputs to `ofile`
for analysis in args.analysis:
result, output_prefix = execute_analysis_script(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW this has been factored out and shared between nsys-jax and nsys-jax-combine

version_file_template = """\
__version__ = version = {version!r}
__version_tuple__ = version_tuple = {version_tuple!r}
__sha__ = {scm_version.node!r}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is propagated so that the install.sh in the nsys-jax[-combine] output .zip archives can install the same nsys-jax commit as was used to generate the archive.

.github/workflows/nsys-jax.yaml Outdated Show resolved Hide resolved
@olupton olupton force-pushed the olupton/pip-install-nsys-jax branch 2 times, most recently from bd221c5 to b66834f Compare November 25, 2024 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant