diff --git a/.github/container/Dockerfile.maxtext b/.github/container/Dockerfile.maxtext index 87b73efcd..6694aa821 100644 --- a/.github/container/Dockerfile.maxtext +++ b/.github/container/Dockerfile.maxtext @@ -2,43 +2,7 @@ ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main -ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master ARG SRC_PATH_MAXTEXT=/opt/maxtext -ARG SRC_PATH_TFTEXT=/opt/tensorflow-text - -############################################################################### -## build tensorflow-text and lingvo, which do not have working arm64 pip wheels -############################################################################### - -ARG BASE_IMAGE -FROM ${BASE_IMAGE} as wheel-builder - -#------------------------------------------------------------------------------ -# build tensorflow-text from source -#------------------------------------------------------------------------------ - -# Remove TFTEXT build from source when it has py-3.12 wheels for x86/arm64 -FROM wheel-builder as tftext-builder -ARG URLREF_TFTEXT -ARG SRC_PATH_TFTEXT - -RUN pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 -RUN git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT} -RUN <<"EOF" bash -exu -o pipefail -cd ${SRC_PATH_TFTEXT} - -# The tftext build script queries GitHub, but these requests are sometimes -# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE. -# A workaround (needs to be updated when the tensorflow version changes): -sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh - -# Newer versions of LLVM make lld's --undefined-version check of lld is strict -# by default (https://reviews.llvm.org/D135402), but the tftext build seems to -# rely on this behavior. -echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh - -./oss_scripts/run_build.sh -EOF ############################################################################### ## Download source and add auxiliary scripts @@ -46,13 +10,7 @@ EOF FROM ${BASE_IMAGE} as mealkit ARG URLREF_MAXTEXT -ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master ARG SRC_PATH_MAXTEXT -ARG SRC_PATH_TFTEXT=/opt/tensorflow-text - -# Preserve version information of tensorflow-text -COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ -RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-maxtext.in RUN <<"EOF" bash -ex git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} @@ -85,6 +43,3 @@ FROM mealkit as final RUN pip-finalize.sh WORKDIR ${SRC_PATH_MAXTEXT} - -# When tftext and lingvo wheels are published on pypi.org, revert this -# Dockerfile to 5c4b687b918e6569bca43758c346ad8e67460154 diff --git a/.github/container/Dockerfile.pax b/.github/container/Dockerfile.pax index 938bd853c..5a8c23ef7 100644 --- a/.github/container/Dockerfile.pax +++ b/.github/container/Dockerfile.pax @@ -3,11 +3,9 @@ ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax ARG URLREF_PAXML=https://github.com/google/paxml.git#main ARG URLREF_PRAXIS=https://github.com/google/praxis.git#main -ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master ARG URLREF_LINGVO=https://github.com/tensorflow/lingvo.git#master ARG SRC_PATH_PAXML=/opt/paxml ARG SRC_PATH_PRAXIS=/opt/praxis -ARG SRC_PATH_TFTEXT=/opt/tensorflow-text ARG SRC_PATH_LINGVO=/opt/lingvo ############################################################################### @@ -17,32 +15,6 @@ ARG SRC_PATH_LINGVO=/opt/lingvo ARG BASE_IMAGE FROM ${BASE_IMAGE} as wheel-builder -#------------------------------------------------------------------------------ -# build tensorflow-text from source -#------------------------------------------------------------------------------ - -# Remove TFTEXT build from source when it has py-3.12 wheels for x86/arm64 -FROM wheel-builder as tftext-builder -ARG URLREF_TFTEXT -ARG SRC_PATH_TFTEXT -RUN <<"EOF" bash -exu -o pipefail -pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 -git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT} -cd ${SRC_PATH_TFTEXT} - -# The tftext build script queries GitHub, but these requests are sometimes -# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE. -# A workaround (needs to be updated when the tensorflow version changes): -sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh - -# Newer versions of LLVM make lld's --undefined-version check of lld is strict -# by default (https://reviews.llvm.org/D135402), but the tftext build seems to -# rely on this behavior. -echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh - -./oss_scripts/run_build.sh -EOF - #------------------------------------------------------------------------------ # build lingvo #------------------------------------------------------------------------------ @@ -50,13 +22,8 @@ EOF # Remove Lingvo build from source when it has py-3.12 wheels for x86/arm64 FROM wheel-builder as lingvo-builder ARG URLREF_LINGVO -ARG SRC_PATH_TFTEXT ARG SRC_PATH_LINGVO -# Preserve the version of tensorflow-text -COPY --from=tftext-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml -COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ - ENV USE_BAZEL_VERSION=7.1.2 # build lingvo @@ -89,10 +56,9 @@ EOFINNER fi -pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 /opt/tensorflow_text*.whl +pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 for pattern in \ "s|tensorflow=|#tensorflow=|g" \ - "s|tensorflow-text=|#tensorflow-text=|g" \ "s|dataclasses=|#dataclasses=|g" \ "s|==.*||g" \ ; do @@ -101,7 +67,7 @@ done # Lingvo support only python < 3.12, so we hack it and update dependencies # to be able to build for py-3.12 for pattern in \ - "s|tensorflow-text~=2.13.0|tensorflow-text~=2.18.0|g" \ + "s|tensorflow-text~=2.13.0|tensorflow-text~=2.18.1|g" \ "s|tensorflow~=2.13.0|tensorflow~=2.18.0|g" \ "s|python_requires='>=3.8,<3.11'|python_requires='>=3.8,<3.13'|" \ ; do @@ -128,16 +94,12 @@ ARG URLREF_PAXML ARG URLREF_PRAXIS ARG SRC_PATH_PAXML ARG SRC_PATH_PRAXIS -ARG SRC_PATH_TFTEXT # Preserve version information of tensorflow-text and lingvo COPY --from=lingvo-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml COPY --from=lingvo-builder /tmp/lingvo/dist/lingvo*-linux*.whl /opt/ RUN echo "lingvo @ file://$(ls /opt/lingvo*.whl)" >> /opt/pip-tools.d/requirements-paxml.in -COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ -RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-paxml.in - # paxml + praxis RUN <<"EOF" bash -ex echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in @@ -187,5 +149,5 @@ FROM mealkit as final RUN pip-finalize.sh -# When tftext and lingvo wheels are published on pypi.org, revert this +# When lingvo wheels are published on pypi.org, revert this # Dockerfile to 5c4b687b918e6569bca43758c346ad8e67460154 diff --git a/.github/container/Dockerfile.t5x b/.github/container/Dockerfile.t5x index ea4bbf2ec..1568ff559 100644 --- a/.github/container/Dockerfile.t5x +++ b/.github/container/Dockerfile.t5x @@ -3,64 +3,18 @@ # docker buildx build -f Dockerfile.t5x --tag t5x --build-arg BASE_IMAGE=ghcr.io/nvidia/jax:mealkit-2024-01-22 . ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax -ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master ARG URLREF_T5X=https://github.com/google-research/t5x.git#main ARG URLREF_AIRIO=https://github.com/google/airio.git#main -ARG SRC_PATH_TFTEXT=/opt/tensorflow-text ARG SRC_PATH_T5X=/opt/t5x ARG SRC_PATH_AIRIO=/opt/airio - -############################################################################### -## build several packages which do not have working arm64 pip wheels -############################################################################### - -ARG BASE_IMAGE -FROM ${BASE_IMAGE} as wheel-builder - -#------------------------------------------------------------------------------ -# build tensorflow-text from source -#------------------------------------------------------------------------------ -FROM wheel-builder as tftext-builder -ARG URLREF_TFTEXT -ARG SRC_PATH_TFTEXT - -RUN pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 -RUN <<"EOF" bash -exu -o pipefail -git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT} -cd ${SRC_PATH_TFTEXT} - -# The tftext build script queries GitHub, but these requests are sometimes -# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE. -# A workaround (needs to be updated when the tensorflow version changes): -sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh - -# Newer versions of LLVM make lld's --undefined-version check of lld is strict -# by default (https://reviews.llvm.org/D135402), but the tftext build seems to -# rely on this behavior. -echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh - -./oss_scripts/run_build.sh -EOF - - -############################################################################### -## T5X -############################################################################### - ARG BASE_IMAGE FROM ${BASE_IMAGE} AS mealkit ARG URLREF_T5X ARG URLREF_AIRIO -ARG SRC_PATH_TFTEXT ARG SRC_PATH_T5X ARG SRC_PATH_AIRIO -# Preserve version information of tensorflow-text -COPY --from=tftext-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml -COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ -RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-t5x.in - RUN <<"EOF" bash -ex # 1. Fetch T5X git-clone.sh "${URLREF_T5X}" "${SRC_PATH_T5X}" diff --git a/.github/container/jax-nccl-test b/.github/container/jax-nccl-test index 706713baf..1d59962bf 100755 --- a/.github/container/jax-nccl-test +++ b/.github/container/jax-nccl-test @@ -74,9 +74,9 @@ if __name__ == "__main__": ) args = parser.parse_args() - assert ( - args.process_id is None or args.distributed - ), "--process-id is only relevant with --distributed" + assert args.process_id is None or args.distributed, ( + "--process-id is only relevant with --distributed" + ) if args.distributed: null_args = { args.coordinator_address is None, @@ -108,7 +108,7 @@ if __name__ == "__main__": f"Rank {args.process_id} has local rank {local_process_id} and " f"devices {local_device_ids} from a total of {visible_devices} " f"visible on this node, {args.process_count} processes and " - f"{args.process_count*args.gpus_per_process} total devices.", + f"{args.process_count * args.gpus_per_process} total devices.", flush=True, ) jax.distributed.initialize( @@ -209,7 +209,7 @@ if __name__ == "__main__": if host_timer: result.block_until_ready() if jax.process_index() == 0: - print(f"First {op} duration {time.time()-start:.2f}s") + print(f"First {op} duration {time.time() - start:.2f}s") return result def device_put_local(x: jax.Array): diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index b9c06e2e6..9845130e3 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -31,6 +31,7 @@ t5x: mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore mirror/patch/dali-support: file://patches/t5x/mirror-patch-dali-support.patch # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100) + mirror/patch/fix-default-vocab: file://patches/t5x/mirror-patch-fix-default-vocab.patch # pull/1609/head # https://github.com/google-research/t5x/pull/1609: Fixes seqio vocab mismatch paxml: url: https://github.com/google/paxml.git mirror_url: https://github.com/nvjax-svc-0/paxml.git @@ -59,12 +60,6 @@ lingvo: tracking_ref: master latest_verified_commit: 05a076b0783a8bbf4a770095966c472bb37bbf65 mode: git-clone -tensorflow-text: - # Used only in ARM pax and t5x builds - url: https://github.com/tensorflow/text.git - tracking_ref: master - latest_verified_commit: 1779b3ae16f7bd287c4edcf66d62208dc63256f3 - mode: git-clone pydantic: version: X.Y.Z mode: pip-constraint diff --git a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb index d8e8c6248..8159f2d28 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb +++ b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb @@ -299,11 +299,11 @@ "# Print out the largest entries adding up to at least this fraction of the total\n", "threshold = 0.97\n", "compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n", - "print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n", + "print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n", "for row in compile_summary[\n", " compile_summary[\"FracNonChild\"].cumsum() <= threshold\n", "].itertuples():\n", - " print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs*1e-3:.2f}s {row.Index}\")" + " print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs * 1e-3:.2f}s {row.Index}\")" ] }, { @@ -585,9 +585,9 @@ "detailed_mask = (compute_duration_rel_stds > var_threshold) & (\n", " compute_duration_means > mean_threshold\n", ")\n", - "assert (\n", - " detailed_mask.sum() <= detailed_limit\n", - "), f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n", + "assert detailed_mask.sum() <= detailed_limit, (\n", + " f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n", + ")\n", "\n", "fig, axs = plt.subplots(\n", " ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True\n", diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py index 4e72a33fb..c4e37fdf9 100644 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -28,9 +28,9 @@ def align_profiler_data_timestamps( # Error if the communication frame doesn't exist at all, but not if it is empty. # Calling this on a profile that does not contain any communication should # gracefully yield empty results. - assert ( - frames.communication is not None - ), "align_profiler_data_timestamps requires a communication frame" + assert frames.communication is not None, ( + "align_profiler_data_timestamps requires a communication frame" + ) if not len(frames.communication): # Nothing to be done, return an empty result return frames, {} @@ -43,9 +43,9 @@ def align_profiler_data_timestamps( f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" ) return frames, {} - assert ( - num_profiled_devices == max_collective_size - ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" + assert num_profiled_devices == max_collective_size, ( + f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" + ) # Find the collectives that will be used align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size] # Calculate the collectives' end times @@ -190,19 +190,18 @@ def _get_message_size( ) -> tuple[int, str, int, float, float]: _, inst = module_proto.find_instruction(instruction) comm_inst = inst.communication_proto() - assert ( - comm_inst.opcode - in { - "all-gather-start", - "all-reduce-start", - "all-to-all", - "collective-broadcast", - "collective-permute-start", - "dynamic-slice", - "dynamic-update-slice", - "reduce-scatter", - } - ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + assert comm_inst.opcode in { + "all-gather-start", + "all-reduce-start", + "all-to-all", + "collective-broadcast", + "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", + "reduce-scatter", + }, ( + f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + ) def _byte_size(inst) -> int: size_bits = math.prod( @@ -256,9 +255,9 @@ def _byte_size(inst) -> int: collective_size = iota_group_list.num_devices_per_group else: collective_sizes = set(len(group.replica_ids) for group in replica_groups) - assert ( - len(collective_sizes) == 1 - ), f"Heterogeneous collective {comm_inst} could not be interpreted" + assert len(collective_sizes) == 1, ( + f"Heterogeneous collective {comm_inst} could not be interpreted" + ) collective_size = next(iter(collective_sizes)) total_msg_size = 0 for operand_id in comm_inst.operand_ids: diff --git a/.github/container/nsys_jax/nsys_jax/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py index a3e848dc8..608f5a5d6 100644 --- a/.github/container/nsys_jax/nsys_jax/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -56,7 +56,9 @@ def _calculate_overlap(thunk_df: pd.DataFrame) -> pd.DataFrame: serial_mask = ( compute_df["ProjStartMs"].array[1:] >= compute_df["ProjEndMs"].array[:-1] ) - assert serial_mask.all(), f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}" + assert serial_mask.all(), ( + f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}" + ) # Update the projected duration of each communication kernel to only # include the non-overlapped time for comm_thunk in comm_df.itertuples(): diff --git a/.github/container/nsys_jax/nsys_jax/protobuf.py b/.github/container/nsys_jax/nsys_jax/protobuf.py index a43160f19..0175ed5c0 100644 --- a/.github/container/nsys_jax/nsys_jax/protobuf.py +++ b/.github/container/nsys_jax/nsys_jax/protobuf.py @@ -110,9 +110,9 @@ def _visit_computation(computation_id): if called_inst.opcode in comm_opcodes or _is_offloading_instruction( called_inst ): - assert ( - self._comm_proto is None - ), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" + assert self._comm_proto is None, ( + f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" + ) self._comm_proto = called_inst for called_id in self._proto.called_computation_ids: diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index 522e636f1..136843999 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -341,7 +341,7 @@ def copy_proto_files_to_tmp( if not osp.isdir(dst_dir): os.makedirs(dst_dir) shutil.copy(osp.join(root, proto), osp.join(proto_dir, proto)) - print(f"{archive_name}: gathered .proto files in {time.time()-start:.2f}s") + print(f"{archive_name}: gathered .proto files in {time.time() - start:.2f}s") return proto_dir, proto_files def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): @@ -369,7 +369,7 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): if osp.isdir(full_path) or not osp.exists(full_path): continue output_queue.put((ofile, full_path, COMPRESS_NONE)) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s") def compress_and_archive(prefix, file, output_queue): """ @@ -403,7 +403,7 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): ) for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): compress_and_archive(tmp_dir, ofile, output_queue) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s") def save_device_stream_thread_names(tmp_dir, report, output_queue): """ @@ -501,7 +501,7 @@ def table_columns(table_name): else: print("WARNING: NOT writing device metadata, no device activity profiled?") print( - f"{archive_name}: extracted device/thread names in {time.time()-start:.2f}s" + f"{archive_name}: extracted device/thread names in {time.time() - start:.2f}s" ) def find_pb_files_in_tmp(tmp_dir): @@ -553,7 +553,7 @@ def gather_source_files( continue assert osp.isabs(src_file), f"{src_file} is not absolute" output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE)) - print(f"{archive_name}: gathered source code in {time.time()-start:.2f}s") + print(f"{archive_name}: gathered source code in {time.time() - start:.2f}s") def execute_analysis_scripts(mirror_dir, analysis_scripts): """ @@ -631,7 +631,7 @@ def write_output_file(to_process, mirror_dir, analysis_scripts): for path_in_archive, local_path in analysis_outputs: archive.write(filename=local_path, arcname=path_in_archive) os.chmod(archive_name, 0o644) - print(f"{archive_name}: wrote in {time.time()-start:.2f}s") + print(f"{archive_name}: wrote in {time.time() - start:.2f}s") if exit_code != 0: print("Exiting due to analysis script errors") sys.exit(exit_code) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/utils.py b/.github/container/nsys_jax/nsys_jax/scripts/utils.py index e13845f17..d1494bd3c 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/utils.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/utils.py @@ -40,9 +40,9 @@ def analysis_recipe_path(script): ) if script_file.is_file(): return script_file - assert os.path.exists( - script - ), f"{script} does not exist and is not the name of a built-in analysis script" + assert os.path.exists(script), ( + f"{script} does not exist and is not the name of a built-in analysis script" + ) return contextlib.nullcontext(pathlib.Path(script)) diff --git a/.github/container/test-t5x.sh b/.github/container/test-t5x.sh index 942e4b2c4..554ba7003 100755 --- a/.github/container/test-t5x.sh +++ b/.github/container/test-t5x.sh @@ -175,10 +175,10 @@ seqio.TaskRegistry.add( ], output_features=dict( inputs=seqio.Feature( - vocabulary=t5.data.get_default_vocabulary(), add_eos=True, required=False + vocabulary=seqio.SentencePieceVocabulary(sentencepiece_model_file="gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"), add_eos=True, required=False ), targets=seqio.Feature( - vocabulary=t5.data.get_default_vocabulary(), add_eos=True + vocabulary=seqio.SentencePieceVocabulary(sentencepiece_model_file="gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"), add_eos=True ) ), metric_fns=[] diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 925fffacd..0856fcc62 100755 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -23,7 +23,7 @@ def main(): logger = get_logger(args.output_prefix) logger.info( "Verbose output, including stdout/err of triage commands, will be written to " - f'{(args.output_prefix / "debug.log").resolve()}' + f"{(args.output_prefix / 'debug.log').resolve()}" ) container_url = functools.partial(container_url_base, container=args.container) container_exists = functools.partial( diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index b4f3b8143..c2b4cb4ee 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -111,7 +111,6 @@ jobs: DOCKERFILE: .github/container/Dockerfile.maxtext EXTRA_BUILD_ARGS: | URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }} - URLREF_TFTEXT=${{ fromJson(inputs.SOURCE_URLREFS).TENSORFLOW_TEXT }} secrets: inherit build-levanter: @@ -143,7 +142,6 @@ jobs: DOCKERFILE: .github/container/Dockerfile.t5x EXTRA_BUILD_ARGS: | URLREF_T5X=${{ fromJson(inputs.SOURCE_URLREFS).T5X }} - URLREF_TFTEXT=${{ fromJson(inputs.SOURCE_URLREFS).TENSORFLOW_TEXT }} URLREF_AIRIO=${{ fromJson(inputs.SOURCE_URLREFS).AIRIO }} secrets: inherit @@ -161,7 +159,6 @@ jobs: EXTRA_BUILD_ARGS: | URLREF_PAXML=${{ fromJson(inputs.SOURCE_URLREFS).PAXML }} URLREF_PRAXIS=${{ fromJson(inputs.SOURCE_URLREFS).PRAXIS }} - URLREF_TFTEXT=${{ fromJson(inputs.SOURCE_URLREFS).TENSORFLOW_TEXT }} URLREF_LINGVO=${{ fromJson(inputs.SOURCE_URLREFS).LINGVO }} secrets: inherit diff --git a/rosetta/Dockerfile.gemma b/rosetta/Dockerfile.gemma index e7db16dcc..4a0ba2965 100644 --- a/rosetta/Dockerfile.gemma +++ b/rosetta/Dockerfile.gemma @@ -11,40 +11,7 @@ ARG URLREF_FLAXFORMER=https://github.com/google/flaxformer.git#main ARG SRC_PATH_FLAXFORMER=/opt/flaxformer ARG URLREF_PANOPTICAPI=https://github.com/akolesnikoff/panopticapi.git#mute ARG SRC_PATH_PANOPTICAPI=/opt/panopticapi -ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master -ARG SRC_PATH_TFTEXT=/opt/tensorflow-text -############################################################################### -## Build several packages which do not have working amd64/arm64 pip wheels -############################################################################### - -ARG BASE_IMAGE -FROM ${BASE_IMAGE} as wheel-builder - -#------------------------------------------------------------------------------ -# build tensorflow-text from source -#------------------------------------------------------------------------------ -FROM wheel-builder as tftext-builder -ARG URLREF_TFTEXT -ARG SRC_PATH_TFTEXT - -RUN <<"EOF" bash -exu -o pipefail -pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.18.0 -git-clone.sh ${URLREF_TFTEXT} ${SRC_PATH_TFTEXT} -cd ${SRC_PATH_TFTEXT} - -# The tftext build script queries GitHub, but these requests are sometimes -# throttled by GH, resulting in a corrupted uri for tensorflow in WORKSPACE. -# A workaround (needs to be updated when the tensorflow version changes): -sed -i "s/# Update TF dependency to installed tensorflow./commit_slug=6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/" oss_scripts/prepare_tf_dep.sh - -# Newer versions of LLVM make lld's --undefined-version check of lld is strict -# by default (https://reviews.llvm.org/D135402), but the tftext build seems to -# rely on this behavior. -echo "write_to_bazelrc \"build --linkopt='-Wl,--undefined-version'\"" >> oss_scripts/configure.sh - -./oss_scripts/run_build.sh -EOF ############################################################################### ## Download source and add auxiliary scripts @@ -62,11 +29,6 @@ ARG URLREF_FLAXFORMER ARG SRC_PATH_FLAXFORMER ARG URLREF_PANOPTICAPI ARG SRC_PATH_PANOPTICAPI -ARG URLREF_TFTEXT -ARG SRC_PATH_TFTEXT - -COPY --from=tftext-builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml -COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ RUN <<"EOF" bash -ex git-clone.sh ${URLREF_GEMMA} ${SRC_PATH_GEMMA} @@ -93,7 +55,7 @@ optax protobuf tfds-nightly tensorflow -tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl) +tensorflow-text tensorflow-gan " >> /opt/pip-tools.d/requirements-gemma.in EOF diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index 2320a7ed9..44baa19ef 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -93,7 +93,7 @@ We have run some intial performance and functionality tests with [LLaMA2-7B](htt Please refer to the [example run script](scripts/example_slurm.sub) for more details. We will continue to add more models and associated performance metrics. # Notes -1. The only changes we need to support multiprocessing is to pin tensorflow and tensorflow-text to 2.13.0 version. +1. The only changes we need to support multiprocessing is to pin tensorflow and tensorflow-text to 2.18.0 version or higher. 2. In order to remove extra copies introduced by DUS (dynamic update slice) when used in conjunction with custom NVIDIA kernels (like cuBLAS for GEMMs), the `--xla_gpu_enable_custom_fusions` and `--xla_gpu_enable_address_computation_fusion` flags were introduced. However, the current XLA has some limitation and sometimes using these flags lead to error. So, in this release, it is advised to turn off these two flags: - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false