diff --git a/alpaca_finetuning_v1/llama/model.py b/alpaca_finetuning_v1/llama/model.py index 1e9e3ae..f5b9f6c 100755 --- a/alpaca_finetuning_v1/llama/model.py +++ b/alpaca_finetuning_v1/llama/model.py @@ -210,7 +210,7 @@ def forward(self, examples, labels): h = layer(h, start_pos, freqs_cis, mask) adapter_index = 0 - adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, 4096).unsqueeze(1) + adapter = self.adapter_query.weight.reshape(-1, self.adapter_len, self.params.dim).unsqueeze(1) for layer in self.layers[-1 * self.adapter_layer :]: h = layer(h, start_pos, freqs_cis, mask, adapter[adapter_index].half()) adapter_index = adapter_index + 1