-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add python bindings * feat: update python bindings * chore(ci): run python ci for all branches * fix: add tests * chore(ci): update python ci * chore(ci): rename ci workflows * chore(ci): remove macos from ci * chore(ci): update release-plz.yml
- Loading branch information
1 parent
40b3b8b
commit c09a3cc
Showing
13 changed files
with
489 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
name: Build & Test the Python Bindings | ||
|
||
defaults: | ||
run: | ||
working-directory: python | ||
|
||
on: | ||
push: | ||
pull_request: | ||
workflow_dispatch: | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
format: | ||
name: Check Python format | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
- name: Install dependencies | ||
run: pip install ruff black | ||
- name: Ruff | ||
run: ruff check . | ||
- name: Black | ||
run: black --check --diff . | ||
|
||
rustfmt: | ||
name: Check Rust format | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- run: rustup update stable && rustup default stable | ||
- run: rustup component add rustfmt | ||
- run: cargo fmt --all --check | ||
|
||
test: | ||
name: Run tests | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install locally | ||
run: pip install -e ".[test]" | ||
- name: Install additional dependencies | ||
run: pip install pytest-md pytest-emoji | ||
- uses: pavelzw/pytest-action@v2 | ||
with: | ||
emoji: false | ||
verbose: true | ||
job-summary: true | ||
- name: Test building wheels | ||
uses: PyO3/maturin-action@v1 | ||
with: | ||
working-directory: python | ||
sccache: true | ||
manylinux: auto | ||
|
||
linux: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
target: [x86_64, aarch64, armv7] | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.9" | ||
- name: Build wheels | ||
uses: PyO3/maturin-action@v1 | ||
with: | ||
working-directory: python | ||
target: ${{ matrix.target }} | ||
args: --release --out dist --interpreter 3.9 pypy3.9 pypy3.10 | ||
sccache: true | ||
manylinux: auto | ||
- name: Upload wheels | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: wheels-linux-${{ matrix.target }} | ||
path: python/dist | ||
- name: pytest | ||
if: ${{ startsWith(matrix.target, 'x86_64') }} | ||
shell: bash | ||
run: | | ||
set -e | ||
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall | ||
pytest --import-mode=importlib | ||
- name: pytest | ||
if: ${{ !startsWith(matrix.target, 'x86') && matrix.target != 'ppc64' }} | ||
uses: uraimo/[email protected] | ||
with: | ||
arch: ${{ matrix.target }} | ||
distro: ubuntu22.04 | ||
githubToken: ${{ github.token }} | ||
install: | | ||
apt-get update | ||
apt-get install -y --no-install-recommends python3 python3-pip | ||
pip3 install -U pip | ||
run: | | ||
set -e | ||
cd python | ||
pip3 install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall | ||
pytest --import-mode=importlib | ||
windows: | ||
runs-on: windows-latest | ||
strategy: | ||
matrix: | ||
target: [x64] | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.9" | ||
architecture: ${{ matrix.target }} | ||
- name: Build wheels | ||
uses: PyO3/maturin-action@v1 | ||
with: | ||
working-directory: python | ||
target: ${{ matrix.target }} | ||
args: --release --out dist --interpreter 3.9 pypy3.9 pypy3.10 | ||
sccache: true | ||
- name: Upload wheels | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: wheels-windows-${{ matrix.target }} | ||
path: python/dist | ||
- name: pytest | ||
if: ${{ !startsWith(matrix.target, 'aarch64') }} | ||
shell: bash | ||
run: | | ||
set -e | ||
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall | ||
pytest --import-mode=importlib | ||
sdist: | ||
needs: [test] | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- name: Build sdist | ||
uses: PyO3/maturin-action@v1 | ||
with: | ||
working-directory: python | ||
command: sdist | ||
args: --out dist | ||
- name: Upload sdist | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: wheels-sdist | ||
path: python/dist | ||
|
||
release: | ||
name: Release | ||
runs-on: ubuntu-latest | ||
if: "startsWith(github.ref, 'refs/tags/')" | ||
needs: [test, format, rustfmt, linux, windows, sdist] | ||
permissions: | ||
# Used to upload release artifacts | ||
contents: write | ||
steps: | ||
- uses: actions/download-artifact@v4 | ||
with: | ||
pattern: wheels-* | ||
merge-multiple: true | ||
- name: Publish to PyPI | ||
uses: PyO3/maturin-action@v1 | ||
env: | ||
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} | ||
with: | ||
command: upload | ||
args: --non-interactive --skip-existing * | ||
- name: Upload to GitHub Release | ||
uses: softprops/action-gh-release@v2 | ||
with: | ||
files: | | ||
*.whl | ||
*.tar.gz | ||
prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,35 @@ | ||
[package] | ||
name = "mtc-token-healing" | ||
[workspace] | ||
members = ["python"] | ||
|
||
[workspace.package] | ||
version = "0.1.1" | ||
edition = "2021" | ||
license = "MIT OR Apache-2.0" | ||
description = "Token healing implementation" | ||
repository = "https://github.com/ModelTC/mtc-token-healing" | ||
homepage = "https://github.com/ModelTC/mtc-token-healing" | ||
documentation = "https://docs.rs/mtc-token-healing" | ||
readme = "README.md" | ||
authors = ["Chielo Newctle <[email protected]>"] | ||
exclude = ["release-plz.toml", ".github"] | ||
|
||
[package] | ||
name = "mtc-token-healing" | ||
version.workspace = true | ||
edition.workspace = true | ||
license.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
homepage.workspace = true | ||
documentation.workspace = true | ||
authors.workspace = true | ||
readme = "README.md" | ||
exclude = ["release-plz.toml", ".github", "python"] | ||
|
||
[dependencies] | ||
derive_more = "0.99.17" | ||
general-sam = { version = "1.0.0", features = ["trie"] } | ||
pyo3 = { version = "0.21.2", optional = true } | ||
smallvec = "1.13.2" | ||
thiserror = "1.0.59" | ||
thiserror = "1.0.60" | ||
|
||
[features] | ||
pyo3 = ["dep:pyo3"] | ||
|
@@ -26,9 +39,14 @@ clap = { version = "4.5.4", features = ["derive", "env"] } | |
color-eyre = "0.6.3" | ||
rand = "0.8.5" | ||
regex = "1.10.4" | ||
serde_json = "1.0.116" | ||
serde_json = "1.0.117" | ||
tokenizers = { version = "0.19.1", features = ["hf-hub", "http"] } | ||
tokio = { version = "1.37.0", features = ["rt-multi-thread"] } | ||
|
||
[package.metadata.docs.rs] | ||
all-features = true | ||
|
||
[profile.release] | ||
lto = true | ||
strip = true | ||
opt-level = "z" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "mtc-token-healing-py" | ||
version.workspace = true | ||
edition.workspace = true | ||
license.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
homepage.workspace = true | ||
documentation.workspace = true | ||
authors.workspace = true | ||
|
||
[lib] | ||
name = "mtc_token_healing" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
mtc-token-healing = { version = "0.1.0", path = "..", features = ["pyo3"] } | ||
pyo3 = { version = "0.21.2", features = ["extension-module", "generate-import-lib", "abi3-py39"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from .mtc_token_healing import ( | ||
BestChoice, | ||
CountInfo, | ||
InferRequest, | ||
InferResponse, | ||
Prediction, | ||
VocabPrefixAutomaton, | ||
ReorderedTokenId, | ||
SearchTree, | ||
) | ||
|
||
__all__ = [ | ||
"BestChoice", | ||
"CountInfo", | ||
"InferRequest", | ||
"InferResponse", | ||
"Prediction", | ||
"VocabPrefixAutomaton", | ||
"ReorderedTokenId", | ||
"SearchTree", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Sequence | ||
|
||
TokenId = int | ||
|
||
class BestChoice: ... | ||
class CountInfo: ... | ||
class InferRequest: ... | ||
class InferResponse: ... | ||
class Prediction: ... | ||
|
||
class VocabPrefixAutomaton: | ||
def __init__(self, vocab: Sequence[str]) -> None: ... | ||
def get_order(self) -> Sequence[int]: ... | ||
@property | ||
def vocab_size(self) -> int: ... | ||
|
||
class ReorderedTokenId: ... | ||
class SearchTree: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["maturin>=1.3,<2.0"] | ||
build-backend = "maturin" | ||
|
||
[project] | ||
name = "mtc_token_healing" | ||
requires-python = ">=3.8" | ||
classifiers = [ | ||
"Programming Language :: Rust", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Programming Language :: Python :: Implementation :: PyPy", | ||
] | ||
dynamic = ["version"] | ||
|
||
[tool.maturin] | ||
|
||
[project.optional-dependencies] | ||
test = ["pytest"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use ::mtc_token_healing::{ | ||
vocab::PyVocabPrefixAutomaton, BestChoice, CountInfo, InferRequest, InferResponse, Prediction, | ||
ReorderedTokenId, SearchTree, | ||
}; | ||
use pyo3::prelude::*; | ||
|
||
#[pymodule] | ||
fn mtc_token_healing(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { | ||
m.add_class::<BestChoice>()?; | ||
m.add_class::<CountInfo>()?; | ||
m.add_class::<InferRequest>()?; | ||
m.add_class::<InferResponse>()?; | ||
m.add_class::<Prediction>()?; | ||
m.add_class::<PyVocabPrefixAutomaton>()?; | ||
m.add_class::<ReorderedTokenId>()?; | ||
m.add_class::<SearchTree>()?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from mtc_token_healing import VocabPrefixAutomaton | ||
|
||
|
||
def test_vocab_simple(): | ||
vocab = ["bcd", "abc", "cc", "hello", "world", " ", "yes", "no", "."] | ||
order = [5, 8, 1, 0, 2, 3, 7, 4, 6] | ||
|
||
assert len(vocab) == len(order) | ||
|
||
automaton = VocabPrefixAutomaton(vocab) | ||
|
||
assert automaton.vocab_size == len(vocab) | ||
assert automaton.get_order() == order | ||
|
||
assert all(vocab[order[i]] < vocab[order[i + 1]] for i in range(len(order) - 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.