From e927a4a484f97d31a4ad6781c4163324904c9778 Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Wed, 16 Aug 2023 15:40:16 +0800 Subject: [PATCH] fix chatglm rope Signed-off-by: ssbuild <462304@qq.cn> --- src/deep_training/nlp/layers/rope_scale/DynamicScaledRotary.py | 2 +- src/deep_training/nlp/layers/rope_scale/LinearScaledRotary.py | 2 +- src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/deep_training/nlp/layers/rope_scale/DynamicScaledRotary.py b/src/deep_training/nlp/layers/rope_scale/DynamicScaledRotary.py index f8d432ee..4ba36cf4 100644 --- a/src/deep_training/nlp/layers/rope_scale/DynamicScaledRotary.py +++ b/src/deep_training/nlp/layers/rope_scale/DynamicScaledRotary.py @@ -67,7 +67,7 @@ def forward(self, x, seq_dim=1, seq_len=None): inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) self.register_buffer("inv_freq", inv_freq) - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached or seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) diff --git a/src/deep_training/nlp/layers/rope_scale/LinearScaledRotary.py b/src/deep_training/nlp/layers/rope_scale/LinearScaledRotary.py index cc4f6629..681db98d 100644 --- a/src/deep_training/nlp/layers/rope_scale/LinearScaledRotary.py +++ b/src/deep_training/nlp/layers/rope_scale/LinearScaledRotary.py @@ -53,7 +53,7 @@ def forward(self, x, seq_dim=1, seq_len=None): seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else max(seq_len,self.max_position_embeddings) - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) / self.scale + t = torch.arange(self.max_seq_len_cached or seq_len , device=x.device, dtype=self.inv_freq.dtype) / self.scale freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) diff --git a/src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py b/src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py index 79a22661..1816a7d3 100644 --- a/src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py +++ b/src/deep_training/nlp/layers/rope_scale/NTKScaledRotary.py @@ -49,7 +49,7 @@ def forward(self, x, seq_dim=1, seq_len=None): seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else max(seq_len,self.max_position_embeddings) - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached or seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device)