diff --git a/benchmarks/nightly/run.py b/benchmarks/nightly/run.py index 5fedc527..eaac8fc1 100644 --- a/benchmarks/nightly/run.py +++ b/benchmarks/nightly/run.py @@ -66,9 +66,10 @@ def run(): op_args = OPERATORS[op] output_file = output_dir.joinpath(f"{op}.csv") op_args.extend(["--output", str(output_file.absolute())]) - op = op_task.make_operator_instance(op_args) - op.run() + op_task.make_operator_instance(op_args) + op_task.run() output_files.append(output_file) + del op_task # Reduce all operator CSV outputs to a single output json result_json_file = reduce(output_files) diff --git a/tritonbench/components/compile_time/trace.py b/tritonbench/components/compile_time/trace.py index 5bb31684..72c7affb 100644 --- a/tritonbench/components/compile_time/trace.py +++ b/tritonbench/components/compile_time/trace.py @@ -1,9 +1,9 @@ from typing import Callable, Dict import torch -from triton.fb.triton_util import triton_add_listener, TritonHook -from tritonbench.utils.env_utils import fresh_triton_cache - +from tritonbench.utils.env_utils import fresh_triton_cache, is_fbcode +if is_fbcode(): + from triton.fb.triton_util import triton_add_listener, TritonHook def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]: # not yet getting results that make sense to me diff --git a/tritonbench/utils/env_utils.py b/tritonbench/utils/env_utils.py index 18696f25..cfd5d7e4 100644 --- a/tritonbench/utils/env_utils.py +++ b/tritonbench/utils/env_utils.py @@ -1,5 +1,6 @@ """ Utils for checking and modifying the environment. +Requires PyTorch """ import logging @@ -30,6 +31,9 @@ ] +def is_fbcode() -> bool: + return not hasattr(torch.version, "git_version") + def is_cuda() -> bool: return torch.version.cuda is not None