Skip to content

Commit

Permalink
Merge branch 'main' into mace
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger authored Jan 12, 2025
2 parents cbb3b60 + 34066f5 commit a017189
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 229 deletions.
45 changes: 0 additions & 45 deletions .github/container/Dockerfile.maxtext
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,15 @@

ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG SRC_PATH_MAXTEXT=/opt/maxtext
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text

###############################################################################
## build tensorflow-text and lingvo, which do not have working arm64 pip wheels
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} as wheel-builder

#------------------------------------------------------------------------------
# build tensorflow-text from source
#------------------------------------------------------------------------------

# Remove TFTEXT build from source when it has py-3.12 wheels for x86/arm64
FROM wheel-builder as tftext-builder
ARG URLREF_TFTEXT
ARG SRC_PATH_TFTEXT

RUN pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0
RUN git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT}
RUN <<"EOF" bash -exu -o pipefail
cd ${SRC_PATH_TFTEXT}

# The tftext build script queries GitHub, but these requests are sometimes
# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE.
# A workaround (needs to be updated when the tensorflow version changes):
sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh

# Newer versions of LLVM make lld's --undefined-version check of lld is strict
# by default (https://reviews.llvm.org/D135402), but the tftext build seems to
# rely on this behavior.
echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh

./oss_scripts/run_build.sh
EOF

###############################################################################
## Download source and add auxiliary scripts
###############################################################################

FROM ${BASE_IMAGE} as mealkit
ARG URLREF_MAXTEXT
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG SRC_PATH_MAXTEXT
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text

# Preserve version information of tensorflow-text
COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/
RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-maxtext.in

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
Expand Down Expand Up @@ -85,6 +43,3 @@ FROM mealkit as final
RUN pip-finalize.sh

WORKDIR ${SRC_PATH_MAXTEXT}

# When tftext and lingvo wheels are published on pypi.org, revert this
# Dockerfile to 5c4b687b918e6569bca43758c346ad8e67460154
44 changes: 3 additions & 41 deletions .github/container/Dockerfile.pax
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_PAXML=https://github.com/google/paxml.git#main
ARG URLREF_PRAXIS=https://github.com/google/praxis.git#main
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG URLREF_LINGVO=https://github.com/tensorflow/lingvo.git#master
ARG SRC_PATH_PAXML=/opt/paxml
ARG SRC_PATH_PRAXIS=/opt/praxis
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text
ARG SRC_PATH_LINGVO=/opt/lingvo

###############################################################################
Expand All @@ -17,46 +15,15 @@ ARG SRC_PATH_LINGVO=/opt/lingvo
ARG BASE_IMAGE
FROM ${BASE_IMAGE} as wheel-builder

#------------------------------------------------------------------------------
# build tensorflow-text from source
#------------------------------------------------------------------------------

# Remove TFTEXT build from source when it has py-3.12 wheels for x86/arm64
FROM wheel-builder as tftext-builder
ARG URLREF_TFTEXT
ARG SRC_PATH_TFTEXT
RUN <<"EOF" bash -exu -o pipefail
pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0
git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT}
cd ${SRC_PATH_TFTEXT}

# The tftext build script queries GitHub, but these requests are sometimes
# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE.
# A workaround (needs to be updated when the tensorflow version changes):
sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh

# Newer versions of LLVM make lld's --undefined-version check of lld is strict
# by default (https://reviews.llvm.org/D135402), but the tftext build seems to
# rely on this behavior.
echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh

./oss_scripts/run_build.sh
EOF

#------------------------------------------------------------------------------
# build lingvo
#------------------------------------------------------------------------------

# Remove Lingvo build from source when it has py-3.12 wheels for x86/arm64
FROM wheel-builder as lingvo-builder
ARG URLREF_LINGVO
ARG SRC_PATH_TFTEXT
ARG SRC_PATH_LINGVO

# Preserve the version of tensorflow-text
COPY --from=tftext-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml
COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/

ENV USE_BAZEL_VERSION=7.1.2

# build lingvo
Expand Down Expand Up @@ -89,10 +56,9 @@ EOFINNER

fi

pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 /opt/tensorflow_text*.whl
pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0
for pattern in \
"s|tensorflow=|#tensorflow=|g" \
"s|tensorflow-text=|#tensorflow-text=|g" \
"s|dataclasses=|#dataclasses=|g" \
"s|==.*||g" \
; do
Expand All @@ -101,7 +67,7 @@ done
# Lingvo support only python < 3.12, so we hack it and update dependencies
# to be able to build for py-3.12
for pattern in \
"s|tensorflow-text~=2.13.0|tensorflow-text~=2.18.0|g" \
"s|tensorflow-text~=2.13.0|tensorflow-text~=2.18.1|g" \
"s|tensorflow~=2.13.0|tensorflow~=2.18.0|g" \
"s|python_requires='>=3.8,<3.11'|python_requires='>=3.8,<3.13'|" \
; do
Expand All @@ -128,16 +94,12 @@ ARG URLREF_PAXML
ARG URLREF_PRAXIS
ARG SRC_PATH_PAXML
ARG SRC_PATH_PRAXIS
ARG SRC_PATH_TFTEXT

# Preserve version information of tensorflow-text and lingvo
COPY --from=lingvo-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml
COPY --from=lingvo-builder /tmp/lingvo/dist/lingvo*-linux*.whl /opt/
RUN echo "lingvo @ file://$(ls /opt/lingvo*.whl)" >> /opt/pip-tools.d/requirements-paxml.in

COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/
RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-paxml.in

# paxml + praxis
RUN <<"EOF" bash -ex
echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in
Expand Down Expand Up @@ -187,5 +149,5 @@ FROM mealkit as final

RUN pip-finalize.sh

# When tftext and lingvo wheels are published on pypi.org, revert this
# When lingvo wheels are published on pypi.org, revert this
# Dockerfile to 5c4b687b918e6569bca43758c346ad8e67460154
46 changes: 0 additions & 46 deletions .github/container/Dockerfile.t5x
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,18 @@
# docker buildx build -f Dockerfile.t5x --tag t5x --build-arg BASE_IMAGE=ghcr.io/nvidia/jax:mealkit-2024-01-22 .

ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG URLREF_T5X=https://github.com/google-research/t5x.git#main
ARG URLREF_AIRIO=https://github.com/google/airio.git#main
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text
ARG SRC_PATH_T5X=/opt/t5x
ARG SRC_PATH_AIRIO=/opt/airio


###############################################################################
## build several packages which do not have working arm64 pip wheels
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} as wheel-builder

#------------------------------------------------------------------------------
# build tensorflow-text from source
#------------------------------------------------------------------------------
FROM wheel-builder as tftext-builder
ARG URLREF_TFTEXT
ARG SRC_PATH_TFTEXT

RUN pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0
RUN <<"EOF" bash -exu -o pipefail
git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT}
cd ${SRC_PATH_TFTEXT}

# The tftext build script queries GitHub, but these requests are sometimes
# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE.
# A workaround (needs to be updated when the tensorflow version changes):
sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh

# Newer versions of LLVM make lld's --undefined-version check of lld is strict
# by default (https://reviews.llvm.org/D135402), but the tftext build seems to
# rely on this behavior.
echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh

./oss_scripts/run_build.sh
EOF


###############################################################################
## T5X
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} AS mealkit
ARG URLREF_T5X
ARG URLREF_AIRIO
ARG SRC_PATH_TFTEXT
ARG SRC_PATH_T5X
ARG SRC_PATH_AIRIO

# Preserve version information of tensorflow-text
COPY --from=tftext-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml
COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/
RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-t5x.in

RUN <<"EOF" bash -ex
# 1. Fetch T5X
git-clone.sh "${URLREF_T5X}" "${SRC_PATH_T5X}"
Expand Down
10 changes: 5 additions & 5 deletions .github/container/jax-nccl-test
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()

