Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hi, Is there any support for Bi-LSTM #1

Open
blueardour opened this issue Jul 15, 2021 · 3 comments
Open

Hi, Is there any support for Bi-LSTM #1

blueardour opened this issue Jul 15, 2021 · 3 comments

Comments

@blueardour
Copy link

Hi, thanks for the helpful work.

Coudl I ask if any plan for supporting bidirectional lstm with custom stacks?

@piEsposito
Copy link
Owner

Hi. I was not thinking on it, but if it is helpful I might as well support it.

Also, feel free to PR with the feature if you want.

@blueardour
Copy link
Author

blueardour commented Jul 16, 2021

Hi, I tried to implement on myself. However, I can not figure out the output format.


import torch
import torch.nn as nn
import pdb

class CustomLSTM(nn.LSTM):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
        # proj_size=0 is available from Pytorch 1.8
        super(CustomLSTM, self).__init__(input_size, hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)

    def forward(self, x, init_states=None, exporting_onnx=False):
        if exporting_onnx:
            assert self.num_layers == 1
            bs, seq, _ = x.size() if self.batch_first else (x.size(1), x.size(0), x.size(2))
            sz = self.hidden_size

            if init_states is None:
                h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
            hidden_seq_forward = []
            for t in range(seq):
                x_t = x[:, t, :] if self.batch_first else x[t, :, :]
                i_t = x_t @ self.weight_ih_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0[sz*0:sz*1] + \
                      h_t @ self.weight_hh_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0[sz*0:sz*1]
                f_t = x_t @ self.weight_ih_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0[sz*1:sz*2] + \
                      h_t @ self.weight_hh_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0[sz*1:sz*2]
                g_t = x_t @ self.weight_ih_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0[sz*2:sz*3] + \
                      h_t @ self.weight_hh_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0[sz*2:sz*3]
                o_t = x_t @ self.weight_ih_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0[sz*3:sz*4] + \
                      h_t @ self.weight_hh_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0[sz*3:sz*4]
                i_t = torch.sigmoid(i_t)
                f_t = torch.sigmoid(f_t)
                g_t = torch.tanh(g_t)
                o_t = torch.sigmoid(o_t)
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t)
                hidden_seq_forward.append(h_t.unsqueeze(0))

            if init_states is None:
                h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
            hidden_seq_reverse = []
            for t in list(reversed(range(seq))):
                x_t = x[:, t, :] if self.batch_first else x[t, :, :]
                i_t = x_t @ self.weight_ih_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*0:sz*1] + \
                      h_t @ self.weight_hh_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*0:sz*1]
                f_t = x_t @ self.weight_ih_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*1:sz*2] + \
                      h_t @ self.weight_hh_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*1:sz*2]
                g_t = x_t @ self.weight_ih_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*2:sz*3] + \
                      h_t @ self.weight_hh_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*2:sz*3]
                o_t = x_t @ self.weight_ih_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*3:sz*4] + \
                      h_t @ self.weight_hh_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*3:sz*4]
                i_t = torch.sigmoid(i_t)
                f_t = torch.sigmoid(f_t)
                g_t = torch.tanh(g_t)
                o_t = torch.sigmoid(o_t)
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t) # [bs * self.hidden_size]
                hidden_seq_reverse.append(h_t.unsqueeze(0))

            # stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
            hidden_seq = torch.cat(hidden_seq, dim=0) # [seq, bs, self.hidden_size]
            if self.batch_first:
                hidden_seq = hidden_seq.transpose(0, 1).contiguous()
            return hidden_seq, (_, _)

        else:
            return super().forward(x)



if __name__ == "__main__":
    model = CustomLSTM(100, 60, bidirectional=True)
    x = torch.rand(512, 10, 100)

    model.eval()
    y1, (hn, cn) = model(x, None, False)
    print(y1.shape)

    y2, (hn, cn) = model(x, None, True)
    print(y2.shape)
    pdb.set_trace()

Could I ask for suggestion around # stack hidden_seq_forward and hidden_seq_reverse to hidden_seq

@blueardour
Copy link
Author

if I employ

# stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
            hidden_seq_forward = torch.cat(hidden_seq_forward, dim=0) # [seq, bs, self.hidden_size]
            hidden_seq_reverse = torch.cat(hidden_seq_reverse, dim=0) # [seq, bs, self.hidden_size]
            print(hidden_seq_forward.shape, hidden_seq_reverse.shape)
            hidden_seq = torch.cat([hidden_seq_forward, hidden_seq_reverse], dim=2)
            print(hidden_seq.shape)

seems y1 == y2 in the main gives a lot of False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants