Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sampling behaviour with samples > 1 #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

elkoz
Copy link

@elkoz elkoz commented Nov 23, 2023

This looks like a really great project, thank you for making it!

I encountered some inconveniences when trying to generate a large number of samples with protein_init not set to None and I thought my fixes might be useful for other people.

First, there is a bug in the code right now that leads to only one sample being generated when protein_init is not None, independently of the samples parameter. Here I expand X_unc, C_unc and S_unc along the batch dimension before passing them to the model to fix that. Second, since there is no batching during sampling, generating a large number of samples is inconvenient. In this PR I added a batch_size parameter to make it possible to generate a larger number of outputs with one command.

@wujiewang
Copy link
Member

Hey @elkoz Thanks for interests and PR!! The fix looks reasonable to me and can be very useful. However, we have not figured out a protocol to accept external PR yet. Will discuss with the team next week and get back to you!

@elkoz
Copy link
Author

elkoz commented Nov 28, 2023

Note that when the samples > 1 behavior is fixed, applying the subsequence conditioner with samples > 1 is failing. It should be an issue with the conditioner (e.g. substructure conditioning works correctly) but I'm not sure how to fix it.

Here is the code snippet:

from chroma import Chroma, Protein, conditioners
from chroma.utility.chroma import plane_split_protein
API_KEY = ...

from chroma import api
api.register_key(API_KEY)

chroma = Chroma()

device = "cuda:0"
pdb_id = "7KGK"

protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)

X, C, _ = protein.to_XCS()
selection_string = "namesel infilling_selection"  
residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:,1].tolist()
protein.sys.save_selection(gti=residues_to_design, selname="infilling_selection")

sequence_conditioner = conditioners.SubsequenceConditioner(
    design_model=chroma.design_network, protein=protein, selection=selection_string
).to(device)

infilled_proteins, trajectories = chroma.sample(
    protein_init=protein,
    conditioner=sequence_conditioner,
    langevin_factor=4.0,
    langevin_isothermal=True,
    inverse_temperature=8.0,
    steps=500,
    full_output=True,
    samples=4,
)

And here is the error trace.

{
	"name": "RuntimeError",
	"message": "Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/liza/chroma/test.ipynb Cell 4 line 2
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> sequence_conditioner = conditioners.SubsequenceConditioner(
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a>     design_model=chroma.design_network, protein=protein, selection=selection_string
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> ).to(device)
---> <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> infilled_proteins, trajectories = chroma.sample(
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>     protein_init=protein,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>     conditioner=sequence_conditioner,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>     langevin_factor=4.0,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>     langevin_isothermal=True,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>     inverse_temperature=8.0,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a>     steps=500,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a>     full_output=True,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a>     samples=4,
     <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a> )

File ~/chroma/chroma/models/chroma.py:236, in Chroma.sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, design_ban_S, design_method, design_selection, design_t, temperature_S, temperature_chi, top_p_S, regularization, potts_mcmc_depth, potts_proposal, potts_symmetry_order, verbose)
    233 design_kwargs = {k: input_args[k] for k in input_args if k in design_keys}
    235 # Perform Sampling
--> 236 sample_output = self._sample(**backbone_kwargs)
    238 if full_output:
    239     protein_sample, output_dictionary = sample_output

File ~/chroma/chroma/models/chroma.py:381, in Chroma._sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, **kwargs)
    373 outs = {
    374     \"C\": torch.tensor([], device=X_unc.device), 
    375     \"X_sample\": torch.tensor([], device=X_unc.device),
   (...)
    378     \"Xunc_trajectory\": [torch.tensor([], device=X_unc.device) for i in range(steps)],
    379 }
    380 for b in range(num_batches):
