You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It's quite possible that 10,000 rows is simply too large for the forward pass. One idea -- though I've never tried it -- could be to split the rows into smaller chunks and create a bootstrapped estimate of the graph by running several forward passes.
However, it seems that your error occurs here, after the forward pass is already done, can you confirm this?
Maybe call jax.block_until_ready before this line to confirm, see here. In that case I don't currently know what could be the issue and would have to investigate. It would be great if you could provide a minimal example that reproduces this with random synthetic data
Traceback (most recent call last): File "/teamspace/studios/this_studio/mcgill_fiam/0X-Causal_discovery/discovery.py", line 30, in <module> g_prob = model(x=x) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/avici/pretrain.py", line 109, in __call__ out = onp.array(out) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 429, in __array__ return np.asarray(self._value, dtype=dtype, **kwds) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 628, in _value self._npy_value = self._single_device_array_to_np_array() jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error preparing computation: %sOut of memory allocating 332034480032 bytes.
This is on 10,000 rows with 51 variables. Can you help me with this issue?
The text was updated successfully, but these errors were encountered: