diff --git a/docs/overrides/main.html b/docs/overrides/main.html
index 7d348f3..00e0bce 100644
--- a/docs/overrides/main.html
+++ b/docs/overrides/main.html
@@ -8,7 +8,7 @@
Navigate the site here!
- v0.4.2 is out!
+ v0.4.3 is out!
diff --git a/lanfactory/__init__.py b/lanfactory/__init__.py
index fac5705..6e23957 100755
--- a/lanfactory/__init__.py
+++ b/lanfactory/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.4.2"
+__version__ = "0.4.3"
from . import config
from . import trainers
diff --git a/notebooks/test_notebooks/test_jax_network.ipynb b/notebooks/test_notebooks/test_jax_network.ipynb
index b02708d..b49daf5 100644
--- a/notebooks/test_notebooks/test_jax_network.ipynb
+++ b/notebooks/test_notebooks/test_jax_network.ipynb
@@ -4,7 +4,20 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "ImportError",
+ "evalue": "cannot import name 'onnx' from partially initialized module 'lanfactory' (most likely due to a circular import) (/users/afengler/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mssms\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mlanfactory\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n",
+ "File \u001b[0;32m~/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py:6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m trainers\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m utils\n\u001b[0;32m----> 6\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m onnx\n\u001b[1;32m 8\u001b[0m __all__ \u001b[39m=\u001b[39m [\u001b[39m\"\u001b[39m\u001b[39mconfig\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mtrainers\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mutils\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39monnx\u001b[39m\u001b[39m\"\u001b[39m]\n",
+ "\u001b[0;31mImportError\u001b[0m: cannot import name 'onnx' from partially initialized module 'lanfactory' (most likely due to a circular import) (/users/afengler/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py)"
+ ]
+ }
+ ],
"source": [
"import ssms\n",
"import lanfactory\n",
diff --git a/pyproject.toml b/pyproject.toml
index 12e6e7b..371c7a9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ requires = ["setuptools", "wheel"]
[project]
name= "lanfactory"
-version= "0.4.2"
+version= "0.4.3"
authors= [{name = "Alexander Fenger", email = "alexander_fengler@brown.edu"}]
description= "Package with convenience functions to train LANs"
readme = "README.md"
@@ -24,7 +24,7 @@ dependencies=[
"SciPy >= 1.6.3",
"pandas >= 1.2.4",
"torch >= 1.7",
- "jax >= 0.4.2",
+ "jax >= 0.4.14",
"flax >= 0.6.4",
"optax >= 0.1.4",
"tqdm >= 4.0.0",
diff --git a/setup.py b/setup.py
index 1341ac8..9e6ebab 100755
--- a/setup.py
+++ b/setup.py
@@ -6,5 +6,6 @@
"lanfactory.config",
"lanfactory.trainers",
"lanfactory.utils",
+ "lanfactory.onnx"
],
)