--> 381     outs_ = self.backbone_network.sample_sde(
    382         C_unc[b * batch_size : (b + 1) * batch_size],
    383         X_init=X_unc[b * batch_size : (b + 1) * batch_size],
    384         conditioner=conditioner,
    385         tspan=tspan,
    386         langevin_isothermal=langevin_isothermal,
    387         integrate_func=integrate_func,
    388         sde_func=sde_func,
    389         langevin_factor=langevin_factor,
    390         inverse_temperature=inverse_temperature,
    391         N=steps,
    392         initialize_noise=initialize_noise,
    393         **kwargs,
    394     )
    395     outs[\"C\"] = torch.cat([outs[\"C\"], outs_[\"C\"]], dim=0)
    396     outs[\"X_sample\"] = torch.cat([outs[\"X_sample\"], outs_[\"X_sample\"]], dim=0)

File ~/chroma/chroma/models/graph_backbone.py:187, in GraphBackbone.__init__.<locals>.<lambda>(C, **kwargs)
    185 # Wrap sampling functions
    186 _X0_func = lambda X, C, t: self.denoise(X, C, t)
--> 187 self.sample_sde = lambda C, **kwargs: self.noise_perturb.sample_sde(
    188     _X0_func, C, **kwargs
    189 )
    190 self.sample_baoab = lambda C, **kwargs: self.noise_perturb.sample_baoab(
    191     _X0_func, C, **kwargs
    192 )
    193 self.sample_ode = lambda C, **kwargs: self.noise_perturb.sample_ode(
    194     _X0_func, C, **kwargs
    195 )

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/chroma/chroma/layers/structure/diffusion.py:1208, in DiffusionChainCov.sample_sde(self, X0_func, C, X_init, conditioner, N, tspan, inverse_temperature, langevin_factor, langevin_isothermal, sde_func, integrate_func, initialize_noise, remap_time, remove_drift_translate, remove_noise_translate, align_X0)
   1206         U_test = 0.0
   1207         t_test = torch.tensor([0.0], device=X_init.device)
-> 1208         _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
   1209 else:
   1210     Ct = C

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
    112         if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
    113             raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)

File ~/chroma/chroma/layers/structure/conditioners.py:241, in SubsequenceConditioner.forward(self, X, C, O, U, t)
    239 if self.mask_condition is not None:
    240     priority = 1.0 - self.mask_condition
--> 241 out = self.design_model(X_input, C, self.S_condition, t, priority=priority)
    242 logp_S = out[\"logp_S\"]
    244 if self.mask_condition is not None:

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
    112         if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
    113             raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)

File ~/chroma/chroma/models/graph_design.py:431, in GraphDesign.forward(self, X, C, S, t, sample_noise, permute_idx, priority)
    429 logp_S_potts = None
    430 if self.kwargs[\"predict_S_potts\"]:
--> 431     logp_S_potts = self.decoder_S_potts.loss(
    432         S, node_h, edge_h, edge_idx, mask_i, mask_ij
    433     )
    435 # Sample random permutations and build autoregressive mask
    436 if permute_idx is None:

File ~/chroma/chroma/layers/structure/potts.py:542, in GraphPotts.loss(self, S, node_h, edge_h, edge_idx, mask_i, mask_ij)
    539 h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij)
    541 # Log composite likelihood
--> 542 logp_ij, mask_p_ij = self.log_composite_likelihood(
    543     S,
    544     h,
    545     J,
    546     edge_idx,
    547     mask_i,
    548     mask_ij,
    549     smoothing_alpha=self.label_smoothing if self.training else 0.0,
    550 )
    552 # Map into approximate local likelihoods
    553 logp_i = (
    554     mask_i
    555     * torch.sum(mask_p_ij * logp_ij, dim=-1)
    556     / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3)
    557 )

File ~/chroma/chroma/layers/structure/potts.py:458, in GraphPotts.log_composite_likelihood(self, S, h, J, edge_idx, mask_i, mask_ij, smoothing_alpha)
    454 num_batch, num_residues, num_k, num_states, _ = list(J.size())
    456 # Gather J clamped at j
    457 # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i]
--> 458 S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx)
    459 S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1)
    460 # (B,i,j,A_i)

File ~/chroma/chroma/layers/graph.py:677, in collect_neighbors(node_h, edge_idx)
    675 idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
    676 idx_flat = idx_flat.expand(-1, -1, num_features)
--> 677 neighbor_h = torch.gather(node_h, 1, idx_flat)
    678 neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
    679 return neighbor_h

RuntimeError: Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1"
}

@wujiewang
Copy link
Member

Yeah, not all conditioners have been tested for batched sampling. Will need to figure out this. Thanks for flagging!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants