-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cc16b40
commit 39a64a7
Showing
21 changed files
with
1,606 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
apps/drug_target_interaction/hybriddta/pointwise/GraphDTA/get_len.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Calculate length of each group in dataset.""" | ||
|
||
import pandas as pd | ||
|
||
|
||
def get_kiba_len(): | ||
# Get length of validation set | ||
for cv in ["CV1", "CV2", "CV3", "CV4", "CV5"]: | ||
df = pd.read_csv("../../Data/KIBA/"+cv+"/"+cv+"_KIBA_unseenP_seenD_val.csv") | ||
df = df.groupby(['Target ID']).size().reset_index(name = 'counts') | ||
f = open("../../Data/KIBA/"+cv+"/"+cv+"_val.txt",'a') | ||
for i in df['counts'].values: | ||
f.write(str(i) + "\n") | ||
|
||
|
||
# Get length of testing set | ||
df = pd.read_csv("../../Data/KIBA/test_KIBA_unseenP_seenD.csv") | ||
df = df.groupby(['Target ID']).size().reset_index(name = 'counts') | ||
f = open("../../Data/KIBA/kiba_len.txt",'a') | ||
for i in df['counts'].values: | ||
f.write(str(i) + "\n") |
74 changes: 74 additions & 0 deletions
74
apps/drug_target_interaction/hybriddta/pointwise/GraphDTA/models/gat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""GraphDTA_GAT backbone model.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import Sequential, Linear, ReLU | ||
from torch_geometric.nn import GATConv | ||
from torch_geometric.nn import global_max_pool as gmp | ||
|
||
|
||
# GAT backbone model | ||
class GATNet(torch.nn.Module): | ||
"""GAT model. | ||
Args: | ||
data: Input data. | ||
Returns: | ||
out: Prediction results. | ||
""" | ||
def __init__(self, num_features_xd=78, n_output=1, num_features_xt=25, | ||
n_filters=32, embed_dim=128, output_dim=128, dropout=0.2): | ||
super(GATNet, self).__init__() | ||
# Basic config | ||
self.relu = nn.ReLU() | ||
self.dropout = nn.Dropout(dropout) | ||
# SMILES graph branch | ||
self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout) | ||
self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout) | ||
self.fc_g1 = nn.Linear(output_dim, output_dim) | ||
# Protein sequence branch (1d conv) | ||
self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) | ||
self.conv_xt1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) | ||
self.fc_xt1 = nn.Linear(32*121, output_dim) | ||
# Combined layers | ||
self.fc1 = nn.Linear(256, 1024) | ||
self.fc2 = nn.Linear(1024, 256) | ||
self.out = nn.Linear(256, n_output) | ||
|
||
def forward(self, data): | ||
"""tbd.""" | ||
# Get graph input | ||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||
# Get protein input | ||
target = data.target | ||
|
||
x = F.dropout(x, p=0.2, training=self.training) | ||
x = F.elu(self.gcn1(x, edge_index)) | ||
|
||
x = F.dropout(x, p=0.2, training=self.training) | ||
x = self.gcn2(x, edge_index) | ||
x = self.relu(x) | ||
x = gmp(x, batch) # global max pooling | ||
|
||
x = self.fc_g1(x) | ||
x = self.relu(x) | ||
# 1d conv layers | ||
embedded_xt = self.embedding_xt(target) | ||
conv_xt = self.conv_xt1(embedded_xt) | ||
conv_xt = self.relu(conv_xt) | ||
# Flatten | ||
xt = conv_xt.view(-1, 32 * 121) | ||
xt = self.fc_xt1(xt) | ||
# Concat | ||
xc = torch.cat((x, xt), 1) | ||
# Add some dense layers | ||
xc = self.fc1(xc) | ||
xc = self.relu(xc) | ||
xc = self.dropout(xc) | ||
xc = self.fc2(xc) | ||
xc = self.relu(xc) | ||
xc = self.dropout(xc) | ||
out = self.out(xc) | ||
return out |
74 changes: 74 additions & 0 deletions
74
apps/drug_target_interaction/hybriddta/pointwise/GraphDTA/models/gat_gcn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""GraphDTA_GATGCN backbone model.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import Sequential, Linear, ReLU | ||
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_add_pool | ||
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | ||
|
||
|
||
# GATGCN backbone model | ||
class GAT_GCN(torch.nn.Module): | ||
"""GATGCN model. | ||
Args: | ||
data: Input data. | ||
Returns: | ||
out: Prediction results. | ||
""" | ||
def __init__(self, n_output=1, num_features_xd=78, num_features_xt=25, | ||
n_filters=32, embed_dim=128, output_dim=128, dropout=0.2): | ||
super(GAT_GCN, self).__init__() | ||
# Basic config | ||
self.relu = nn.ReLU() | ||
self.dropout = nn.Dropout(dropout) | ||
self.n_output = n_output | ||
# SMILES graph branch | ||
self.conv1 = GATConv(num_features_xd, num_features_xd, heads=10) | ||
self.conv2 = GCNConv(num_features_xd*10, num_features_xd*10) | ||
self.fc_g1 = torch.nn.Linear(num_features_xd*10*2, 1500) | ||
self.fc_g2 = torch.nn.Linear(1500, output_dim) | ||
# Protein sequence branch (1d conv) | ||
self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) | ||
self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) | ||
self.fc1_xt = nn.Linear(32*121, output_dim) | ||
# Combined layers | ||
self.fc1 = nn.Linear(256, 1024) | ||
self.fc2 = nn.Linear(1024, 512) | ||
self.out = nn.Linear(512, self.n_output) # n_output = 1 for regression task | ||
|
||
def forward(self, data): | ||
"""tbd.""" | ||
# Get graph input | ||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||
# Get protein input | ||
target = data.target | ||
|
||
x = self.conv1(x, edge_index) | ||
x = self.relu(x) | ||
x = self.conv2(x, edge_index) | ||
x = self.relu(x) | ||
# Apply global max pooling (gmp) and global mean pooling (gap) | ||
x = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) | ||
x = self.relu(self.fc_g1(x)) | ||
x = self.dropout(x) | ||
x = self.fc_g2(x) | ||
|
||
embedded_xt = self.embedding_xt(target) | ||
conv_xt = self.conv_xt_1(embedded_xt) | ||
# Flatten | ||
xt = conv_xt.view(-1, 32 * 121) | ||
xt = self.fc1_xt(xt) | ||
# Concat | ||
xc = torch.cat((x, xt), 1) | ||
# Add some dense layers | ||
xc = self.fc1(xc) | ||
xc = self.relu(xc) | ||
xc = self.dropout(xc) | ||
xc = self.fc2(xc) | ||
xc = self.relu(xc) | ||
xc = self.dropout(xc) | ||
out = self.out(xc) | ||
return out |
Oops, something went wrong.