Skip to content

Commit

Permalink
add support dynamic ntk for llama (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Sep 12, 2023
1 parent 4f42716 commit 01ad8c7
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def _init_custom(self):
"""
模型特殊的一些初始化
"""
self._init_to_get_rotary()
if self.config.get("use_dynamic_ntk", False):
self._init_to_get_dynamic_ntk_rotary()
else:
self._init_to_get_rotary()
return


def _init_to_get_rotary(self, default_base=10000.0):
if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
Expand Down Expand Up @@ -97,3 +99,27 @@ def _init_to_get_rotary(self, default_base=10000.0):
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return

def _init_to_get_dynamic_ntk_rotary(self):
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
base = self.config.get("rope_theta", 10000.0)
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = 32 * max_position_embeddings # 64k
self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=torch.float16, device="cuda")

inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_))
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(torch.float16).cuda()

for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (self.head_dim_ / (self.head_dim_ - 2))
inv_freq = 1.0 / (new_base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_))
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(torch.float16).cuda()
return

0 comments on commit 01ad8c7

Please sign in to comment.