From 5b7e435818b46740b47df526b63577dde7b66496 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 13 Apr 2023 10:29:45 +0900 Subject: [PATCH] Remove model decompression feature (#10) * Remove model decompression feature * fix * update tests --- .github/workflows/CI.yml | 2 +- Cargo.toml | 3 +-- README.md | 29 ++++++++++++++++++++++++++++- docs/source/examples.rst | 26 +++++++++++++++++++++++--- requirements-dev.txt | 7 ++++--- src/lib.rs | 13 +++---------- tests/test_vaporetto.py | 37 ++++++++++++++++++++++++++++--------- 7 files changed, 88 insertions(+), 29 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2d122be..24476c6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,7 +31,7 @@ jobs: - name: Test package run: | python -m pip install --upgrade pip - pip install pytest mypy + pip install pytest mypy zstandard pip install vaporetto --no-index --find-links target/wheels --force-reinstall mypy --strict tests pytest tests/test_vaporetto.py diff --git a/Cargo.toml b/Cargo.toml index 907fcdb..932f9ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "python-vaporetto" -version = "0.2.1" +version = "0.3.0" edition = "2021" authors = ["Koichi Akabe "] description = "Python wrapper of Vaporetto tokenizer" @@ -17,6 +17,5 @@ crate-type = ["cdylib"] hashbrown = "0.13.2" # MIT or Apache-2.0 ouroboros = "0.15.6" # MIT or Apache-2.0 pyo3 = { version = "0.18.2", features = ["extension-module"] } # Apache-2.0 -ruzstd = "0.3.0" # MIT vaporetto_rust = { package = "vaporetto", version = "0.6.3", features = ["kytea"] } # MIT or Apache-2.0 vaporetto_rules = "0.6.3" # MIT or Apache-2.0 diff --git a/README.md b/README.md index 3ccb1f7..6b54245 100644 --- a/README.md +++ b/README.md @@ -37,12 +37,22 @@ $ pip install git+https://github.com/daac-tools/python-vaporetto python-vaporetto does not contain model files. To perform tokenization, follow [the document of Vaporetto](https://github.com/daac-tools/vaporetto) to download distribution models or train your own models beforehand. +Check the version number as shown below to use compatible models: + +```python +import vaporetto +vaporetto.VAPORETTO_VERSION +#=> "0.6.3" +``` + +Examples: + ```python # Import vaporetto module import vaporetto # Load the model file -with open('path/to/model.zst', 'rb') as fp: +with open('path/to/model', 'rb') as fp: model = fp.read() # Create an instance of the Vaporetto @@ -65,6 +75,23 @@ tokens[0].tag(1) #=> ['まぁ', '社長', 'は', '火星', '猫', 'だ'] ``` +## Note for distributed models + +The distributed models are compressed in zstd format. If you want to load these compressed models, +you must decompress them outside the API. + +```python +import vaporetto +import zstandard # zstandard package in PyPI + +dctx = zstandard.ZstdDecompressor() +with open('path/to/model.zst', 'rb') as fp: + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags = True) +``` + +## Note for KyTea's models + You can also use KyTea's models as follows: ```python diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 7f6082c..59e404d 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -5,6 +5,14 @@ python-vaporetto does not contain model files. To perform tokenization, follow ` Vaporetto `_ to download distribution models or train your own models beforehand. +You can check the version number as shown below to use compatible models: + +.. code-block:: python + + >>> import vaporetto + >>> vaporetto.VAPORETTO_VERSION + '0.6.3' + Tokenize with Vaporetto model ----------------------------- @@ -13,8 +21,8 @@ The following example tokenizes a string using a Vaporetto model. .. code-block:: python >>> import vaporetto - >>> with open('path/to/model.zst', 'rb') as fp: - >>> model = fp.read() + >>> with open('path/to/model', 'rb') as fp: + ... model = fp.read() >>> tokenizer = vaporetto.Vaporetto(model, predict_tags = True) @@ -33,6 +41,18 @@ The following example tokenizes a string using a Vaporetto model. >>> [token.surface() for token in tokens] ['まぁ', '社長', 'は', '火星', '猫', 'だ'] +The distributed models are compressed in zstd format. If you want to load these compressed models, +you must decompress them outside the API: + +.. code-block:: python + + >>> import vaporetto + >>> import zstandard # zstandard package in PyPI + + >>> dctx = zstandard.ZstdDecompressor() + >>> with open('path/to/model.zst', 'rb') as fp: + ... dict_reader = dctx.stream_reader(fp) + >>> tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags = True) Tokenize with KyTea model ------------------------- @@ -42,6 +62,6 @@ If you want to use a KyTea model, use ``create_from_kytea_model()`` instead. .. code-block:: python >>> with open('path/to/jp-0.4.7-5.mod', 'rb') as fp: - >>> model = fp.read() + ... model = fp.read() >>> tokenizer = vaporetto.Vaporetto.create_from_kytea_model(model) diff --git a/requirements-dev.txt b/requirements-dev.txt index ffb0806..6c275b2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ -pytest>=7.2.1 +pytest>=7.3.0 pytest-benchmark>=4.0.0 -mypy>=0.991 +mypy>=1.2.0 kytea>=0.1.7 -SudachiPy>=0.6.6 +SudachiPy>=0.6.7 SudachiDict-core>=20230110 +zstandard>=0.20.0 diff --git a/src/lib.rs b/src/lib.rs index 4bc1185..3805729 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ use std::fmt::Write; -use std::io::Read; use pyo3::{exceptions::PyValueError, prelude::*, types::PyUnicode}; @@ -211,8 +210,8 @@ impl PredictorWrapper { /// /// Examples: /// >>> import vaporetto -/// >>> with open('path/to/model.zst', 'rb') as fp: -/// >>> model = fp.read() +/// >>> with open('path/to/model', 'rb') as fp: +/// ... model = fp.read() /// >>> tokenizer = vaporetto.Vaporetto(model, predict_tags = True) /// >>> tokenizer.tokenize_to_string('まぁ社長は火星猫だ') /// 'まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ' @@ -318,14 +317,8 @@ impl Vaporetto { wsconst: &str, norm: bool, ) -> PyResult { - let mut buf = vec![]; let (model, _) = py.allow_threads(|| { - let mut decoder = ruzstd::StreamingDecoder::new(model) - .map_err(|e| PyValueError::new_err(e.to_string()))?; - decoder - .read_to_end(&mut buf) - .map_err(|e| PyValueError::new_err(e.to_string()))?; - Model::read_slice(&buf).map_err(|e| PyValueError::new_err(e.to_string())) + Model::read_slice(&model).map_err(|e| PyValueError::new_err(e.to_string())) })?; Self::create_internal(py, model, predict_tags, wsconst, norm) } diff --git a/tests/test_vaporetto.py b/tests/test_vaporetto.py index e6b72bc..5d81088 100644 --- a/tests/test_vaporetto.py +++ b/tests/test_vaporetto.py @@ -3,21 +3,26 @@ import pathlib import vaporetto +import zstandard MODEL_PATH = pathlib.PurePath(__file__).parent / 'data/model.zst' def test_tokenlist_empty() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read()) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read()) tokens = tokenizer.tokenize('') assert [] == list(tokens) def test_tokenlist_index() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read()) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read()) tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert 'まぁ' == tokens[0].surface() @@ -29,8 +34,10 @@ def test_tokenlist_index() -> None: def test_tokenlist_iter() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read()) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read()) tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert ['まぁ', '社長', 'は', '火星', '猫', 'だ'] == list( @@ -39,8 +46,10 @@ def test_tokenlist_iter() -> None: def test_tokenlist_iter_positions() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read()) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read()) tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert [(0, 2), (2, 4), (4, 5), (5, 7), (7, 8), (8, 9)] == list( @@ -49,16 +58,20 @@ def test_tokenlist_iter_positions() -> None: def test_wsconst() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read(), wsconst='K') + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), wsconst='K') tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert ['まぁ', '社長', 'は', '火星猫', 'だ'] == list(token.surface() for token in tokens) def test_tags_1() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read(), predict_tags=True) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags=True) tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert ['名詞', '名詞', '助詞', '名詞', '名詞', '助動詞'] == list( @@ -67,8 +80,10 @@ def test_tags_1() -> None: def test_tags_2() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read(), predict_tags=True) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags=True) tokens = tokenizer.tokenize('まぁ社長は火星猫だ') assert ['マー', 'シャチョー', 'ワ', 'カセー', 'ネコ', 'ダ'] == list( @@ -77,14 +92,18 @@ def test_tags_2() -> None: def test_tokenize_to_string_empty() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read(), predict_tags=True) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags=True) assert '' == tokenizer.tokenize_to_string('') def test_tokenize_to_string() -> None: + dctx = zstandard.ZstdDecompressor() with open(MODEL_PATH, 'rb') as fp: - tokenizer = vaporetto.Vaporetto(fp.read(), predict_tags=True) + dict_reader = dctx.stream_reader(fp) + tokenizer = vaporetto.Vaporetto(dict_reader.read(), predict_tags=True) assert ( 'まぁ/名詞/マー 社長/名詞/シャチョー は/助詞/ワ 火星/名詞/カセー 猫/名詞/ネコ だ/助動詞/ダ' == tokenizer.tokenize_to_string('まぁ社長は火星猫だ')