Skip to content

Commit

Permalink
defining encoder and decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnnyboycurtis committed Jan 7, 2024
1 parent 10779c9 commit 311f5cf
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def forward(self, images):
features = self.embed(features)
return features


'''
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
pass
Expand All @@ -30,4 +30,41 @@ def forward(self, features, captions):
def sample(self, inputs, states=None, max_len=20):
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
pass
pass
'''

#https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2, dropout=0.2):
"""Set the hyper-parameters and build the layers."""
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size) ## outputsize, hidden_size
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_size
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, dropout=dropout, batch_first=True) ## embedding_dim, hidden_dim
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(hidden_size, vocab_size)

def forward(self, features, captions):
"""Decode image feature vectors and generates captions."""
features = features.view(len(features), 1, -1)
embeddings = self.embed(captions[:, :-1]) ## remove the last `word`

inputs = torch.cat((features, embeddings), 1)
ltsm_out, hidden = self.lstm(inputs)
ltsm_out = self.dropout(ltsm_out)
out = self.linear(ltsm_out)
return out

def sample(self, inputs, states=None, max_len=20):
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
result = []
for _ in range(max_len):
lstm_out, states = self.lstm(inputs, states)
out = self.linear(lstm_out.view(len(lstm_out), -1))
idx = out.max(1)[1]
result.append(int(idx.item()))
inputs = self.embed(idx)
inputs = inputs.unsqueeze(1)
return result

0 comments on commit 311f5cf

Please sign in to comment.