assert (
args.process_id is None or args.distributed
), "--process-id is only relevant with --distributed"
assert args.process_id is None or args.distributed, (
"--process-id is only relevant with --distributed"
)
if args.distributed:
null_args = {
args.coordinator_address is None,
Expand Down Expand Up @@ -108,7 +108,7 @@ if __name__ == "__main__":
f"Rank {args.process_id} has local rank {local_process_id} and "
f"devices {local_device_ids} from a total of {visible_devices} "
f"visible on this node, {args.process_count} processes and "
f"{args.process_count*args.gpus_per_process} total devices.",
f"{args.process_count * args.gpus_per_process} total devices.",
flush=True,
)
jax.distributed.initialize(
Expand Down Expand Up @@ -209,7 +209,7 @@ if __name__ == "__main__":
if host_timer:
result.block_until_ready()
if jax.process_index() == 0:
print(f"First {op} duration {time.time()-start:.2f}s")
print(f"First {op} duration {time.time() - start:.2f}s")
return result

def device_put_local(x: jax.Array):
Expand Down
7 changes: 1 addition & 6 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ t5x:
mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore
mirror/patch/dali-support: file://patches/t5x/mirror-patch-dali-support.patch # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x
mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
mirror/patch/fix-default-vocab: file://patches/t5x/mirror-patch-fix-default-vocab.patch # pull/1609/head # https://github.com/google-research/t5x/pull/1609: Fixes seqio vocab mismatch
paxml:
url: https://github.com/google/paxml.git
mirror_url: https://github.com/nvjax-svc-0/paxml.git
Expand Down Expand Up @@ -59,12 +60,6 @@ lingvo:
tracking_ref: master
latest_verified_commit: 05a076b0783a8bbf4a770095966c472bb37bbf65
mode: git-clone
tensorflow-text:
# Used only in ARM pax and t5x builds
url: https://github.com/tensorflow/text.git
tracking_ref: master
latest_verified_commit: 1779b3ae16f7bd287c4edcf66d62208dc63256f3
mode: git-clone
pydantic:
version: X.Y.Z
mode: pip-constraint
Expand Down
10 changes: 5 additions & 5 deletions .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,11 @@
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
"].itertuples():\n",
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs*1e-3:.2f}s {row.Index}\")"
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs * 1e-3:.2f}s {row.Index}\")"
]
},
{
Expand Down Expand Up @@ -585,9 +585,9 @@
"detailed_mask = (compute_duration_rel_stds > var_threshold) & (\n",
" compute_duration_means > mean_threshold\n",
")\n",
"assert (\n",
" detailed_mask.sum() <= detailed_limit\n",
"), f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
"assert detailed_mask.sum() <= detailed_limit, (\n",
" f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
")\n",
"\n",
"fig, axs = plt.subplots(\n",
" ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True\n",
Expand Down
43 changes: 21 additions & 22 deletions .github/container/nsys_jax/nsys_jax/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def align_profiler_data_timestamps(
# Error if the communication frame doesn't exist at all, but not if it is empty.
# Calling this on a profile that does not contain any communication should
# gracefully yield empty results.
assert (
frames.communication is not None
), "align_profiler_data_timestamps requires a communication frame"
assert frames.communication is not None, (
"align_profiler_data_timestamps requires a communication frame"
)
if not len(frames.communication):
# Nothing to be done, return an empty result
return frames, {}
Expand All @@ -43,9 +43,9 @@ def align_profiler_data_timestamps(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
assert num_profiled_devices == max_collective_size, (
f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
)
# Find the collectives that will be used
align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size]
# Calculate the collectives' end times
Expand Down Expand Up @@ -190,19 +190,18 @@ def _get_message_size(
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
comm_inst = inst.communication_proto()
assert (
comm_inst.opcode
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
assert comm_inst.opcode in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}, (
f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
)

def _byte_size(inst) -> int:
size_bits = math.prod(
Expand Down Expand Up @@ -256,9 +255,9 @@ def _byte_size(inst) -> int:
collective_size = iota_group_list.num_devices_per_group
else:
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
assert len(collective_sizes) == 1, (
f"Heterogeneous collective {comm_inst} could not be interpreted"
)
collective_size = next(iter(collective_sizes))
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
Expand Down
4 changes: 3 additions & 1 deletion .github/container/nsys_jax/nsys_jax/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def _calculate_overlap(thunk_df: pd.DataFrame) -> pd.DataFrame:
serial_mask = (
compute_df["ProjStartMs"].array[1:] >= compute_df["ProjEndMs"].array[:-1]
)
assert serial_mask.all(), f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
assert serial_mask.all(), (
f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
)
# Update the projected duration of each communication kernel to only
# include the non-overlapped time
for comm_thunk in comm_df.itertuples():
Expand Down
Loading

0 comments on commit a017189

Please sign in to comment.