Skip to content

Commit

Permalink
Merge pull request #14 from cms-ml/feature/passthrough_xla_flags
Browse files Browse the repository at this point in the history
Forward XLA_FLAGS and TF_XLA_FLAGS in AOT compilation.
  • Loading branch information
valsdav authored Mar 27, 2024
2 parents 8b24692 + b5a33b1 commit 89df2ec
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion cmsml/scripts/compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def compile_tf_graph(
output_serving_key: str | None = None,
compile_prefix: str | None = None,
compile_class: str | None = None,
xla_flags: list[str] | None = None,
tf_xla_flags: list[str] | None = None,
) -> None:
"""
For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
Expand All @@ -35,6 +37,8 @@ def compile_tf_graph(
An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
*compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
*xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`.
"""
tf = import_tf()[0]

Expand Down Expand Up @@ -96,6 +100,8 @@ def compile_tf_graph(
compile_class,
batch_sizes=batch_sizes,
serving_key=output_serving_key,
xla_flags=xla_flags,
tf_xla_flags=tf_xla_flags,
)


Expand All @@ -106,13 +112,18 @@ def aot_compile(
class_name: str,
batch_sizes: tuple[int] = (1,),
serving_key: str = r"serving_default_bs{}",
xla_flags: list[str] | None = None,
tf_xla_flags: list[str] | None = None,
) -> None:
"""
Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
from it, AOT compiles it.
This process generates header and object files at *output_path*. The *class_name* is used as class name within the
header access the AOT-compiled network.
When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
and *TF_XLA_FLAGS* environment variables, respectively.
"""
# prepare model path
model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path))))
Expand All @@ -131,6 +142,19 @@ def aot_compile(
# get the compilation executable
exe = _which_saved_model_cli()

# ammend the env when xla flags were passed
env = os.environ.copy()
if xla_flags:
xla_flags_orig = env.get("XLA_FLAGS", "")
if xla_flags_orig:
xla_flags = [xla_flags_orig.rstrip(",")] + xla_flags
env["XLA_FLAGS"] = ",".join(map(str, xla_flags))
if tf_xla_flags:
tf_xla_flags_orig = env.get("TF_XLA_FLAGS", "")
if tf_xla_flags_orig:
tf_xla_flags = [tf_xla_flags_orig.rstrip(",")] + tf_xla_flags
env["TF_XLA_FLAGS"] = ",".join(map(str, tf_xla_flags))

# compile for each batch size
for bs in sorted(set(map(int, batch_sizes))):
cmd = (
Expand All @@ -143,7 +167,7 @@ def aot_compile(
)

print(f"compiling for batch size {colored(bs, 'magenta')}")
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path)[0]
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path, env=env)[0]
if code != 0:
raise Exception(f"aot compilation using {exe} failed with exit code {code}")

Expand Down

0 comments on commit 89df2ec

Please sign in to comment.