From e821d5a246170a110540dafe115c44641c86821d Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Fri, 8 Nov 2024 10:38:21 +0000 Subject: [PATCH] Build with uv instead of pip Signed-off-by: Jennifer Zhou --- build_tools/build_ext.py | 3 +++ build_tools/utils.py | 41 +++++++++++++++++++++++++--------------- setup.py | 1 + 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index af11ada34c..1807aaa106 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -25,8 +25,11 @@ get_frameworks, cuda_path, get_max_jobs_for_parallel_build, + install_and_import, ) +install_and_import("pybind11[global]") + class CMakeExtension(setuptools.Extension): """CMake extension module""" diff --git a/build_tools/utils.py b/build_tools/utils.py index d846b87f22..4bb18a7002 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -15,6 +15,7 @@ from pathlib import Path from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union +from importlib.util import find_spec @functools.lru_cache(maxsize=None) @@ -254,7 +255,9 @@ def get_frameworks() -> List[str]: _frameworks = [framework.lower() for framework in _frameworks] for framework in _frameworks: if framework not in supported_frameworks: - raise ValueError(f"Transformer Engine does not support framework={framework}") + raise ValueError( + f"Transformer Engine does not support framework={framework}" + ) return _frameworks @@ -294,24 +297,32 @@ def copy_common_headers( shutil.copy(path, new_path) +def pip_or_uv() -> List[str]: + if find_spec("pip") is not None: + return [sys.executable, "-m", "pip"] + else: + return ["/usr/bin/env", "uv", "pip"] + + def install_and_import(package): """Install a package via pip (if not already installed) and import into globals.""" main_package = package.split("[")[0] - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + subprocess.check_call([*pip_or_uv(), "install", package]) globals()[main_package] = importlib.import_module(main_package) def uninstall_te_wheel_packages(): - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "uninstall", - "-y", - "transformer_engine_cu12", - "transformer_engine_torch", - "transformer_engine_paddle", - "transformer_engine_jax", - ] - ) + if find_spec("pip") is not None: + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "uninstall", + "-y", + "transformer_engine_cu12", + "transformer_engine_torch", + "transformer_engine_paddle", + "transformer_engine_jax", + ] + ) diff --git a/setup.py b/setup.py index 3bb2fe6b95..b020fa3c4f 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ elif "paddle" in frameworks: from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: + install_and_import("jax[cuda12_local]") install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension