From cc4fb78c608f9287fcef9a6b2f6d14b4bfe3484f Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 10 Jan 2025 08:03:20 -0800 Subject: [PATCH] Remove deprecated flag xla_gpu_enable_triton_softmax_fusion. References: https://github.com/NVIDIA/JAX-Toolbox/blob/9dd32f50257a405ae766aea2bcbcf51c217ed75c/rosetta/docs/GPU_performance.md?plain=1#L141 https://github.com/openxla/xla/blob/092b8dd65d0d961265665a27432795d302762ae6/xla/debug_options_flags.cc#L1706-L1708 --- rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh | 2 +- rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env | 1 - rosetta/rosetta/projects/pax/xla_flags/common.env | 1 - rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env | 1 - rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env | 1 - 5 files changed, 1 insertion(+), 5 deletions(-) diff --git a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh index a5eaf9aa0..e86ec0c7d 100644 --- a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh +++ b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh @@ -1,2 +1,2 @@ # These XLA flags are meant to be used with the JAX version in the imagen container -export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env index d999f5b5e..38aa396e1 100644 --- a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -14,7 +14,6 @@ export XLA_FLAGS="\ --xla_gpu_enable_pipelined_reduce_scatter=true \ --xla_gpu_enable_pipelined_all_reduce=true \ --xla_gpu_enable_while_loop_double_buffering=true \ - --xla_gpu_enable_triton_softmax_fusion=false \ --xla_gpu_enable_all_gather_combine_by_dim=false \ --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ --xla_disable_hlo_passes=rematerialization \ diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env index 26c819143..d81413a2f 100644 --- a/rosetta/rosetta/projects/pax/xla_flags/common.env +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -4,7 +4,6 @@ export XLA_FLAGS="\ --xla_gpu_enable_latency_hiding_scheduler=true \ --xla_allow_excess_precision \ --xla_gpu_enable_highest_priority_async_stream=true \ - --xla_gpu_enable_triton_softmax_fusion=false \ --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ --xla_gpu_graph_level=0 \ " diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env index e5b97b466..7447ff71d 100644 --- a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -4,7 +4,6 @@ export XLA_FLAGS="\ --xla_gpu_enable_latency_hiding_scheduler=true \ --xla_allow_excess_precision \ --xla_gpu_enable_highest_priority_async_stream=true \ - --xla_gpu_enable_triton_softmax_fusion=false \ --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ --xla_gpu_graph_level=0 \ --xla_gpu_enable_cudnn_fmha=false \ diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env index e48b76dcf..f600e203f 100644 --- a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -6,7 +6,6 @@ export XLA_FLAGS="\ --xla_gpu_enable_latency_hiding_scheduler=true \ --xla_allow_excess_precision \ --xla_gpu_enable_highest_priority_async_stream=true \ - --xla_gpu_enable_triton_softmax_fusion=false \ --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ --xla_gpu_graph_level=0 \ --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \