Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching MCMC OOM issue #1954

Open
ttansuwan opened this issue Jan 22, 2025 · 0 comments
Open

Batching MCMC OOM issue #1954

ttansuwan opened this issue Jan 22, 2025 · 0 comments

Comments

@ttansuwan
Copy link

Hi all,
I'm currently working on a model with 1 GPU and I am facing an OOM error.
All the chains are loaded in the GPU memory and we exceed the 16GB of the GPU.

To fix it, I have implemented batching. (cf (Forum 1, Forum 2). )
It now runs with 2000 samples. However BFMI is low -> I need to increase the samples.
OOM error is again triggered with 4000 samples. It seems that 2 chains are always kept in memory. (1 chain being 6GB, 2 chains 12GB -> OOM on 16GB GPU)

Below is the experiments I have done:

Experiments so far:

XLA_PYTHON_CLIENT_PREALLOCATE XLA_PYTHON_CLIENT_ALLOCATOR num_samples num_warmup Chain method VRAM at beginning of each sample Successfully run?
False Platform 4000 10000 sequential 1. 744 MB 2. 6566MB No
False Platform 2000 10000 sequential 1. 744 MB 2. 3695 MB 3. 3695 MB 4. 3695 MB Yes
Default Default 2000 10000 sequential 1. 13,385 MB 2. 13,385 MB No

Similar to other in the forums, the 2nd batch will usually cause an OOM error. I have attempted to pass mainly these two flags: XLA_PYTHON_CLIENT_PREALLOCATE and XLA_PYTHON_CLIENT_ALLOCATOR, as suggested in the forums I have mentioned.

When sampling for 2nd batch, the GPU memory is not all released. I assume this is the previous sampling state. It does not get released until the model ran.

Current implementation of the model and the batching:

Model definition and NUTS+MCMC configuration

def model(n_items, n_factors, n_persons, responses_mask, responses=None):
    with plate("diff_dim1",  n_items, dim=-1):
        diff = numpyro.sample("diff", dist.Normal(loc=0.0, scale=1.0))

    with plate("discrim_dim1", n_factors, dim=-1):
        with plate("discrim_dim2", n_items, dim=-2):
            discrim_offset = numpyro.sample(
                "discrim_offset", dist.LogNormal(loc=0.0, scale=1.0)
            )
            # Need to use deterministic layer to apply q_matrix
            discrim = numpyro.deterministic(
                "discrim", discrim_offset * item_factors
            )

    with plate("ability_dim1", n_persons, dim=-1):
        corr = numpyro.sample("ability_corr", dist.LKJ(n_factors, jnp.ones([1])))

        # Mean would always be 0
        mu = jnp.zeros(n_factors)
        ability = numpyro.sample(
            "ability", dist.MultivariateNormal(loc=mu, covariance_matrix=corr)
        )

    # Logistic regression
    kernel = jnp.dot(ability, discrim.T) + diff
    numpyro.sample(
        "obs",
        dist.BernoulliLogits(logits=kernel).mask(responses_mask),
        obs=responses,
    )

# Defining sampler
kernel = NUTS(
    model, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False
)

mcmc = MCMC(
    kernel,
    num_warmup=10000,
    num_samples=2000,
    num_chains=1,
    progress_bar=True,
    chain_method="sequential",
    jit_model_args=True,
)

Batching function

_traces = []
_extra_fields = []

for i in range(chains):
    mcmc.run(
        random.PRNGKey(0),
        self.responses_mask,
        self.item_factors,
        self.response_trains,
        extra_fields=(
            "num_steps",
            "potential_energy",
            "energy",
            "adapt_state.step_size",
            "accept_prob",
            "diverging",
        ),
    )
    # Transfer to CPU
    samples = jax.device_put(
        mcmc.get_samples(group_by_chain=True), jax.devices("cpu")[0]
    )
    extra_fields = jax.device_put(
        mcmc.get_extra_fields(group_by_chain=True), jax.devices("cpu")[0]
    )

    _traces.append(samples)
    _extra_fields.append(extra_fields)
    del samples, extra_fields

    # Set warmup state to the next run
    sampler._warmup_state = sampler._last_state
    gc.collect()

# Prepare the traces for arviz
trace = {}
extras = {}
for k in _traces[0].keys():
    trace[k] = np.concatenate(list(trace[k] for trace in _traces))
for j in _extra_fields[0].keys():
    extras[j] = np.concatenate(list(extras[j] for extras in _extra_fields))

idata = az.convert_to_inference_data(trace)
iextra = az.convert_to_inference_data(extras, group="sample_stats")
az.concat(idata, iextra, inplace=True)

Questions:

  1. Is there anyway to reduce the memory consumption to avoid OOM issue?
  2. Is there any suggestions on how to improve the situation? Or any mistake spotted in the implementation?
  3. Why I seem to have 2 chains at all times in memory, when I only keep the last state (which should be smaller)?
  4. Why do I run out of memory before hitting 16GB, is it fragmentation?

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant