Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: re-work to be more pip-install-able #1165

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8e12eee
entrypoints for CLI tools
olupton Nov 19, 2024
6558f70
install-protoc: reorganise into jax_nsys.scripts
olupton Nov 19, 2024
01acaab
reorganise
olupton Nov 19, 2024
9fd4aa0
move nsys-jax
olupton Nov 19, 2024
ebd1d7a
move nsys-jax-combine
olupton Nov 19, 2024
782cb36
dedupe
olupton Nov 19, 2024
5100919
explain what's needed
olupton Nov 19, 2024
6203b0e
Separate nsys-jax docs page, expand docs
olupton Nov 20, 2024
cbb259e
More installation documentation
olupton Nov 20, 2024
66aff8a
reorganise to make jax_nsys a more vanilla Python package
olupton Nov 20, 2024
af063fd
write the git commit that was installed
olupton Nov 20, 2024
d20cfed
pip juggling
olupton Nov 21, 2024
4eaa96b
rework nsys-jax install
olupton Nov 21, 2024
0c80d4a
explicit path
olupton Nov 21, 2024
1431775
Respect SRC_PATH_XLA
olupton Nov 21, 2024
dbf51f4
Install Analysis.ipynb
olupton Nov 21, 2024
3dc5d0f
post-install hook
olupton Nov 21, 2024
8bbdcb0
docs
olupton Nov 21, 2024
b600944
tweaks
olupton Nov 21, 2024
6b7471a
mypy
olupton Nov 21, 2024
1262fbe
Generate install.sh inside nsys-jax
olupton Nov 21, 2024
65f91d7
search-replace jax_nsys -> nsys_jax and jax-nsys -> nsys-jax
olupton Nov 21, 2024
ca1c2b3
install script tuning
olupton Nov 21, 2024
bcdb39c
missed renaming
olupton Nov 21, 2024
3575256
tmpcommit
olupton Nov 21, 2024
e180258
analyses rename
olupton Nov 21, 2024
2d32b76
refinements
olupton Nov 22, 2024
707e016
Share run-analysis-script logic between nsys-jax and nsys-jax-combine
olupton Nov 22, 2024
62d2a37
allow overriding the default prefix via environment variable
olupton Nov 22, 2024
315926b
pick up default
olupton Nov 22, 2024
e44145a
revert experiment that seems unneeded
olupton Nov 22, 2024
7bc853d
format, start fixing CI
olupton Nov 22, 2024
b87951f
CI fixups
olupton Nov 22, 2024
5bccae4
CI fixups
olupton Nov 22, 2024
ce88a3c
CI fixups
olupton Nov 22, 2024
ec571ab
CI fixups
olupton Nov 22, 2024
c910ec2
Patch execution of nsys-jax pytest tests
olupton Nov 22, 2024
58fd4fe
Add CI job testing offline analysis of nsys-jax output archives
olupton Nov 25, 2024
faed0fc
avoid baking in PR trial merge commit hashes
olupton Nov 26, 2024
3a2c4b7
fixup
olupton Nov 26, 2024
e9394cc
fixup
olupton Nov 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.2-devel-ubuntu22.04
ARG GIT_USER_NAME="JAX Toolbox"
ARG [email protected]
ARG CLANG_VERSION=18
ARG JAX_TOOLBOX_REF

###############################################################################
## Obtain GCP's NCCL TCPx plugin
Expand Down Expand Up @@ -30,6 +31,7 @@ ARG BASE_IMAGE
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG CLANG_VERSION
ARG JAX_TOOLBOX_REF
ENV CUDA_BASE_IMAGE=${BASE_IMAGE}

###############################################################################
Expand Down Expand Up @@ -110,7 +112,7 @@ RUN <<"EOF" bash -ex
git config --global user.name "${GIT_USER_NAME}"
git config --global user.email "${GIT_USER_EMAIL}"
EOF
RUN mkdir -p /opt/pip-tools.d
RUN mkdir -p /opt/pip-tools.d /opt/pip-tools-post-install.d
ADD --chmod=777 \
git-clone.sh \
pip-finalize.sh \
Expand Down Expand Up @@ -141,7 +143,6 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH}
###############################################################################

ADD install-nsight.sh /usr/local/bin
ADD nsys-2024.5-tid-export.patch /opt/nvidia
RUN install-nsight.sh

###############################################################################
Expand Down Expand Up @@ -183,7 +184,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH}
ADD install-nccl-sanity-check.sh /usr/local/bin
ADD nccl-sanity-check.cu /opt
RUN install-nccl-sanity-check.sh
ADD jax-nccl-test parallel-launch /usr/local/bin
ADD jax-nccl-test parallel-launch /usr/local/bin/

###############################################################################
## Add the systemcheck to the entrypoint.
Expand All @@ -199,23 +200,11 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
# COPY gcp-autoconfig.sh /opt/nvidia/entrypoint.d/

###############################################################################
## Add helper scripts for profiling with Nsight Systems
##
## The scripts saved to /opt/jax_nsys are embedded in the output archives
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are
## only used inside the containers.
###############################################################################
ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
# The jax_nsys package should be installed inside the containers, so nsys-jax
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container
# environment, without an extra layer of virtual environment indirection.
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
# This should be embedded in output archives and be runnable inside containers
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
# Should be available for execution inside the containers, should not be
# embedded in the output archives.
ADD jax_nsys_tests/ /opt/jax_nsys_tests
## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems
###############################################################################

ADD install-nsys-jax.sh /usr/local/bin
RUN install-nsys-jax.sh ${JAX_TOOLBOX_REF}

###############################################################################
## Copy manifest file to the container
Expand Down
11 changes: 0 additions & 11 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,3 @@ apt-get install -y nsight-compute nsight-systems-cli-2024.6.1
apt-get clean

rm -rf /var/lib/apt/lists/*

for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
if [[ -d "${NSYS}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
done

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
ln -s $(dirname $(realpath $(command -v nsys)))/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in
32 changes: 32 additions & 0 deletions .github/container/install-nsys-jax.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
set -exo pipefail

REF="$1"
if [[ -z "${REF}" ]]; then
echo "$0: <git ref of JAX-Toolbox>"
exit 1
fi

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
NSYS_DIR=$(dirname $(realpath $(command -v nsys)))
ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in

# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine,
# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the
# nsys_jax Python library.
URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax"
echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in

# protobuf will be installed at least as a dependency of nsys_jax in the base
# image, but the installed version is likely to be influenced by other packages.
echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc
chmod 755 /opt/pip-tools-post-install.d/protoc

# Make sure flamegraph.pl is available
echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph
chmod 755 /opt/pip-tools-post-install.d/flamegraph

# Make sure Nsight Systems Python patches are installed if needed
echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys
chmod 755 /opt/pip-tools-post-install.d/patch-nsys
65 changes: 0 additions & 65 deletions .github/container/jax_nsys/install-protoc

This file was deleted.

38 changes: 0 additions & 38 deletions .github/container/jax_nsys/install.sh

This file was deleted.

17 changes: 0 additions & 17 deletions .github/container/jax_nsys/python/jax_nsys/pyproject.toml

This file was deleted.

Loading
Loading