From c9a81733a4c0ae6b7d35d0192b710f38ba081ccc Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Tue, 30 Apr 2024 09:00:58 +0200 Subject: [PATCH 1/4] chore: add more ruff linting rules --- .pre-commit-config.yaml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e64f6f2..e0f7c24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: rev: "v0.3.5" hooks: - id: ruff - args: ["--fix", "--show-fixes"] + args: ["--fix", "--show-fixes", "--preview"] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/pyproject.toml b/pyproject.toml index 4d551f1..68cf6d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ extend-select = [ "SIM", # flake8-simplify # TODO: in evaluation "T10", # flake8-debugger "T20", # flake8-print # TODO: in evaluation + "TCH", # flake8-type-checking # TODO: in evaluation "NPY" # NumPy specific rules ] ignore = [ From 08436ffe5a10527fb651a838568f52b153166d13 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 29 May 2024 12:57:23 +0200 Subject: [PATCH 2/4] build: update base and development dependencies --- .pre-commit-config.yaml | 11 +++++------ pyproject.toml | 18 ++++++++++++++++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0f7c24..3acdc0b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,23 +49,22 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.3.5" + rev: v0.4.6 hooks: - id: ruff args: ["--fix", "--show-fixes", "--preview"] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: v1.10.0 hooks: - id: mypy files: src|tests args: [--no-install-types] additional_dependencies: - - pytest - - typing-extensions>=4.10.0 - - types-all - + - dace==0.15.1 + - jax[cpu]==0.4.28 + - numpy==1.26.4 - repo: https://github.com/codespell-project/codespell rev: "v2.2.6" hooks: diff --git a/pyproject.toml b/pyproject.toml index 68cf6d2..4067092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,11 @@ classifiers = [ "Topic :: Scientific/Engineering", "Typing :: Typed" ] -dependencies = [] +dependencies = [ + "dace>=0.15", + "jax[cpu]>=0.4.24", + "numpy>=1.26.0" +] description = "JAX jit using DaCe (Data Centric Parallel Programming)" name = "JaCe" readme = "README.md" @@ -29,6 +33,13 @@ requires-python = ">=3.10" version = "0.1.0" license.file = "LICENSE" +[project.optional-dependencies] +cuda12 = [ + "cupy-cuda12x>=12.1.0", + "jax[cuda12]>=0.4.24", + "optuna>=3.4.0" +] + [project.urls] "Bug Tracker" = "https://github.com/GridTools/JaCe/issues" Changelog = "https://github.com/GridTools/JaCe/releases" @@ -120,7 +131,7 @@ ignore = [ 'E501', # [line-too-long] 'UP038' # [non-pep604-isinstance] ] -ignore-init-module-imports = true +# ignore-init-module-imports = true # deprecated in preview mode unfixable = [] [tool.ruff.lint.isort] @@ -146,6 +157,9 @@ section-order = [ 'local-folder' ] +[tool.ruff.lint.isort.sections] +tests = ["tests", "unit_tests", "integration_tests"] + [tool.ruff.lint.per-file-ignores] "!tests/**.py" = ["PT"] # Ignore `flake8-pytest-style` everywhere except in `tests/` "noxfile.py" = ["T20"] # Ignore `flake8-print` From e9c3ce16e8ccdd61994ab32e8fe50b935c9830cb Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 29 May 2024 17:06:58 +0200 Subject: [PATCH 3/4] wip: enhance configs of github actions and pre-commit --- .github/workflows/ci.yml | 4 ---- .pre-commit-config.yaml | 2 +- ROADMAP.md | 4 ++-- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f624ada..86a32f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,10 +39,6 @@ jobs: python-version: ["3.10", "3.12"] runs-on: [ubuntu-latest, macos-latest, windows-latest] - include: - - python-version: pypy-3.10 - runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3acdc0b..97e8d51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -80,7 +80,7 @@ repos: - id: disallow-caps name: Disallow improper capitalization language: pygrep - entry: PyBind|Numpy|Cmake|CCache|Github|PyTest + entry: PyBind|Numpy|Cmake|CCache|Github|PyTest|Dace|Jace exclude: .pre-commit-config.yaml - repo: https://github.com/abravalheri/validate-pyproject diff --git a/ROADMAP.md b/ROADMAP.md index e27152a..2beaa39 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -5,7 +5,7 @@ A kind of roadmap that gives a rough idea about how the project will be continue - [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3). - [ ] Basic functionalities: - [ ] Annotation `@jace.jit`. - - [ ] Composable with Jax, i.e. take the Jax derivative of a Jace annotated function. + - [ ] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function. - [ ] Implementing the `stages` model that is supported by Jax. - [ ] Handling Jax arrays as native input (only on single host). - [ ] Cache the compilation and lowering results for later reuse. @@ -56,7 +56,7 @@ These are more general topics that should be addressed at one point. # Optimization & Transformations -The SDFG generated by Jace have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them. +The SDFG generated by JaCe have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them. Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken. - [ ] Modified state fusion; Because of the structure we have, this could make `Simplify` much more efficient. From 4827843379a18f579a6845148754a7064b76f26c Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Thu, 30 May 2024 11:09:58 +0200 Subject: [PATCH 4/4] Ignore typing import errors from dace and jax --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4067092..3556e8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ warn_unused_ignores = true disallow_incomplete_defs = false disallow_untyped_defs = false ignore_missing_imports = true -module = "tests.*" +module = ["tests.*", "dace.*", "jax.*", "jaxlib.*"] # -- pytest -- [tool.pytest]