From b5a33b1468496c71bb988f7500256a251f9b1013 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 26 Mar 2024 18:16:41 +0100 Subject: [PATCH] Forward XLA_FLAGS and TF_XLA_FLAGS in AOT compilation. --- cmsml/scripts/compile_tf_graph.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/cmsml/scripts/compile_tf_graph.py b/cmsml/scripts/compile_tf_graph.py index 6d990fb..9316ec5 100644 --- a/cmsml/scripts/compile_tf_graph.py +++ b/cmsml/scripts/compile_tf_graph.py @@ -20,6 +20,8 @@ def compile_tf_graph( output_serving_key: str | None = None, compile_prefix: str | None = None, compile_class: str | None = None, + xla_flags: list[str] | None = None, + tf_xla_flags: list[str] | None = None, ) -> None: """ For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel @@ -35,6 +37,8 @@ def compile_tf_graph( An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case *compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files. + + *xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`. """ tf = import_tf()[0] @@ -96,6 +100,8 @@ def compile_tf_graph( compile_class, batch_sizes=batch_sizes, serving_key=output_serving_key, + xla_flags=xla_flags, + tf_xla_flags=tf_xla_flags, ) @@ -106,6 +112,8 @@ def aot_compile( class_name: str, batch_sizes: tuple[int] = (1,), serving_key: str = r"serving_default_bs{}", + xla_flags: list[str] | None = None, + tf_xla_flags: list[str] | None = None, ) -> None: """ Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key* @@ -113,6 +121,9 @@ def aot_compile( This process generates header and object files at *output_path*. The *class_name* is used as class name within the header access the AOT-compiled network. + + When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS* + and *TF_XLA_FLAGS* environment variables, respectively. """ # prepare model path model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path)))) @@ -131,6 +142,19 @@ def aot_compile( # get the compilation executable exe = _which_saved_model_cli() + # ammend the env when xla flags were passed + env = os.environ.copy() + if xla_flags: + xla_flags_orig = env.get("XLA_FLAGS", "") + if xla_flags_orig: + xla_flags = [xla_flags_orig.rstrip(",")] + xla_flags + env["XLA_FLAGS"] = ",".join(map(str, xla_flags)) + if tf_xla_flags: + tf_xla_flags_orig = env.get("TF_XLA_FLAGS", "") + if tf_xla_flags_orig: + tf_xla_flags = [tf_xla_flags_orig.rstrip(",")] + tf_xla_flags + env["TF_XLA_FLAGS"] = ",".join(map(str, tf_xla_flags)) + # compile for each batch size for bs in sorted(set(map(int, batch_sizes))): cmd = ( @@ -143,7 +167,7 @@ def aot_compile( ) print(f"compiling for batch size {colored(bs, 'magenta')}") - code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path)[0] + code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path, env=env)[0] if code != 0: raise Exception(f"aot compilation using {exe} failed with exit code {code}")