diff --git a/.github/workflows/deploy-pypi.yml b/.github/workflows/deploy-pypi.yml index c88e5968..8138ae81 100644 --- a/.github/workflows/deploy-pypi.yml +++ b/.github/workflows/deploy-pypi.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: [ubuntu-latest, macOS-latest, windows-latest] steps: - uses: actions/checkout@v2 @@ -23,15 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8==4.0.1 - python -m pip install pytest==7.1.2 - python -m pip install pytest-flake8==1.1.1 - python -m pip install pydocstyle==6.1.1 - python -m pip install pytest-pydocstyle==2.3.0 - python -m pip install pytest-cov==3.0.0 - python -m pip install ray - python -m pip install 'importlib-metadata<4.3' - python -m pip install . + python -m pip install -e ".[dev]" - name: Test with pytest run: | pytest -v --flake8 --pydocstyle --cov=hiclass --cov-fail-under=90 --cov-report html diff --git a/.github/workflows/test-pr.yml b/.github/workflows/test-pr.yml index 57f15877..325c9ea7 100644 --- a/.github/workflows/test-pr.yml +++ b/.github/workflows/test-pr.yml @@ -24,14 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8==4.0.1 - python -m pip install pytest==7.1.2 - python -m pip install pytest-flake8==1.1.1 - python -m pip install pydocstyle==6.1.1 - python -m pip install pytest-pydocstyle==2.3.0 - python -m pip install pytest-cov==3.0.0 - python -m pip install ray - python -m pip install . + python -m pip install -e ".[dev]" - name: Test with pytest run: | pytest -v --flake8 --pydocstyle --cov=hiclass --cov-fail-under=90 --cov-report html diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3854349b..d2b6eade 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,6 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.2.0 hooks: - id: black diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dfea1c83..ec8ebfac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,15 +13,7 @@ Please make sure all tests pass before submitting a pull request. It is also goo To test the code locally you need to install the dependencies for the library in the current environment. Additionally, you need to install the dependencies for testing. All of those dependencies can be installed with: ``` -pip install flake8==4.0.1 -pip install pytest==7.1.2 -pip install pytest-flake8==1.1.1 -pip install pydocstyle==6.1.1 -pip install pytest-pydocstyle==2.3.0 -pip install pytest-cov==3.0.0 -pip install black==22.10.0 -pip install pre-commit==2.20.0 -pip install -e . +pip install -e ".[dev]" ``` To run the tests simply execute: diff --git a/Pipfile.lock b/Pipfile.lock index 1151d408..c6179dbe 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -9,27 +9,27 @@ { "name": "pypi", "url": "https://pypi.python.org/simple", - "verify_ssl": true + "verify_ssl": true, } - ] + ], }, "default": { "joblib": { "hashes": [ "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1", - "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9" + "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9", ], "markers": "python_version >= '3.7'", - "version": "==1.3.2" + "version": "==1.3.2", }, "networkx": { "hashes": [ "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36", - "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61" + "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61", ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==3.1" + "version": "==3.1", }, "numpy": { "hashes": [ @@ -57,11 +57,11 @@ "sha256:dfe4a913e29b418d096e696ddd422d8a5d13ffba4ea91f9f60440a3b759b0187", "sha256:eb942bfb6f84df5ce05dbf4b46673ffed0d3da59f13635ea9b926af3deb76926", "sha256:f08f2e037bba04e707eebf4bc934f1972a315c883a9e0ebfa8a7756eabf9e357", - "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760" + "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760", ], "index": "pypi", "markers": "python_version >= '3.9'", - "version": "==1.25.2" + "version": "==1.25.2", }, "scikit-learn": { "hashes": [ @@ -85,11 +85,11 @@ "sha256:c7e28d8fa47a0b30ae1bd7a079519dd852764e31708a7804da6cb6f8b36e3630", "sha256:ded35e810438a527e17623ac6deae3b360134345b7c598175ab7741720d7ffa7", "sha256:ee04835fb016e8062ee9fe9074aef9b82e430504e420bff51e3e5fffe72750ca", - "sha256:fd6e2d7389542eae01077a1ee0318c4fec20c66c957f45c7aac0c6eb0fe3c612" + "sha256:fd6e2d7389542eae01077a1ee0318c4fec20c66c957f45c7aac0c6eb0fe3c612", ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==1.3.0" + "version": "==1.3.0", }, "scipy": { "hashes": [ @@ -117,52 +117,52 @@ "sha256:ea932570b1c2a30edafca922345854ff2cd20d43cd9123b6dacfdecebfc1a80b", "sha256:f28f1f6cfeb48339c192efc6275749b2a25a7e49c4d8369a28b6591da02fbc9a", "sha256:f73102f769ee06041a3aa26b5841359b1a93cc364ce45609657751795e8f4a4a", - "sha256:fa4909c6c20c3d91480533cddbc0e7c6d849e7d9ded692918c76ce5964997898" + "sha256:fa4909c6c20c3d91480533cddbc0e7c6d849e7d9ded692918c76ce5964997898", ], "markers": "python_version < '3.13' and python_version >= '3.9'", - "version": "==1.11.2" + "version": "==1.11.2", }, "threadpoolctl": { "hashes": [ "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032", - "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355" + "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355", ], "markers": "python_version >= '3.8'", - "version": "==3.2.0" - } + "version": "==3.2.0", + }, }, "develop": { "alabaster": { "hashes": [ "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3", - "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2" + "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2", ], "markers": "python_version >= '3.6'", - "version": "==0.7.13" + "version": "==0.7.13", }, "babel": { "hashes": [ "sha256:b4246fb7677d3b98f501a39d43396d3cafdc8eadb045f4a31be01863f655c610", - "sha256:cc2d99999cd01d44420ae725a21c9e3711b3aadc7976d6147f622d8581963455" + "sha256:cc2d99999cd01d44420ae725a21c9e3711b3aadc7976d6147f622d8581963455", ], "markers": "python_version >= '3.7'", - "version": "==2.12.1" + "version": "==2.12.1", }, "bleach": { "hashes": [ "sha256:1a1a85c1595e07d8db14c5f09f09e6433502c51c595970edc090551f0db99414", - "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4" + "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4", ], "markers": "python_version >= '3.7'", - "version": "==6.0.0" + "version": "==6.0.0", }, "certifi": { "hashes": [ "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082", - "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9" + "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9", ], "markers": "python_version >= '3.6'", - "version": "==2023.7.22" + "version": "==2023.7.22", }, "charset-normalizer": { "hashes": [ @@ -240,15 +240,13 @@ "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1", "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c", "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac", - "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa" + "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa", ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.2.0" + "version": "==3.2.0", }, "coverage": { - "extras": [ - "toml" - ], + "extras": ["toml"], "hashes": [ "sha256:07ea61bcb179f8f05ffd804d2732b09d23a1238642bf7e51dad62082b5019b34", "sha256:1084393c6bda8875c05e04fce5cfe1301a425f758eb012f010eab586f1f3905e", @@ -301,90 +299,90 @@ "sha256:e2ac9a1de294773b9fa77447ab7e529cf4fe3910f6a0832816e5f3d538cfea9a", "sha256:e61260ec93f99f2c2d93d264b564ba912bec502f679793c56f678ba5251f0393", "sha256:fac440c43e9b479d1241fe9d768645e7ccec3fb65dc3a5f6e90675e75c3f3e3a", - "sha256:fc0ed8d310afe013db1eedd37176d0839dc66c96bcfcce8f6607a73ffea2d6ba" + "sha256:fc0ed8d310afe013db1eedd37176d0839dc66c96bcfcce8f6607a73ffea2d6ba", ], "markers": "python_version >= '3.8'", - "version": "==7.3.0" + "version": "==7.3.0", }, "docutils": { "hashes": [ "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af", - "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc" + "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", - "version": "==0.16" + "version": "==0.16", }, "flake8": { "hashes": [ "sha256:d5b3857f07c030bdb5bf41c7f53799571d75c4491748a3adcd47de929e34cd23", - "sha256:ffdfce58ea94c6580c77888a86506937f9a1a227dfcd15f245d694ae20a6b6e5" + "sha256:ffdfce58ea94c6580c77888a86506937f9a1a227dfcd15f245d694ae20a6b6e5", ], "markers": "python_full_version >= '3.8.1'", - "version": "==6.1.0" + "version": "==6.1.0", }, "idna": { "hashes": [ "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4", - "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2" + "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2", ], "markers": "python_version >= '3.5'", - "version": "==3.4" + "version": "==3.4", }, "imagesize": { "hashes": [ "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", - "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a" + "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.4.1" + "version": "==1.4.1", }, "importlib-metadata": { "hashes": [ "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb", - "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743" + "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743", ], "markers": "python_version >= '3.8'", - "version": "==6.8.0" + "version": "==6.8.0", }, "iniconfig": { "hashes": [ "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", - "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374" + "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", ], "markers": "python_version >= '3.7'", - "version": "==2.0.0" + "version": "==2.0.0", }, "jaraco.classes": { "hashes": [ "sha256:10afa92b6743f25c0cf5f37c6bb6e18e2c5bb84a16527ccfc0040ea377e7aaeb", - "sha256:c063dd08e89217cee02c8d5e5ec560f2c8ce6cdc2fcdc2e68f7b2e5547ed3621" + "sha256:c063dd08e89217cee02c8d5e5ec560f2c8ce6cdc2fcdc2e68f7b2e5547ed3621", ], "markers": "python_version >= '3.8'", - "version": "==3.3.0" + "version": "==3.3.0", }, "jinja2": { "hashes": [ "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852", - "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61" + "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61", ], "markers": "python_version >= '3.7'", - "version": "==3.1.2" + "version": "==3.1.2", }, "keyring": { "hashes": [ "sha256:4901caaf597bfd3bbd78c9a0c7c4c29fcd8310dab2cffefe749e916b6527acd6", - "sha256:ca0746a19ec421219f4d713f848fa297a661a8a8c1504867e55bfb5e09091509" + "sha256:ca0746a19ec421219f4d713f848fa297a661a8a8c1504867e55bfb5e09091509", ], "markers": "python_version >= '3.8'", - "version": "==24.2.0" + "version": "==24.2.0", }, "markdown-it-py": { "hashes": [ "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", - "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb" + "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", ], "markers": "python_version >= '3.8'", - "version": "==3.0.0" + "version": "==3.0.0", }, "markupsafe": { "hashes": [ @@ -437,116 +435,116 @@ "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9", "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57", "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc", - "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2" + "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2", ], "markers": "python_version >= '3.7'", - "version": "==2.1.3" + "version": "==2.1.3", }, "mccabe": { "hashes": [ "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", - "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e" + "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", ], "markers": "python_version >= '3.6'", - "version": "==0.7.0" + "version": "==0.7.0", }, "mdurl": { "hashes": [ "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", - "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba" + "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", ], "markers": "python_version >= '3.7'", - "version": "==0.1.2" + "version": "==0.1.2", }, "more-itertools": { "hashes": [ "sha256:626c369fa0eb37bac0291bce8259b332fd59ac792fa5497b59837309cd5b114a", - "sha256:64e0735fcfdc6f3464ea133afe8ea4483b1c5fe3a3d69852e6503b43a0b222e6" + "sha256:64e0735fcfdc6f3464ea133afe8ea4483b1c5fe3a3d69852e6503b43a0b222e6", ], "markers": "python_version >= '3.8'", - "version": "==10.1.0" + "version": "==10.1.0", }, "packaging": { "hashes": [ "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", - "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f" + "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f", ], "markers": "python_version >= '3.7'", - "version": "==23.1" + "version": "==23.1", }, "pkginfo": { "hashes": [ "sha256:4b7a555a6d5a22169fcc9cf7bfd78d296b0361adad412a346c1226849af5e546", - "sha256:8fd5896e8718a4372f0ea9cc9d96f6417c9b986e23a4d116dda26b62cc29d046" + "sha256:8fd5896e8718a4372f0ea9cc9d96f6417c9b986e23a4d116dda26b62cc29d046", ], "markers": "python_version >= '3.6'", - "version": "==1.9.6" + "version": "==1.9.6", }, "pluggy": { "hashes": [ "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12", - "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7" + "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7", ], "markers": "python_version >= '3.8'", - "version": "==1.3.0" + "version": "==1.3.0", }, "pycodestyle": { "hashes": [ "sha256:259bcc17857d8a8b3b4a2327324b79e5f020a13c16074670f9c8c8f872ea76d0", - "sha256:5d1013ba8dc7895b548be5afb05740ca82454fd899971563d2ef625d090326f8" + "sha256:5d1013ba8dc7895b548be5afb05740ca82454fd899971563d2ef625d090326f8", ], "markers": "python_version >= '3.8'", - "version": "==2.11.0" + "version": "==2.11.0", }, "pydocstyle": { "hashes": [ "sha256:118762d452a49d6b05e194ef344a55822987a462831ade91ec5c06fd2169d019", - "sha256:7ce43f0c0ac87b07494eb9c0b462c0b73e6ff276807f204d6b53edc72b7e44e1" + "sha256:7ce43f0c0ac87b07494eb9c0b462c0b73e6ff276807f204d6b53edc72b7e44e1", ], "markers": "python_version >= '3.6'", - "version": "==6.3.0" + "version": "==6.3.0", }, "pyflakes": { "hashes": [ "sha256:4132f6d49cb4dae6819e5379898f2b8cce3c5f23994194c24b77d5da2e36f774", - "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc" + "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc", ], "markers": "python_version >= '3.8'", - "version": "==3.1.0" + "version": "==3.1.0", }, "pygments": { "hashes": [ "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692", - "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29" + "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29", ], "markers": "python_version >= '3.7'", - "version": "==2.16.1" + "version": "==2.16.1", }, "pytest": { "hashes": [ "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32", - "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a" + "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a", ], "index": "pypi", "markers": "python_version >= '3.7'", - "version": "==7.4.0" + "version": "==7.4.0", }, "pytest-cov": { "hashes": [ "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", - "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a" + "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", ], "index": "pypi", "markers": "python_version >= '3.7'", - "version": "==4.1.0" + "version": "==4.1.0", }, "pytest-flake8": { "hashes": [ "sha256:ba4f243de3cb4c2486ed9e70752c80dd4b636f7ccb27d4eba763c35ed0cd316e", - "sha256:e0661a786f8cbf976c185f706fdaf5d6df0b1667c3bcff8e823ba263618627e7" + "sha256:e0661a786f8cbf976c185f706fdaf5d6df0b1667c3bcff8e823ba263618627e7", ], "index": "pypi", - "version": "==1.1.1" + "version": "==1.1.1", }, "pytest-pydocstyle": { "hashes": [ @@ -554,193 +552,193 @@ ], "index": "pypi", "markers": "python_version ~= '3.7'", - "version": "==2.3.2" + "version": "==2.3.2", }, "readme-renderer": { "hashes": [ "sha256:4f4b11e5893f5a5d725f592c5a343e0dc74f5f273cb3dcf8c42d9703a27073f7", - "sha256:a38243d5b6741b700a850026e62da4bd739edc7422071e95fd5c4bb60171df86" + "sha256:a38243d5b6741b700a850026e62da4bd739edc7422071e95fd5c4bb60171df86", ], "markers": "python_version >= '3.8'", - "version": "==41.0" + "version": "==41.0", }, "requests": { "hashes": [ "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f", - "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1" + "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1", ], "markers": "python_version >= '3.7'", - "version": "==2.31.0" + "version": "==2.31.0", }, "requests-toolbelt": { "hashes": [ "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", - "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06" + "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.0.0" + "version": "==1.0.0", }, "rfc3986": { "hashes": [ "sha256:50b1502b60e289cb37883f3dfd34532b8873c7de9f49bb546641ce9cbd256ebd", - "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c" + "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c", ], "markers": "python_version >= '3.7'", - "version": "==2.0.0" + "version": "==2.0.0", }, "rich": { "hashes": [ "sha256:146a90b3b6b47cac4a73c12866a499e9817426423f57c5a66949c086191a8808", - "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39" + "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39", ], "markers": "python_full_version >= '3.7.0'", - "version": "==13.5.2" + "version": "==13.5.2", }, "setuptools": { "hashes": [ "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d", - "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b" + "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b", ], "markers": "python_version >= '3.8'", - "version": "==68.1.2" + "version": "==68.1.2", }, "six": { "hashes": [ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", - "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" + "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.16.0" + "version": "==1.16.0", }, "snowballstemmer": { "hashes": [ "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1", - "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a" + "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a", ], - "version": "==2.2.0" + "version": "==2.2.0", }, "sphinx": { "hashes": [ "sha256:23c846a1841af998cb736218539bb86d16f5eb95f5760b1966abcd2d584e62b8", - "sha256:3d513088236eef51e5b0adb78b0492eb22cc3b8ccdb0b36dd021173b365d4454" + "sha256:3d513088236eef51e5b0adb78b0492eb22cc3b8ccdb0b36dd021173b365d4454", ], "index": "pypi", "markers": "python_version >= '3.6'", - "version": "==4.1.1" + "version": "==4.1.1", }, "sphinx-rtd-theme": { "hashes": [ "sha256:32bd3b5d13dc8186d7a42fc816a23d32e83a4827d7d9882948e7b837c232da5a", - "sha256:4a05bdbe8b1446d77a01e20a23ebc6777c74f43237035e76be89699308987d6f" + "sha256:4a05bdbe8b1446d77a01e20a23ebc6777c74f43237035e76be89699308987d6f", ], "index": "pypi", - "version": "==0.5.2" + "version": "==0.5.2", }, "sphinxcontrib-applehelp": { "hashes": [ "sha256:29d341f67fb0f6f586b23ad80e072c8e6ad0b48417db2bde114a4c9746feb228", - "sha256:828f867945bbe39817c210a1abfd1bc4895c8b73fcaade56d45357a348a07d7e" + "sha256:828f867945bbe39817c210a1abfd1bc4895c8b73fcaade56d45357a348a07d7e", ], "markers": "python_version >= '3.8'", - "version": "==1.0.4" + "version": "==1.0.4", }, "sphinxcontrib-devhelp": { "hashes": [ "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e", - "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4" + "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4", ], "markers": "python_version >= '3.5'", - "version": "==1.0.2" + "version": "==1.0.2", }, "sphinxcontrib-htmlhelp": { "hashes": [ "sha256:0cbdd302815330058422b98a113195c9249825d681e18f11e8b1f78a2f11efff", - "sha256:c38cb46dccf316c79de6e5515e1770414b797162b23cd3d06e67020e1d2a6903" + "sha256:c38cb46dccf316c79de6e5515e1770414b797162b23cd3d06e67020e1d2a6903", ], "markers": "python_version >= '3.8'", - "version": "==2.0.1" + "version": "==2.0.1", }, "sphinxcontrib-jsmath": { "hashes": [ "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", - "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8" + "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", ], "markers": "python_version >= '3.5'", - "version": "==1.0.1" + "version": "==1.0.1", }, "sphinxcontrib-qthelp": { "hashes": [ "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72", - "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6" + "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6", ], "markers": "python_version >= '3.5'", - "version": "==1.0.3" + "version": "==1.0.3", }, "sphinxcontrib-serializinghtml": { "hashes": [ "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd", - "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952" + "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952", ], "markers": "python_version >= '3.5'", - "version": "==1.1.5" + "version": "==1.1.5", }, "twine": { "hashes": [ "sha256:929bc3c280033347a00f847236564d1c52a3e61b1ac2516c97c48f3ceab756d8", - "sha256:9e102ef5fdd5a20661eb88fad46338806c3bd32cf1db729603fe3697b1bc83c8" + "sha256:9e102ef5fdd5a20661eb88fad46338806c3bd32cf1db729603fe3697b1bc83c8", ], "index": "pypi", "markers": "python_version >= '3.7'", - "version": "==4.0.2" + "version": "==4.0.2", }, "urllib3": { "hashes": [ "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11", - "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4" + "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4", ], "markers": "python_version >= '3.7'", - "version": "==2.0.4" + "version": "==2.0.4", }, "webencodings": { "hashes": [ "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", - "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923" + "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", ], - "version": "==0.5.1" + "version": "==0.5.1", }, "zipp": { "hashes": [ "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0", - "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147" + "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147", ], "markers": "python_version >= '3.8'", - "version": "==3.16.2" - } + "version": "==3.16.2", + }, }, "extras": { "aiosignal": { "hashes": [ "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc", - "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17" + "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", ], "markers": "python_version >= '3.7'", - "version": "==1.3.1" + "version": "==1.3.1", }, "attrs": { "hashes": [ "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04", - "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015" + "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015", ], "markers": "python_version >= '3.7'", - "version": "==23.1.0" + "version": "==23.1.0", }, "certifi": { "hashes": [ "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082", - "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9" + "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9", ], "markers": "python_version >= '3.6'", - "version": "==2023.7.22" + "version": "==2023.7.22", }, "charset-normalizer": { "hashes": [ @@ -818,26 +816,26 @@ "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1", "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c", "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac", - "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa" + "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa", ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.2.0" + "version": "==3.2.0", }, "click": { "hashes": [ "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", - "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de" + "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", ], "markers": "python_version >= '3.7'", - "version": "==8.1.7" + "version": "==8.1.7", }, "filelock": { "hashes": [ "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d", - "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb" + "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb", ], "markers": "python_version >= '3.8'", - "version": "==3.12.3" + "version": "==3.12.3", }, "frozenlist": { "hashes": [ @@ -901,10 +899,10 @@ "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f", "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3", "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1", - "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e" + "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e", ], "markers": "python_version >= '3.8'", - "version": "==1.4.0" + "version": "==1.4.0", }, "grpcio": { "hashes": [ @@ -952,34 +950,34 @@ "sha256:fada6b07ec4f0befe05218181f4b85176f11d531911b64c715d1875c4736d73a", "sha256:fd173b4cf02b20f60860dc2ffe30115c18972d7d6d2d69df97ac38dee03be5bf", "sha256:fe752639919aad9ffb0dee0d87f29a6467d1ef764f13c4644d212a9a853a078d", - "sha256:fee387d2fab144e8a34e0e9c5ca0f45c9376b99de45628265cfa9886b1dbe62b" + "sha256:fee387d2fab144e8a34e0e9c5ca0f45c9376b99de45628265cfa9886b1dbe62b", ], "markers": "python_version >= '3.10'", - "version": "==1.57.0" + "version": "==1.57.0", }, "idna": { "hashes": [ "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4", - "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2" + "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2", ], "markers": "python_version >= '3.5'", - "version": "==3.4" + "version": "==3.4", }, "jsonschema": { "hashes": [ "sha256:043dc26a3845ff09d20e4420d6012a9c91c9aa8999fa184e7efcfeccb41e32cb", - "sha256:6e1e7569ac13be8139b2dd2c21a55d350066ee3f80df06c608b398cdc6f30e8f" + "sha256:6e1e7569ac13be8139b2dd2c21a55d350066ee3f80df06c608b398cdc6f30e8f", ], "markers": "python_version >= '3.8'", - "version": "==4.19.0" + "version": "==4.19.0", }, "jsonschema-specifications": { "hashes": [ "sha256:05adf340b659828a004220a9613be00fa3f223f2b82002e273dee62fd50524b1", - "sha256:c91a50404e88a1f6ba40636778e2ee08f6e24c5613fe4c53ac24578a5a7f72bb" + "sha256:c91a50404e88a1f6ba40636778e2ee08f6e24c5613fe4c53ac24578a5a7f72bb", ], "markers": "python_version >= '3.8'", - "version": "==2023.7.1" + "version": "==2023.7.1", }, "msgpack": { "hashes": [ @@ -1045,9 +1043,9 @@ "sha256:ed40e926fa2f297e8a653c954b732f125ef97bdd4c889f243182299de27e2aa9", "sha256:ef8108f8dedf204bb7b42994abf93882da1159728a2d4c5e82012edd92c9da9f", "sha256:f933bbda5a3ee63b8834179096923b094b76f0c7a73c1cfe8f07ad608c58844b", - "sha256:fe5c63197c55bce6385d9aee16c4d0641684628f63ace85f73571e65ad1c1e8d" + "sha256:fe5c63197c55bce6385d9aee16c4d0641684628f63ace85f73571e65ad1c1e8d", ], - "version": "==1.0.5" + "version": "==1.0.5", }, "numpy": { "hashes": [ @@ -1075,19 +1073,19 @@ "sha256:dfe4a913e29b418d096e696ddd422d8a5d13ffba4ea91f9f60440a3b759b0187", "sha256:eb942bfb6f84df5ce05dbf4b46673ffed0d3da59f13635ea9b926af3deb76926", "sha256:f08f2e037bba04e707eebf4bc934f1972a315c883a9e0ebfa8a7756eabf9e357", - "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760" + "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760", ], "index": "pypi", "markers": "python_version >= '3.9'", - "version": "==1.25.2" + "version": "==1.25.2", }, "packaging": { "hashes": [ "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", - "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f" + "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f", ], "markers": "python_version >= '3.7'", - "version": "==23.1" + "version": "==23.1", }, "protobuf": { "hashes": [ @@ -1103,10 +1101,10 @@ "sha256:839952e759fc40b5d46be319a265cf94920174d88de31657d5622b5d8d6be5cd", "sha256:bb7aa97c252279da65584af0456f802bd4b2de429eb945bbc9b3d61a42a8cd16", "sha256:c00c3c7eb9ad3833806e21e86dca448f46035242a680f81c3fe068ff65e79c74", - "sha256:c5cdd486af081bf752225b26809d2d0a85e575b80a84cde5172a05bbb1990099" + "sha256:c5cdd486af081bf752225b26809d2d0a85e575b80a84cde5172a05bbb1990099", ], "markers": "python_version >= '3.7'", - "version": "==4.24.2" + "version": "==4.24.2", }, "pyyaml": { "hashes": [ @@ -1159,10 +1157,10 @@ "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c", "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585", "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d", - "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f" + "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f", ], "markers": "python_version >= '3.6'", - "version": "==6.0.1" + "version": "==6.0.1", }, "ray": { "hashes": [ @@ -1189,26 +1187,26 @@ "sha256:bca66c8e8163f06dc5443623e7b221660529a39574a589ba9257f2188ea8bf6b", "sha256:bdeacaafcbb97e5f1c3c3349e7fcc0c40f691cea2bf057027c5491ea1ac929b0", "sha256:dff21468d621c8dac95b3df320e6c6121f6618f6827243fd75a057c8815c2498", - "sha256:e0f8eaf4c4592335722dad474685c2ffc98207b997e47a24b297a60db389a4cb" + "sha256:e0f8eaf4c4592335722dad474685c2ffc98207b997e47a24b297a60db389a4cb", ], "index": "pypi", - "version": "==2.6.3" + "version": "==2.6.3", }, "referencing": { "hashes": [ "sha256:449b6669b6121a9e96a7f9e410b245d471e8d48964c67113ce9afe50c8dd7bdf", - "sha256:794ad8003c65938edcdbc027f1933215e0d0ccc0291e3ce20a4d87432b59efc0" + "sha256:794ad8003c65938edcdbc027f1933215e0d0ccc0291e3ce20a4d87432b59efc0", ], "markers": "python_version >= '3.8'", - "version": "==0.30.2" + "version": "==0.30.2", }, "requests": { "hashes": [ "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f", - "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1" + "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1", ], "markers": "python_version >= '3.7'", - "version": "==2.31.0" + "version": "==2.31.0", }, "rpds-py": { "hashes": [ @@ -1308,18 +1306,18 @@ "sha256:f9e7e493ded7042712a374471203dd43ae3fff5b81e3de1a0513fa241af9fd41", "sha256:fc72ae476732cdb7b2c1acb5af23b478b8a0d4b6fcf19b90dd150291e0d5b26b", "sha256:fccbf0cd3411719e4c9426755df90bf3449d9fc5a89f077f4a7f1abd4f70c910", - "sha256:ffcf18ad3edf1c170e27e88b10282a2c449aa0358659592462448d71b2000cfc" + "sha256:ffcf18ad3edf1c170e27e88b10282a2c449aa0358659592462448d71b2000cfc", ], "markers": "python_version >= '3.8'", - "version": "==0.10.0" + "version": "==0.10.0", }, "urllib3": { "hashes": [ "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11", - "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4" + "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4", ], "markers": "python_version >= '3.7'", - "version": "==2.0.4" - } - } + "version": "==2.0.4", + }, + }, } diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index bd07da11..23e422ab 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -1,7 +1,9 @@ """Shared code for all classifiers.""" import abc +import hashlib import logging +import pickle import networkx as nx import numpy as np @@ -68,6 +70,7 @@ def __init__( n_jobs: int = 1, bert: bool = False, classifier_abbreviation: str = "", + tmp_dir: str = None, ): """ Initialize a local hierarchical classifier. @@ -93,6 +96,9 @@ def __init__( If True, skip scikit-learn's checks and sample_weight passing for BERT. classifier_abbreviation : str, default="" The abbreviation of the local hierarchical classifier to be displayed during logging. + tmp_dir : str, default=None + Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, + it will skip the pre-trained local classifier found in the temporary directory. """ self.local_classifier = local_classifier self.verbose = verbose @@ -101,6 +107,7 @@ def __init__( self.n_jobs = n_jobs self.bert = bert self.classifier_abbreviation = classifier_abbreviation + self.tmp_dir = tmp_dir def fit(self, X, y, sample_weight=None): """ @@ -341,7 +348,9 @@ def _fit_node_classifier( @staticmethod def _fit_classifier(self, node): - raise NotImplementedError("Method should be implemented in the LCPN and LCPPN") + raise NotImplementedError( + "Method should be implemented in the LCPN, LCPPN or LCPL" + ) def _clean_up(self): self.logger_.info("Cleaning up variables that can take a lot of disk space") @@ -349,3 +358,18 @@ def _clean_up(self): del self.y_ if self.sample_weight_ is not None: del self.sample_weight_ + + def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): + raise NotImplementedError( + "Method should be implemented in the LCPN, LCPPN or LCPL" + ) + + def _save_tmp(self, name, classifier): + if self.tmp_dir: + md5 = hashlib.md5(str(name).encode("utf-8")).hexdigest() + filename = f"{self.tmp_dir}/{md5}.sav" + with open(filename, "wb") as file: + pickle.dump((name, classifier), file) + self.logger_.info( + f"Stored trained model for local classifier {str(name).split(self.separator_)[-1]} in file {filename}" + ) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 0c174809..907e61cf 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -4,7 +4,10 @@ Numeric and string output labels are both handled. """ +import hashlib +import pickle from copy import deepcopy +from os.path import exists import numpy as np from joblib import Parallel, delayed @@ -49,6 +52,7 @@ def __init__( replace_classifiers: bool = True, n_jobs: int = 1, bert: bool = False, + tmp_dir: str = None, ): """ Initialize a local classifier per level. @@ -72,6 +76,9 @@ def __init__( If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. bert : bool, default=False If True, skip scikit-learn's checks and sample_weight passing for BERT. + tmp_dir : str, default=None + Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, + it will skip the pre-trained local classifier found in the temporary directory. """ super().__init__( local_classifier=local_classifier, @@ -81,6 +88,7 @@ def __init__( n_jobs=n_jobs, classifier_abbreviation="LCPL", bert=bert, + tmp_dir=tmp_dir, ) def fit(self, X, y, sample_weight=None): @@ -246,7 +254,16 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): @staticmethod def _fit_classifier(self, level, separator): classifier = self.local_classifiers_[level] - + if self.tmp_dir: + md5 = hashlib.md5(str(level).encode("utf-8")).hexdigest() + filename = f"{self.tmp_dir}/{md5}.sav" + if exists(filename): + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {level} from file {filename}" + ) + return classifier + self.logger_.info(f"Training local classifier {level}") X, y, sample_weight = self._remove_empty_leaves( separator, self.X_, self.y_[:, level], self.sample_weight_ ) @@ -261,6 +278,7 @@ def _fit_classifier(self, level, separator): classifier.fit(X, y) else: classifier.fit(X, y) + self._save_tmp(level, classifier) return classifier @staticmethod diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 6debe71b..1382c72e 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -4,7 +4,10 @@ Numeric and string output labels are both handled. """ +import hashlib +import pickle from copy import deepcopy +from os.path import exists import networkx as nx import numpy as np @@ -44,6 +47,7 @@ def __init__( replace_classifiers: bool = True, n_jobs: int = 1, bert: bool = False, + tmp_dir: str = None, ): """ Initialize a local classifier per node. @@ -78,6 +82,9 @@ def __init__( If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. bert : bool, default=False If True, skip scikit-learn's checks and sample_weight passing for BERT. + tmp_dir : str, default=None + Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, + it will skip the pre-trained local classifier found in the temporary directory. """ super().__init__( local_classifier=local_classifier, @@ -87,6 +94,7 @@ def __init__( n_jobs=n_jobs, classifier_abbreviation="LCPN", bert=bert, + tmp_dir=tmp_dir, ) self.binary_policy = binary_policy @@ -237,6 +245,16 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): @staticmethod def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] + if self.tmp_dir: + md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + filename = f"{self.tmp_dir}/{md5}.sav" + if exists(filename): + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + ) + return classifier + self.logger_.info(f"Training local classifier {node}") X, y, sample_weight = self.binary_policy_.get_binary_examples(node) unique_y = np.unique(y) if len(unique_y) == 1 and self.replace_classifiers: @@ -248,6 +266,7 @@ def _fit_classifier(self, node): classifier.fit(X, y) else: classifier.fit(X, y) + self._save_tmp(node, classifier) return classifier def _clean_up(self): diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 50a30ca0..47f77475 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -4,7 +4,10 @@ Numeric and string output labels are both handled. """ +import hashlib +import pickle from copy import deepcopy +from os.path import exists import networkx as nx import numpy as np @@ -42,6 +45,7 @@ def __init__( replace_classifiers: bool = True, n_jobs: int = 1, bert: bool = False, + tmp_dir: str = None, ): """ Initialize a local classifier per parent node. @@ -65,6 +69,9 @@ def __init__( If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. bert : bool, default=False If True, skip scikit-learn's checks and sample_weight passing for BERT. + tmp_dir : str, default=None + Temporary directory to persist local classifiers that are trained. If the job needs to be restarted, + it will skip the pre-trained local classifier found in the temporary directory. """ super().__init__( local_classifier=local_classifier, @@ -74,6 +81,7 @@ def __init__( n_jobs=n_jobs, classifier_abbreviation="LCPPN", bert=bert, + tmp_dir=tmp_dir, ) def fit(self, X, y, sample_weight=None): @@ -206,6 +214,16 @@ def _get_successors(self, node): @staticmethod def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] + if self.tmp_dir: + md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + filename = f"{self.tmp_dir}/{md5}.sav" + if exists(filename): + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + ) + return classifier + self.logger_.info(f"Training local classifier {node}") # get children examples X, y, sample_weight = self._get_successors(node) unique_y = np.unique(y) @@ -218,6 +236,7 @@ def _fit_classifier(self, node): classifier.fit(X, y) else: classifier.fit(X, y) + self._save_tmp(node, classifier) return classifier def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): diff --git a/setup.cfg b/setup.cfg index fa25a94b..11c68c15 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,7 @@ exclude = **/__init__.py, docs/source/conf.py ;file.py: error [requires] -python_version = ">=3.7,<3.12" +python_version = ">=3.8,<3.12" # See the docstring in versioneer.py for instructions. Note that you must # re-run 'versioneer.py setup' after changing this section, and commit the diff --git a/setup.py b/setup.py index c93c16f6..b2a49b97 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ URL_ISSUES = "https://github.com/scikit-learn-contrib/hiclass/issues" EMAIL = "fabio.malchermiranda@hpi.de, Niklas.Koehnecke@student.hpi.uni-potsdam.de" AUTHOR = "Fabio Malcher Miranda, Niklas Koehnecke" -REQUIRES_PYTHON = ">=3.7,<3.12" +REQUIRES_PYTHON = ">=3.8,<3.12" KEYWORDS = ["hierarchical classification"] DACS_SOFTWARE = "https://gitlab.com/dacs-hpi" # What packages are required for this module to be executed? @@ -31,7 +31,22 @@ # What packages are optional? # 'fancy feature': ['django'],} -EXTRAS = {"ray": ["ray>=1.11.0"], "xai": ["shap", "xarray"]} +EXTRAS = { + "ray": ["ray>=1.11.0"], + "xai": ["shap", "xarray"], + "dev": [ + "flake8==4.0.1", + "pytest==7.1.2", + "pytest-flake8==1.1.1", + "pydocstyle==6.1.1", + "pytest-pydocstyle==2.3.0", + "pytest-cov==3.0.0", + "pyfakefs==5.3.5", + "black==24.2.0", + "pre-commit==2.20.0", + "ray", + ], +} # The rest you shouldn't have to touch too much :) # ------------------------------------------------ @@ -140,7 +155,6 @@ def run(self): "Operating System :: Unix", "Operating System :: MacOS", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index d800ff47..3333cf52 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -216,6 +216,11 @@ def test_fit_classifier(): HierarchicalClassifier._fit_classifier(None, None) +def test_fit_digraph(): + with pytest.raises(NotImplementedError): + HierarchicalClassifier._fit_digraph(None, None) + + def test_pre_fit_bert(): classifier = HierarchicalClassifier() classifier.logger_ = logging.getLogger("HC") diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index 37f1bf46..abd7bddf 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -1,6 +1,10 @@ +import os +import pickle + import numpy as np import pytest from numpy.testing import assert_array_equal +from pyfakefs.fake_filesystem_unittest import Patcher from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier from sklearn.utils.validation import check_is_fitted @@ -119,3 +123,23 @@ def test_predict_multiple_dim_input(classifier): clf.fit(X, y) predictions = clf.predict(X) assert predictions is not None + + +@pytest.mark.parametrize("classifier", classifiers) +def test_tmp_dir(classifier): + clf = classifier(tmp_dir=".") + with Patcher() as patcher: + x = np.array([[1, 2], [3, 4]]) + y = np.array([["a", "b"], ["c", "d"]]) + clf.fit(x, y) + if isinstance(clf, LocalClassifierPerLevel): + filename = "cfcd208495d565ef66e7dff9f98764da.sav" + expected_name = 0 + else: + filename = "0cc175b9c0f1b6a831c399e269772661.sav" + expected_name = "a" + assert patcher.fs.exists(filename) + (name, classifier) = pickle.load(open(filename, "rb")) + assert expected_name == name + check_is_fitted(classifier) + clf.fit(x, y)