Skip to content

Commit

Permalink
feat: add python bindings (#2)
Browse files Browse the repository at this point in the history
* feat: add python bindings

* feat: update python bindings

* chore(ci): run python ci for all branches

* fix: add tests

* chore(ci): update python ci

* chore(ci): rename ci workflows

* chore(ci): remove macos from ci

* chore(ci): update release-plz.yml
  • Loading branch information
ChieloNewctle authored May 9, 2024
1 parent 40b3b8b commit c09a3cc
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 25 deletions.
201 changes: 201 additions & 0 deletions .github/workflows/ci_python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
name: Build & Test the Python Bindings

defaults:
run:
working-directory: python

on:
push:
pull_request:
workflow_dispatch:

permissions:
contents: read

jobs:
format:
name: Check Python format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
- name: Install dependencies
run: pip install ruff black
- name: Ruff
run: ruff check .
- name: Black
run: black --check --diff .

rustfmt:
name: Check Rust format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- run: rustup update stable && rustup default stable
- run: rustup component add rustfmt
- run: cargo fmt --all --check

test:
name: Run tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install locally
run: pip install -e ".[test]"
- name: Install additional dependencies
run: pip install pytest-md pytest-emoji
- uses: pavelzw/pytest-action@v2
with:
emoji: false
verbose: true
job-summary: true
- name: Test building wheels
uses: PyO3/maturin-action@v1
with:
working-directory: python
sccache: true
manylinux: auto

linux:
runs-on: ubuntu-latest
strategy:
matrix:
target: [x86_64, aarch64, armv7]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.9"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
working-directory: python
target: ${{ matrix.target }}
args: --release --out dist --interpreter 3.9 pypy3.9 pypy3.10
sccache: true
manylinux: auto
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-linux-${{ matrix.target }}
path: python/dist
- name: pytest
if: ${{ startsWith(matrix.target, 'x86_64') }}
shell: bash
run: |
set -e
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
- name: pytest
if: ${{ !startsWith(matrix.target, 'x86') && matrix.target != 'ppc64' }}
uses: uraimo/[email protected]
with:
arch: ${{ matrix.target }}
distro: ubuntu22.04
githubToken: ${{ github.token }}
install: |
apt-get update
apt-get install -y --no-install-recommends python3 python3-pip
pip3 install -U pip
run: |
set -e
cd python
pip3 install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
windows:
runs-on: windows-latest
strategy:
matrix:
target: [x64]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.9"
architecture: ${{ matrix.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
working-directory: python
target: ${{ matrix.target }}
args: --release --out dist --interpreter 3.9 pypy3.9 pypy3.10
sccache: true
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-windows-${{ matrix.target }}
path: python/dist
- name: pytest
if: ${{ !startsWith(matrix.target, 'aarch64') }}
shell: bash
run: |
set -e
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
sdist:
needs: [test]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Build sdist
uses: PyO3/maturin-action@v1
with:
working-directory: python
command: sdist
args: --out dist
- name: Upload sdist
uses: actions/upload-artifact@v4
with:
name: wheels-sdist
path: python/dist

release:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [test, format, rustfmt, linux, windows, sdist]
permissions:
# Used to upload release artifacts
contents: write
steps:
- uses: actions/download-artifact@v4
with:
pattern: wheels-*
merge-multiple: true
- name: Publish to PyPI
uses: PyO3/maturin-action@v1
env:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
args: --non-interactive --skip-existing *
- name: Upload to GitHub Release
uses: softprops/action-gh-release@v2
with:
files: |
*.whl
*.tar.gz
prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') }}
2 changes: 1 addition & 1 deletion .github/workflows/release-plz.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ permissions:

on:
workflow_run:
workflows: [Cargo Build & Test]
workflows: [Cargo Build & Test, Build & Test the Python Bindings]
types: [completed]
branches: [main]

Expand Down
30 changes: 24 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
[package]
name = "mtc-token-healing"
[workspace]
members = ["python"]

[workspace.package]
version = "0.1.1"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Token healing implementation"
repository = "https://github.com/ModelTC/mtc-token-healing"
homepage = "https://github.com/ModelTC/mtc-token-healing"
documentation = "https://docs.rs/mtc-token-healing"
readme = "README.md"
authors = ["Chielo Newctle <[email protected]>"]
exclude = ["release-plz.toml", ".github"]

[package]
name = "mtc-token-healing"
version.workspace = true
edition.workspace = true
license.workspace = true
description.workspace = true
repository.workspace = true
homepage.workspace = true
documentation.workspace = true
authors.workspace = true
readme = "README.md"
exclude = ["release-plz.toml", ".github", "python"]

[dependencies]
derive_more = "0.99.17"
general-sam = { version = "1.0.0", features = ["trie"] }
pyo3 = { version = "0.21.2", optional = true }
smallvec = "1.13.2"
thiserror = "1.0.59"
thiserror = "1.0.60"

[features]
pyo3 = ["dep:pyo3"]
Expand All @@ -26,9 +39,14 @@ clap = { version = "4.5.4", features = ["derive", "env"] }
color-eyre = "0.6.3"
rand = "0.8.5"
regex = "1.10.4"
serde_json = "1.0.116"
serde_json = "1.0.117"
tokenizers = { version = "0.19.1", features = ["hf-hub", "http"] }
tokio = { version = "1.37.0", features = ["rt-multi-thread"] }

[package.metadata.docs.rs]
all-features = true

[profile.release]
lto = true
strip = true
opt-level = "z"
18 changes: 18 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "mtc-token-healing-py"
version.workspace = true
edition.workspace = true
license.workspace = true
description.workspace = true
repository.workspace = true
homepage.workspace = true
documentation.workspace = true
authors.workspace = true

[lib]
name = "mtc_token_healing"
crate-type = ["cdylib"]

[dependencies]
mtc-token-healing = { version = "0.1.0", path = "..", features = ["pyo3"] }
pyo3 = { version = "0.21.2", features = ["extension-module", "generate-import-lib", "abi3-py39"] }
21 changes: 21 additions & 0 deletions python/mtc_token_healing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .mtc_token_healing import (
BestChoice,
CountInfo,
InferRequest,
InferResponse,
Prediction,
VocabPrefixAutomaton,
ReorderedTokenId,
SearchTree,
)

__all__ = [
"BestChoice",
"CountInfo",
"InferRequest",
"InferResponse",
"Prediction",
"VocabPrefixAutomaton",
"ReorderedTokenId",
"SearchTree",
]
18 changes: 18 additions & 0 deletions python/mtc_token_healing/mtc_token_healing.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Sequence

TokenId = int

class BestChoice: ...
class CountInfo: ...
class InferRequest: ...
class InferResponse: ...
class Prediction: ...

class VocabPrefixAutomaton:
def __init__(self, vocab: Sequence[str]) -> None: ...
def get_order(self) -> Sequence[int]: ...
@property
def vocab_size(self) -> int: ...

class ReorderedTokenId: ...
class SearchTree: ...
18 changes: 18 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["maturin>=1.3,<2.0"]
build-backend = "maturin"

[project]
name = "mtc_token_healing"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dynamic = ["version"]

[tool.maturin]

[project.optional-dependencies]
test = ["pytest"]
18 changes: 18 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use ::mtc_token_healing::{
vocab::PyVocabPrefixAutomaton, BestChoice, CountInfo, InferRequest, InferResponse, Prediction,
ReorderedTokenId, SearchTree,
};
use pyo3::prelude::*;

#[pymodule]
fn mtc_token_healing(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<BestChoice>()?;
m.add_class::<CountInfo>()?;
m.add_class::<InferRequest>()?;
m.add_class::<InferResponse>()?;
m.add_class::<Prediction>()?;
m.add_class::<PyVocabPrefixAutomaton>()?;
m.add_class::<ReorderedTokenId>()?;
m.add_class::<SearchTree>()?;
Ok(())
}
15 changes: 15 additions & 0 deletions python/tests/test_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from mtc_token_healing import VocabPrefixAutomaton


def test_vocab_simple():
vocab = ["bcd", "abc", "cc", "hello", "world", " ", "yes", "no", "."]
order = [5, 8, 1, 0, 2, 3, 7, 4, 6]

assert len(vocab) == len(order)

automaton = VocabPrefixAutomaton(vocab)

assert automaton.vocab_size == len(vocab)
assert automaton.get_order() == order

assert all(vocab[order[i]] < vocab[order[i + 1]] for i in range(len(order) - 1))
1 change: 1 addition & 0 deletions src/choice.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::TokenId;

#[derive(Clone, Debug)]
#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, frozen))]
pub struct BestChoice {
pub extra_token_ids: Vec<TokenId>,
pub accum_log_prob: f64,
Expand Down
Loading

0 comments on commit c09a3cc

Please sign in to comment.