Skip to content

Commit

Permalink
Build with uv instead of pip
Browse files Browse the repository at this point in the history
Signed-off-by: Jennifer Zhou <[email protected]>
  • Loading branch information
jennifgcrl committed Nov 8, 2024
1 parent e5ffaa7 commit e821d5a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
3 changes: 3 additions & 0 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
get_frameworks,
cuda_path,
get_max_jobs_for_parallel_build,
install_and_import,
)

install_and_import("pybind11[global]")


class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
Expand Down
41 changes: 26 additions & 15 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
from importlib.util import find_spec


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -254,7 +255,9 @@ def get_frameworks() -> List[str]:
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(f"Transformer Engine does not support framework={framework}")
raise ValueError(
f"Transformer Engine does not support framework={framework}"
)

return _frameworks

Expand Down Expand Up @@ -294,24 +297,32 @@ def copy_common_headers(
shutil.copy(path, new_path)


def pip_or_uv() -> List[str]:
if find_spec("pip") is not None:
return [sys.executable, "-m", "pip"]
else:
return ["/usr/bin/env", "uv", "pip"]


def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
subprocess.check_call([*pip_or_uv(), "install", package])
globals()[main_package] = importlib.import_module(main_package)


def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
]
)
if find_spec("pip") is not None:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
]
)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("jax[cuda12_local]")
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension

Expand Down

0 comments on commit e821d5a

Please sign in to comment.