Skip to content

Commit

Permalink
Dev (#48)
Browse files Browse the repository at this point in the history
* add _jit_pass_eliminate_simple_arith

* disable triton reshape for vae

* fix source dist

* bump version to 0.0.12 and fix python publish

* optimize performance

* fix python publish and some bug fixes

* fix typo

* fix CI
  • Loading branch information
chengzeyi authored Nov 24, 2023
1 parent 65c6bb1 commit 821054e
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python -m pip install --upgrade pip
pip install build wheel setuptools torch==2.1.0
- name: Build package
run: python -m build --sdist -n
run: WITH_CUDA=0 python -m build --sdist -n
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ jobs:
torch_version_suffix=torch$(echo ${{ inputs.torch_version }} | sed 's/\.//g')
cuda_version_suffix=${{ steps.cuda_info.outputs.CUDA_VERSION_SUFFIX }}
nightly_tag=$([[ ${VERSION_SOURCE} == 'tag' ]] && echo '' || echo '.dev'`date +%Y%m%d`)
echo "BUILD_VERSION=${version}${nightly_tag}+${torch_version_suffix}${cuda_version_suffix}" >> ${GITHUB_ENV}
echo "BUILD_VERSION=${version}${nightly_tag}+${torch_version_suffix}${cuda_version_suffix}" >> ${GITHUB_OUTPUT}
echo "BUILD_VERSION=${version}${nightly_tag}-${torch_version_suffix}+${cuda_version_suffix}" >> ${GITHUB_ENV}
echo "BUILD_VERSION=${version}${nightly_tag}-${torch_version_suffix}+${cuda_version_suffix}" >> ${GITHUB_OUTPUT}
cat ${GITHUB_ENV}
- run: echo "sfast-${BUILD_VERSION}"
- run: echo "release version"
if: ${{ !contains(steps.sfast_version.outputs.BUILD_VERSION, '.dev') }}

- name: Setup proper pytorch dependency in "requirements.txt"
run: |
sed -i '/torch/d' ./requirements.txt
echo "torch == ${{ inputs.torch_version }}" >> ./requirements.txt
cat ./requirements.txt
# - name: Setup proper pytorch dependency in "requirements.txt"
# run: |
# sed -i '/torch/d' ./requirements.txt
# echo "torch==${{ inputs.torch_version }}" >> ./requirements.txt
# cat ./requirements.txt

- if: runner.os == 'Windows'
name: (Windows) Setup Runner
Expand All @@ -181,7 +181,7 @@ jobs:
run: |
cudnn_next_version_major=$((${CUDNN_VERSION_MAJOR} + 1))
cudnn_package_name="${CUDNN_PYPI_PACKAGE}>=${CUDNN_VERSION_MAJOR}.0.0.0,<${cudnn_next_version_major}.0.0.0"
$PY -m pip install wheel setuptools ninja twine "${cudnn_package_name}" -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cuda_short_version }} --no-cache-dir
$PY -m pip install wheel setuptools ninja twine "torch==${{ inputs.torch_version }}" "${cudnn_package_name}" -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cuda_short_version }} --no-cache-dir
- name: Build wheel
run: |
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Example requirement, can be anything that pip knows
# install with `pip install -r requirements.txt`, and make sure that CI does the same
torch>=1.12
# torch>=1.12
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_extensions():
# Skip the above useless check as we will always compile with CUDA support,
# and the CI might be running on CPU-only machines.
if os.getenv("WITH_CUDA", "1") != "0":
assert CUDA_HOME is not None, "Cannot find CUDA installation."
assert CUDA_HOME is not None, "Cannot find CUDA installation. If you want to compile without CUDA, set `WITH_CUDA=0`."

cutlass_root = os.path.join(this_dir, "third_party", "cutlass")
cutlass_include = os.path.join(cutlass_root, "include")
Expand Down
11 changes: 10 additions & 1 deletion sfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ def new_lru_cache(*args, **kwargs):

setup_environment()

import sfast._C as _C
try:
import sfast._C as _C
except ImportError:
print('''
***ERROR IMPORTING sfast._C***
Unable to load stable-fast C extension.
Is is compatible with your PyTorch installation?
Or is it compatible with your CUDA version?
''')
raise

# This line will be programatically read/write by setup.py.
# Leave them at the bottom of this file and don't touch them.
Expand Down
6 changes: 3 additions & 3 deletions sfast/jit/trace_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def lazy_trace(func, *, ts_compiler=None, **kwargs_):
traced_modules = {}

name = getattr(func, '__name__', func.__class__.__name__)
wraped = func.forward if isinstance(func, torch.nn.Module) else func
module_to_be_traced = to_module(wraped)
wrapped = func.forward if isinstance(func, torch.nn.Module) else func
module_to_be_traced = to_module(wrapped)

@functools.wraps(wraped)
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
nonlocal lock, traced_modules
key = (hash_arg(args), hash_arg(kwargs))
Expand Down

0 comments on commit 821054e

Please sign in to comment.