-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nsys-jax: re-work to be more pip-install-able #1165
base: main
Are you sure you want to change the base?
Conversation
1d0b327
to
80c4ca0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is new (but trivial)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really new, it's the old install-protoc
.
.github/container/nsys-jax
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really gone, it became .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really new, it's the old nsys-jax
install_script_template = r"""#!/bin/bash | ||
# | ||
# Usage: ./install.sh [optional arguments to virtualenv] | ||
# | ||
# If it doesn't already exist, this creates a virtual environment named | ||
# `nsys_jax_env` in the current directory and installs Jupyter Lab and the | ||
# dependencies of the Analysis.ipynb notebook that is shipped alongside this | ||
# script inside the output archives of the `nsys-jax` wrapper. | ||
# | ||
# The expectation is that those archives will be copied and extracted on a | ||
# laptop or workstation, and this installation script will be run there, while | ||
# the `nsys-jax` wrapper is executed on a remote GPU cluster. | ||
set -ex | ||
SCRIPT_DIR=$(cd -- "$( dirname -- "${{BASH_SOURCE[0]}}" )" &> /dev/null && pwd) | ||
VIRTUALENV="${{SCRIPT_DIR}}/nsys_jax_venv" | ||
BIN="${{VIRTUALENV}}/bin" | ||
if [[ ! -d "${{VIRTUALENV}}" ]]; then | ||
# Let `virtualenv` find/choose a Python. Currently >=3.10 is supported. | ||
virtualenv -p 3.13 -p 3.12 -p 3.11 -p 3.10 "$@" "${{VIRTUALENV}}" | ||
"${{BIN}}/pip" install -U pip | ||
"${{BIN}}/pip" install 'nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@{jax_toolbox_commit}#subdirectory=.github/container/nsys_jax' | ||
"${{BIN}}/install-flamegraph" "${{VIRTUALENV}}" | ||
"${{BIN}}/install-protoc" "${{VIRTUALENV}}" | ||
else | ||
echo "Virtual environment already exists, not installing anything..." | ||
fi | ||
if [ -z ${{NSYS_JAX_INSTALL_SKIP_LAUNCH+x}} ]; then | ||
# Pick up the current profile data by default | ||
export NSYS_JAX_DEFAULT_PREFIX="${{PWD}}" | ||
# https://setuptools.pypa.io/en/latest/userguide/datafiles.html#accessing-data-files-at-runtime | ||
NOTEBOOK=$("${{BIN}}/python" -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') | ||
echo "Launching: cd ${{SCRIPT_DIR}} && ${{BIN}}/jupyter-lab ${{NOTEBOOK}}" | ||
cd "${{SCRIPT_DIR}}" && "${{BIN}}/jupyter-lab" "${{NOTEBOOK}}" | ||
else | ||
echo "Skipping launch of Jupyter Lab due to NSYS_JAX_INSTALL_SKIP_LAUNCH" | ||
fi | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces the old install.sh
, otherwise the content of this file is mostly the old nsys-jax
.
def create_install_script(output_queue): | ||
""" | ||
Write an install.sh to the output archive that installs nsys-jax at the same | ||
version/commit that the current execution is using. | ||
""" | ||
jax_toolbox_sha = jax_toolbox_sha_with_prefix[1:] | ||
install_script = install_script_template.format(jax_toolbox_commit=jax_toolbox_sha) | ||
output_queue.put(("install.sh", install_script.encode(), COMPRESS_DEFLATE)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This drives the install script generation. The 1:
slice is because the hash is prefixed with g
(for git
).
assert mirror_dir is not None | ||
# Execute post-processing recipes and add any outputs to `ofile` | ||
for analysis in args.analysis: | ||
result, output_prefix = execute_analysis_script( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW this has been factored out and shared between nsys-jax
and nsys-jax-combine
version_file_template = """\ | ||
__version__ = version = {version!r} | ||
__version_tuple__ = version_tuple = {version_tuple!r} | ||
__sha__ = {scm_version.node!r} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is propagated so that the install.sh
in the nsys-jax[-combine]
output .zip
archives can install the same nsys-jax
commit as was used to generate the archive.
bd221c5
to
b66834f
Compare
b66834f
to
faed0fc
Compare
The overarching goal of this PR is to get closer to a world where the
nsys-jax
tooling is straightforwardlypip install
-able. While the diff looks scary, it's mostly re-organisation.Substantive changes:
nsys-jax
no longer bundles Python code in the output archives, theinstall.sh
script provided for users to run on local machines becomes, loosely,install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'
, whereCOMMIT
corresponds to thensys-jax
command that produced the archive. For theghcr.io/nvidia/jax
containers, this is the commit of JAX-Toolbox that triggered the container build.Changes included:
/opt/pip-tools-post-install.d
, whichpip-finalize.sh
will execute the contents of after installing thepip
-managed worldinstall-protoc
to use this, sopip-finalize.sh
can forget about that detail.nvtx_gpu_proj_trace
Python code in Nsight Systems 2024.5 and 2024.6 via this.nsys-jax
installation (specifically for the containers) intoinstall-nsys-jax.sh
and thereby clean upinstall-nsight.sh
. The new script has to be told the git commit hash of JAX-Toolbox that is being built, becausensys-jax
bakes this into an installation script in its output.zip
archives to ensure the local environment matches the profile-collection environment.nsys-jax
,nsys-jax-combine
andinstall-protoc
are now handled via[project.scripts]
inpyproject.toml
instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code betweennsys-jax
andnsys-jax-combine
.jax_nsys
tonsys_jax
for consistency.NSYS_JAX_DEFAULT_PREFIX
environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.