Skip to content

Commit

Permalink
fix chatglm rope
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Aug 16, 2023
1 parent ef3d460 commit e927a4a
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, x, seq_dim=1, seq_len=None):
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached or seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, x, seq_dim=1, seq_len=None):
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else max(seq_len,self.max_position_embeddings)
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) / self.scale
t = torch.arange(self.max_seq_len_cached or seq_len , device=x.device, dtype=self.inv_freq.dtype) / self.scale
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, x, seq_dim=1, seq_len=None):
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else max(seq_len,self.max_position_embeddings)
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached or seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
Expand Down

0 comments on commit e927a4a

Please sign in to comment.