Skip to content

Commit

Permalink
Setup for PyPi release.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Dec 11, 2023
1 parent 26b6714 commit 108009a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 9 deletions.
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
recursive-include csrc *.h
recursive-include csrc *.cu
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
clean:
rm -rf dist/*

dist: clean
python3 setup.py sdist

upload: dist
twine upload dist/*

upload-test: dist
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Grouped GEMM

A lighweight library exposing grouped GEMM kernels in PyTorch.

# Installation

Run `pip install grouped_gemm` to install the package.
24 changes: 15 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"

os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"

cwd = Path(os.path.dirname(os.path.abspath(__file__)))
_dc = torch.cuda.get_device_capability()
_dc = f"{_dc[0]}{_dc[1]}"

# DEBUG
_dc = 90

ext_modules = [
CUDAExtension(
"grouped_gemm_backend",
["csrc/ops.cu", "csrc/grouped_gemm.cu"],
include_dirs = [
f"{cwd}/third_party/cutlass/include/"
f"{cwd}/third_party/cutlass/include/",
f"{cwd}/csrc"
],
extra_compile_args={
"cxx": [
Expand All @@ -38,12 +34,22 @@
)
]

extra_deps = {}

extra_deps['dev'] = [
'absl-py',
]

extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)

setup(
name="grouped_gemm",
version="0.0.1",
author="Trevor Gale",
author_email="[email protected]",
description="GEMM Grouped",
description="Grouped GEMM",
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url="https://github.com/tgale06/grouped_gemm",
classifiers=[
"Programming Language :: Python :: 3",
Expand All @@ -53,5 +59,5 @@
packages=find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
install_requires=["absl-py", "numpy", "torch"],
extras_require=extra_deps,
)

0 comments on commit 108009a

Please sign in to comment.