From 83ea032556d2941872f5ae4d673001253e48b16c Mon Sep 17 00:00:00 2001 From: Chaitanya Joshi Date: Wed, 22 Nov 2023 20:14:25 +0000 Subject: [PATCH 1/5] Fix docstring --- chroma/layers/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chroma/layers/attention.py b/chroma/layers/attention.py index d673c38..12b0a8d 100644 --- a/chroma/layers/attention.py +++ b/chroma/layers/attention.py @@ -64,11 +64,11 @@ class MultiHeadAttention(nn.Module): for details and intuition. Args: - n_head (int): number of attention heads - d_k (int): dimension of the keys and queries in each attention head - d_v (int): dimension of the values in each attention head - d_model (int): input and output dimension for the layer - dropout (float): dropout rate, default is 0.1 + n_head (int): number of attention heads + d_k (int): dimension of the keys and queries in each attention head + d_v (int): dimension of the values in each attention head + d_model (int): input and output dimension for the layer + dropout (float): dropout rate, default is 0.1 Inputs: Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)``` From e088ab860e95ecc938752fab350873cb2096c6ec Mon Sep 17 00:00:00 2001 From: Chaitanya Joshi Date: Fri, 24 Nov 2023 20:22:09 +0000 Subject: [PATCH 2/5] Add missing import --- chroma/data/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chroma/data/system.py b/chroma/data/system.py index 4d42c5e..2a33854 100644 --- a/chroma/data/system.py +++ b/chroma/data/system.py @@ -20,7 +20,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch From 05b7db22a4d8e0eae78be2184010395aa01cfb35 Mon Sep 17 00:00:00 2001 From: Chaitanya Joshi Date: Fri, 24 Nov 2023 20:36:39 +0000 Subject: [PATCH 3/5] Add missing Args to docstring --- chroma/models/graph_design.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chroma/models/graph_design.py b/chroma/models/graph_design.py index c74e93f..9b50e80 100644 --- a/chroma/models/graph_design.py +++ b/chroma/models/graph_design.py @@ -1954,8 +1954,12 @@ def sample( smoothing values less than 1.0 are recommended. top_p (float, optional): Top-p cutoff for Nucleus Sampling, see Holtzman et al ICLR 2020. - ban_S (tuple, optional): An optional set of token indices from - `chroma.constants.AA20` to ban during sampling. + mask_S (torch.Tensor, optional): Binary tensor mask indicating + masked/banned tokens during sampling at each residue with shape + `(num_batch, num_residues, num_alphabet)`. + bias (torch.Tensor, optional): Bias for each token for at + each residue added to log probabilities with shape + `(num_batch, num_residues, num_alphabet)`. Returns: S_sample (torch.LongTensor): Sampled sequence of shape `(num_batch, From 77160686d66de533000bb9b09258d256aae9c6c0 Mon Sep 17 00:00:00 2001 From: Chaitanya Joshi Date: Fri, 24 Nov 2023 20:46:43 +0000 Subject: [PATCH 4/5] Fix typo --- chroma/data/xcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chroma/data/xcs.py b/chroma/data/xcs.py index 3883376..7ce063d 100644 --- a/chroma/data/xcs.py +++ b/chroma/data/xcs.py @@ -28,7 +28,7 @@ `C` (LongTensor), the chain map encoding per-residue chain assignments with shape `(num_batch, num_residues)`.The chain map codes positions as `0` - when masked, poitive integers for chain indices, and negative integers + when masked, positive integers for chain indices, and negative integers to represent missing residues (of the corresponding positive integers). `S` (LongTensor), the sequence of the protein as alphabet indices with From c1ffed0e2b0837f153e1d6ec79de4c8b2f15978c Mon Sep 17 00:00:00 2001 From: Chaitanya Joshi Date: Fri, 24 Nov 2023 21:35:01 +0000 Subject: [PATCH 5/5] Fix typo --- chroma/layers/structure/protein_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chroma/layers/structure/protein_graph.py b/chroma/layers/structure/protein_graph.py index f2ff0a2..ba06344 100644 --- a/chroma/layers/structure/protein_graph.py +++ b/chroma/layers/structure/protein_graph.py @@ -101,7 +101,7 @@ class ProteinFeatureGraph(nn.Module): for the the third dimension are PDB order (`[N, CA, C, O]`). C (LongTensor, optional): Chain map with shape `(num_batch, num_residues)`. The chain map codes positions as `0` - when masked, poitive integers for chain indices, and negative + when masked, positive integers for chain indices, and negative integers to represent missing residues of the corresponding positive integers. custom_D (Tensor, optional): Pre-computed custom distance map