You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all,
I'm currently working on a model with 1 GPU and I am facing an OOM error.
All the chains are loaded in the GPU memory and we exceed the 16GB of the GPU.
To fix it, I have implemented batching. (cf (Forum 1, Forum 2). )
It now runs with 2000 samples. However BFMI is low -> I need to increase the samples.
OOM error is again triggered with 4000 samples. It seems that 2 chains are always kept in memory. (1 chain being 6GB, 2 chains 12GB -> OOM on 16GB GPU)
Below is the experiments I have done:
Experiments so far:
XLA_PYTHON_CLIENT_PREALLOCATE
XLA_PYTHON_CLIENT_ALLOCATOR
num_samples
num_warmup
Chain method
VRAM at beginning of each sample
Successfully run?
False
Platform
4000
10000
sequential
1. 744 MB 2. 6566MB
No
False
Platform
2000
10000
sequential
1. 744 MB 2. 3695 MB 3. 3695 MB 4. 3695 MB
Yes
Default
Default
2000
10000
sequential
1. 13,385 MB 2. 13,385 MB
No
Similar to other in the forums, the 2nd batch will usually cause an OOM error. I have attempted to pass mainly these two flags: XLA_PYTHON_CLIENT_PREALLOCATE and XLA_PYTHON_CLIENT_ALLOCATOR, as suggested in the forums I have mentioned.
When sampling for 2nd batch, the GPU memory is not all released. I assume this is the previous sampling state. It does not get released until the model ran.
Current implementation of the model and the batching:
Hi all,
I'm currently working on a model with 1 GPU and I am facing an OOM error.
All the chains are loaded in the GPU memory and we exceed the 16GB of the GPU.
To fix it, I have implemented batching. (cf (Forum 1, Forum 2). )
It now runs with 2000 samples. However BFMI is low -> I need to increase the samples.
OOM error is again triggered with 4000 samples. It seems that 2 chains are always kept in memory. (1 chain being 6GB, 2 chains 12GB -> OOM on 16GB GPU)
Below is the experiments I have done:
Experiments so far:
Similar to other in the forums, the 2nd batch will usually cause an OOM error. I have attempted to pass mainly these two flags:
XLA_PYTHON_CLIENT_PREALLOCATE
andXLA_PYTHON_CLIENT_ALLOCATOR
, as suggested in the forums I have mentioned.When sampling for 2nd batch, the GPU memory is not all released. I assume this is the previous sampling state. It does not get released until the model ran.
Current implementation of the model and the batching:
Model definition and NUTS+MCMC configuration
Batching function
Questions:
Thank you!
The text was updated successfully, but these errors were encountered: