Skip to content

Commit

Permalink
Add install to colfax
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 20, 2024
1 parent fec9876 commit c1aab43
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ def test_fbgemm():
print("OK")


def install_cutlass():
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()


def install_fa2(compile=False):
if compile:
# compile from source (slow)
Expand All @@ -83,12 +77,6 @@ def install_liger():
subprocess.check_call(cmd)


def install_tk():
from tools.tk.install import install_tk

install_tk()


def install_xformers():
os_env = os.environ.copy()
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
Expand All @@ -101,7 +89,7 @@ def install_xformers():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
parser.add_argument(
"--cutlass", action="store_true", help="Install optional CUTLASS kernels"
"--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels"
)
parser.add_argument(
"--fa2", action="store_true", help="Install optional flash_attention 2 kernels"
Expand Down Expand Up @@ -139,14 +127,16 @@ def install_xformers():
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass-kernels...")
install_cutlass()
if args.colfax or args.all:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass
install_colfax_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
if args.tk or args.all:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk
install_tk()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
Expand Down

0 comments on commit c1aab43

Please sign in to comment.