-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve sampling behaviour with samples > 1 #16
base: main
Are you sure you want to change the base?
Conversation
Hey @elkoz Thanks for interests and PR!! The fix looks reasonable to me and can be very useful. However, we have not figured out a protocol to accept external PR yet. Will discuss with the team next week and get back to you! |
Note that when the samples > 1 behavior is fixed, applying the subsequence conditioner with samples > 1 is failing. It should be an issue with the conditioner (e.g. substructure conditioning works correctly) but I'm not sure how to fix it. Here is the code snippet: from chroma import Chroma, Protein, conditioners
from chroma.utility.chroma import plane_split_protein
API_KEY = ...
from chroma import api
api.register_key(API_KEY)
chroma = Chroma()
device = "cuda:0"
pdb_id = "7KGK"
protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)
X, C, _ = protein.to_XCS()
selection_string = "namesel infilling_selection"
residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:,1].tolist()
protein.sys.save_selection(gti=residues_to_design, selname="infilling_selection")
sequence_conditioner = conditioners.SubsequenceConditioner(
design_model=chroma.design_network, protein=protein, selection=selection_string
).to(device)
infilled_proteins, trajectories = chroma.sample(
protein_init=protein,
conditioner=sequence_conditioner,
langevin_factor=4.0,
langevin_isothermal=True,
inverse_temperature=8.0,
steps=500,
full_output=True,
samples=4,
) And here is the error trace. {
"name": "RuntimeError",
"message": "Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/home/liza/chroma/test.ipynb Cell 4 line 2
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=17'>18</a> protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a> sequence_conditioner = conditioners.SubsequenceConditioner(
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> design_model=chroma.design_network, protein=protein, selection=selection_string
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a> ).to(device)
---> <a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a> infilled_proteins, trajectories = chroma.sample(
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a> protein_init=protein,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a> conditioner=sequence_conditioner,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a> langevin_factor=4.0,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a> langevin_isothermal=True,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a> inverse_temperature=8.0,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a> steps=500,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a> full_output=True,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a> samples=4,
<a href='vscode-notebook-cell://ssh-remote%2Bgcp_gpu_0/home/liza/chroma/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a> )
File ~/chroma/chroma/models/chroma.py:236, in Chroma.sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, design_ban_S, design_method, design_selection, design_t, temperature_S, temperature_chi, top_p_S, regularization, potts_mcmc_depth, potts_proposal, potts_symmetry_order, verbose)
233 design_kwargs = {k: input_args[k] for k in input_args if k in design_keys}
235 # Perform Sampling
--> 236 sample_output = self._sample(**backbone_kwargs)
238 if full_output:
239 protein_sample, output_dictionary = sample_output
File ~/chroma/chroma/models/chroma.py:381, in Chroma._sample(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, batch_size, **kwargs)
373 outs = {
374 \"C\": torch.tensor([], device=X_unc.device),
375 \"X_sample\": torch.tensor([], device=X_unc.device),
(...)
378 \"Xunc_trajectory\": [torch.tensor([], device=X_unc.device) for i in range(steps)],
379 }
380 for b in range(num_batches):
--> 381 outs_ = self.backbone_network.sample_sde(
382 C_unc[b * batch_size : (b + 1) * batch_size],
383 X_init=X_unc[b * batch_size : (b + 1) * batch_size],
384 conditioner=conditioner,
385 tspan=tspan,
386 langevin_isothermal=langevin_isothermal,
387 integrate_func=integrate_func,
388 sde_func=sde_func,
389 langevin_factor=langevin_factor,
390 inverse_temperature=inverse_temperature,
391 N=steps,
392 initialize_noise=initialize_noise,
393 **kwargs,
394 )
395 outs[\"C\"] = torch.cat([outs[\"C\"], outs_[\"C\"]], dim=0)
396 outs[\"X_sample\"] = torch.cat([outs[\"X_sample\"], outs_[\"X_sample\"]], dim=0)
File ~/chroma/chroma/models/graph_backbone.py:187, in GraphBackbone.__init__.<locals>.<lambda>(C, **kwargs)
185 # Wrap sampling functions
186 _X0_func = lambda X, C, t: self.denoise(X, C, t)
--> 187 self.sample_sde = lambda C, **kwargs: self.noise_perturb.sample_sde(
188 _X0_func, C, **kwargs
189 )
190 self.sample_baoab = lambda C, **kwargs: self.noise_perturb.sample_baoab(
191 _X0_func, C, **kwargs
192 )
193 self.sample_ode = lambda C, **kwargs: self.noise_perturb.sample_ode(
194 _X0_func, C, **kwargs
195 )
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/chroma/chroma/layers/structure/diffusion.py:1208, in DiffusionChainCov.sample_sde(self, X0_func, C, X_init, conditioner, N, tspan, inverse_temperature, langevin_factor, langevin_isothermal, sde_func, integrate_func, initialize_noise, remap_time, remove_drift_translate, remove_noise_translate, align_X0)
1206 U_test = 0.0
1207 t_test = torch.tensor([0.0], device=X_init.device)
-> 1208 _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test)
1209 else:
1210 Ct = C
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
112 if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
113 raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)
File ~/chroma/chroma/layers/structure/conditioners.py:241, in SubsequenceConditioner.forward(self, X, C, O, U, t)
239 if self.mask_condition is not None:
240 priority = 1.0 - self.mask_condition
--> 241 out = self.design_model(X_input, C, self.S_condition, t, priority=priority)
242 logp_S = out[\"logp_S\"]
244 if self.mask_condition is not None:
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/chroma/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/chroma/chroma/data/xcs.py:114, in validate_XCS.<locals>.decorator.<locals>.new_func(*args, **kwargs)
112 if not torch.allclose(tensors[\"O\"].argmax(dim=2), tensors[\"S\"]):
113 raise ValueError(\"S and O are both provided but don't match!\")
--> 114 return func(*args, **kwargs)
File ~/chroma/chroma/models/graph_design.py:431, in GraphDesign.forward(self, X, C, S, t, sample_noise, permute_idx, priority)
429 logp_S_potts = None
430 if self.kwargs[\"predict_S_potts\"]:
--> 431 logp_S_potts = self.decoder_S_potts.loss(
432 S, node_h, edge_h, edge_idx, mask_i, mask_ij
433 )
435 # Sample random permutations and build autoregressive mask
436 if permute_idx is None:
File ~/chroma/chroma/layers/structure/potts.py:542, in GraphPotts.loss(self, S, node_h, edge_h, edge_idx, mask_i, mask_ij)
539 h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij)
541 # Log composite likelihood
--> 542 logp_ij, mask_p_ij = self.log_composite_likelihood(
543 S,
544 h,
545 J,
546 edge_idx,
547 mask_i,
548 mask_ij,
549 smoothing_alpha=self.label_smoothing if self.training else 0.0,
550 )
552 # Map into approximate local likelihoods
553 logp_i = (
554 mask_i
555 * torch.sum(mask_p_ij * logp_ij, dim=-1)
556 / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3)
557 )
File ~/chroma/chroma/layers/structure/potts.py:458, in GraphPotts.log_composite_likelihood(self, S, h, J, edge_idx, mask_i, mask_ij, smoothing_alpha)
454 num_batch, num_residues, num_k, num_states, _ = list(J.size())
456 # Gather J clamped at j
457 # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i]
--> 458 S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx)
459 S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1)
460 # (B,i,j,A_i)
File ~/chroma/chroma/layers/graph.py:677, in collect_neighbors(node_h, edge_idx)
675 idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
676 idx_flat = idx_flat.expand(-1, -1, num_features)
--> 677 neighbor_h = torch.gather(node_h, 1, idx_flat)
678 neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
679 return neighbor_h
RuntimeError: Size does not match at dimension 0 expected index [4, 18840, 1] to be smaller than self [1, 314, 1] apart from dimension 1"
} |
Yeah, not all conditioners have been tested for batched sampling. Will need to figure out this. Thanks for flagging! |
This looks like a really great project, thank you for making it!
I encountered some inconveniences when trying to generate a large number of samples with
protein_init
not set toNone
and I thought my fixes might be useful for other people.First, there is a bug in the code right now that leads to only one sample being generated when
protein_init
is notNone
, independently of thesamples
parameter. Here I expandX_unc
,C_unc
andS_unc
along the batch dimension before passing them to the model to fix that. Second, since there is no batching during sampling, generating a large number of samples is inconvenient. In this PR I added abatch_size
parameter to make it possible to generate a larger number of outputs with one command.