From 064abb342d3a7d510bd071a7aebce01de4518939 Mon Sep 17 00:00:00 2001 From: mhh001 Date: Tue, 23 Apr 2024 16:48:35 +0800 Subject: [PATCH] Add NPU support for Llava --- README.md | 8 +- docs/AscendNPU_Support.md | 83 +++++++++ llava/eval/model_qa.py | 10 +- llava/eval/model_vqa.py | 9 +- llava/eval/model_vqa_loader.py | 9 +- llava/eval/model_vqa_mmbench.py | 9 +- llava/eval/model_vqa_science.py | 9 +- llava/eval/run_llava.py | 7 +- llava/serve/cli.py | 2 +- llava/train/llama_npu_monkey_patch.py | 245 ++++++++++++++++++++++++++ llava/train/train_npu.py | 14 ++ llava/utils.py | 18 +- scripts/v1_5/finetune_npu.sh | 40 +++++ 13 files changed, 441 insertions(+), 22 deletions(-) create mode 100644 docs/AscendNPU_Support.md create mode 100644 llava/train/llama_npu_monkey_patch.py create mode 100644 llava/train/train_npu.py create mode 100644 scripts/v1_5/finetune_npu.sh diff --git a/README.md b/README.md index 794ce1b27..9e510849e 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,8 @@ If you are not using Linux, do *NOT* proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md). +If you are using Ascend NPU, see instructions for [AscendNPU support](docs/AscendNPU_Support.md). + 1. Clone this repository and navigate to LLaVA folder ```bash git clone https://github.com/haotian-liu/LLaVA.git @@ -180,7 +182,7 @@ flowchart BT subgraph Demo Connections direction BT c<-->gws - + mw7b<-->c mw13b<-->c lsglw13b<-->c @@ -431,14 +433,14 @@ If you find LLaVA useful for your research and applications, please cite using t } @misc{liu2023improvedllava, - title={Improved Baselines with Visual Instruction Tuning}, + title={Improved Baselines with Visual Instruction Tuning}, author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae}, publisher={arXiv:2310.03744}, year={2023}, } @misc{liu2023llava, - title={Visual Instruction Tuning}, + title={Visual Instruction Tuning}, author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae}, publisher={NeurIPS}, year={2023}, diff --git a/docs/AscendNPU_Support.md b/docs/AscendNPU_Support.md new file mode 100644 index 000000000..5dfb345fb --- /dev/null +++ b/docs/AscendNPU_Support.md @@ -0,0 +1,83 @@ +# Run Llava on AscendNPU + + + +## Installation +1. Clone this repository and navigate to LLaVA folder +```bash +git clone https://github.com/haotian-liu/LLaVA.git +cd LLaVA +``` + +2. Install Package +```Shell +conda create -n llava python=3.10 -y +conda activate llava +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +3. Install additional packages for training cases +``` +pip install -e ".[train]" +``` + +4. Install Ascend Extension for PyTorch + +You can follow this [guide](https://www.hiascend.com/document/detail/en/ModelZoo/pytorchframework/ptes/ptes_00001.html) to download and install the Ascend NPU Firmware, Ascend NPU Driver, and CANN. Afterwards, you need to install additional Python packages. +```shell +pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu #For X86 +pip3 install torch==2.1.0 #For Aarch64 +pip3 install accelerate==0.28.0 decorator==5.1.1 scipy==1.13.0 attrs==23.2.0 openpyxl +``` +After installing the above Python packages, +You can follow this [README](https://github.com/Ascend/pytorch/blob/master/README.md) to install the torch_npu environment. +Then you can use Llava on Ascend NPU. + + + + +## Pretrain/Finetune Llava on AscendNPU +If you want to Pretrain/Finetune Llava on AscendNPU, you only need to make modifications to two lines in the Pretrain/Finetune shell script. + +As shown below: +```shell +# Firstly, add environment variables to the system via the 'source' command. +source /usr/local/Ascend/ascend-toolkit/set_env.sh +# Disable TF32 mode +--tf32 False +``` +Here is [finetune shell](scripts/v1_5/finetune_npu.sh) example on AscendNPU + + +## Inference/Evaluate Llava on AscendNPU +If you want to perform inference/evaluation, a small modification to your shell script is all that's needed. + + +As shown below, you only need to add a 'source' command in your shell script,and the usage for inference remains the same. +```shell +# textvqa.sh +source /usr/local/Ascend/ascend-toolkit/set_env.sh #Add this +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-13b \ + --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ + --image-folder ./playground/data/eval/textvqa/train_images \ + --answers-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 + +python -m llava.eval.eval_textvqa \ + --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ + --result-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl + +# inference.sh +source /usr/local/Ascend/ascend-toolkit/set_env.sh #Add this +python -m llava.serve.cli \ + --model-path liuhaotian/llava-v1.5-7b \ + --image-file "https://llava-vl.github.io/static/images/view.jpg" \ + +``` +*NOTE:Ascend NPU doesn't support all quantization methods. If you encounter issues during inference, you can remove the quantization.* + + + diff --git a/llava/eval/model_qa.py b/llava/eval/model_qa.py index 2e254da15..765a6980f 100644 --- a/llava/eval/model_qa.py +++ b/llava/eval/model_qa.py @@ -7,8 +7,11 @@ import shortuuid from llava.conversation import default_conversation -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu @torch.inference_mode() def eval_model(model_name, questions_file, answers_file): @@ -17,7 +20,8 @@ def eval_model(model_name, questions_file, answers_file): model_name = os.path.expanduser(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name, - torch_dtype=torch.float16).cuda() + torch_dtype=torch.float16).to("npu" if is_npu_available() else "cuda") + ques_file = open(os.path.expanduser(questions_file), "r") @@ -30,7 +34,7 @@ def eval_model(model_name, questions_file, answers_file): conv.append_message(conv.roles[0], qs) prompt = conv.get_prompt() inputs = tokenizer([prompt]) - input_ids = torch.as_tensor(inputs.input_ids).cuda() + input_ids = torch.as_tensor(inputs.input_ids).to("npu" if is_npu_available() else "cuda") output_ids = model.generate( input_ids, do_sample=True, diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py index 938706438..99ec55120 100644 --- a/llava/eval/model_vqa.py +++ b/llava/eval/model_vqa.py @@ -8,12 +8,15 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from PIL import Image import math +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" @@ -53,7 +56,7 @@ def eval_model(args): conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to("npu" if is_npu_available() else "cuda") image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') image_tensor = process_images([image], image_processor, model.config)[0] @@ -61,7 +64,7 @@ def eval_model(args): with torch.inference_mode(): output_ids = model.generate( input_ids, - images=image_tensor.unsqueeze(0).half().cuda(), + images=image_tensor.unsqueeze(0).half().to("npu" if is_npu_available() else "cuda"), image_sizes=[image.size], do_sample=True if args.temperature > 0 else False, temperature=args.temperature, diff --git a/llava/eval/model_vqa_loader.py b/llava/eval/model_vqa_loader.py index d435b7d83..74c5acfcf 100644 --- a/llava/eval/model_vqa_loader.py +++ b/llava/eval/model_vqa_loader.py @@ -8,13 +8,16 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from torch.utils.data import Dataset, DataLoader from PIL import Image import math +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" @@ -99,12 +102,12 @@ def eval_model(args): idx = line["question_id"] cur_prompt = line["text"] - input_ids = input_ids.to(device='cuda', non_blocking=True) + input_ids = input_ids.to(device="npu" if is_npu_available() else "cuda", non_blocking=True) with torch.inference_mode(): output_ids = model.generate( input_ids, - images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + images=image_tensor.to(dtype=torch.float16, device="npu" if is_npu_available() else "cuda",non_blocking=True), image_sizes=image_sizes, do_sample=True if args.temperature > 0 else False, temperature=args.temperature, diff --git a/llava/eval/model_vqa_mmbench.py b/llava/eval/model_vqa_mmbench.py index bd7a4c808..1e708c746 100644 --- a/llava/eval/model_vqa_mmbench.py +++ b/llava/eval/model_vqa_mmbench.py @@ -9,12 +9,15 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path from PIL import Image import math +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu all_options = ['A', 'B', 'C', 'D'] @@ -103,14 +106,14 @@ def eval_model(args): conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device="npu" if is_npu_available() else "cuda") image_tensor = process_images([image], image_processor, model.config)[0] with torch.inference_mode(): output_ids = model.generate( input_ids, - images=image_tensor.unsqueeze(0).half().cuda(), + images=image_tensor.unsqueeze(0).half().to(device="npu" if is_npu_available() else "cuda"), image_sizes=[image.size], do_sample=True if args.temperature > 0 else False, temperature=args.temperature, diff --git a/llava/eval/model_vqa_science.py b/llava/eval/model_vqa_science.py index 90fc681a2..4dd7f1ddc 100644 --- a/llava/eval/model_vqa_science.py +++ b/llava/eval/model_vqa_science.py @@ -8,12 +8,15 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from PIL import Image import math +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" @@ -48,7 +51,7 @@ def eval_model(args): image_file = line["image"] image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = process_images([image], image_processor, model.config)[0] - images = image_tensor.unsqueeze(0).half().cuda() + images = image_tensor.unsqueeze(0).half().to(device="npu" if is_npu_available() else "cuda") image_sizes = [image.size] if getattr(model.config, 'mm_use_im_start_end', False): qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs @@ -68,7 +71,7 @@ def eval_model(args): conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device="npu" if is_npu_available() else "cuda") with torch.inference_mode(): output_ids = model.generate( diff --git a/llava/eval/run_llava.py b/llava/eval/run_llava.py index 24b0fffcc..0ba5439a7 100644 --- a/llava/eval/run_llava.py +++ b/llava/eval/run_llava.py @@ -10,7 +10,7 @@ ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model -from llava.utils import disable_torch_init +from llava.utils import disable_torch_init, is_npu_available from llava.mm_utils import ( process_images, tokenizer_image_token, @@ -24,6 +24,9 @@ from io import BytesIO import re +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu def image_parser(args): out = args.image_file.split(args.sep) @@ -108,7 +111,7 @@ def eval_model(args): input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) - .cuda() + .to(device="npu" if is_npu_available() else "cuda") ) with torch.inference_mode(): diff --git a/llava/serve/cli.py b/llava/serve/cli.py index 5ecb30d56..d38527c4d 100644 --- a/llava/serve/cli.py +++ b/llava/serve/cli.py @@ -82,7 +82,7 @@ def main(args): else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp image = None - + conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() diff --git a/llava/train/llama_npu_monkey_patch.py b/llava/train/llama_npu_monkey_patch.py new file mode 100644 index 000000000..de0757a8c --- /dev/null +++ b/llava/train/llama_npu_monkey_patch.py @@ -0,0 +1,245 @@ +import warnings +import math +from typing import Optional, Tuple + +import torch +import transformers.models.llama.modeling_llama +import torch.nn.functional as F +from transformers.cache_utils import Cache, DynamicCache +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from torch import nn +from einops import rearrange +import torch_npu + + +def forward_rmsnorm(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + +class FlashLlamaRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, :, None, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, None, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( + self.cos_cached[:, :seq_len, :, ...].to(dtype=x.dtype), + self.sin_cached[:, :seq_len, :, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_fused_rotary_pos_emb(q, k, cos, sin, position_ids): + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +def _init_rope(self): + self.rotary_emb = FlashLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # use torch_npu flash attention + if not use_cache and query_states.dtype in (torch.float16, torch.bfloat16): + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) # BSND + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) # BSND + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) # BSND + + kv_seq_len = key_states.shape[1] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_fused_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + q, k, v = [rearrange(x, 'b s h d -> b s (h d)').contiguous() for x in + (query_states, key_states, value_states)] # BSH + scale = 1 / math.sqrt(self.head_dim) + + attention_mask_shape = attention_mask.shape + if attention_mask_shape[0] == 1: + attention_mask = attention_mask.view((attention_mask_shape[-2], attention_mask_shape[-1])) + if not isinstance(attention_mask.type(), torch.BoolTensor): + attention_mask = attention_mask.bool() + + attn_output = torch_npu.npu_fusion_attention( + q, k, v, self.num_heads, + pse=None, + padding_mask=None, + atten_mask=attention_mask, + scale=scale, + keep_prob=1, + input_layout="BSH", + pre_tockens=65536, + next_tockens=0, + inner_precise=0)[0] + + else: + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_with_torch_npu_flash_attention(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = attention_forward + transformers.models.llama.modeling_llama.LlamaAttention._init_rope = _init_rope + + +def replace_with_torch_npu_rmsnorm(): + transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = forward_rmsnorm diff --git a/llava/train/train_npu.py b/llava/train/train_npu.py new file mode 100644 index 000000000..06e66c003 --- /dev/null +++ b/llava/train/train_npu.py @@ -0,0 +1,14 @@ +from llava.train.llama_npu_monkey_patch import ( + replace_with_torch_npu_flash_attention, + replace_with_torch_npu_rmsnorm +) + +replace_with_torch_npu_flash_attention() +replace_with_torch_npu_rmsnorm() + +from llava.train.train import train +import torch_npu +from torch_npu.contrib import transfer_to_npu + +if __name__ == "__main__": + train() diff --git a/llava/utils.py b/llava/utils.py index 4006cf917..d766ecfe1 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -3,10 +3,11 @@ import logging.handlers import os import sys - +import importlib import requests from llava.constants import LOGDIR +import torch server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." @@ -124,3 +125,18 @@ def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False diff --git a/scripts/v1_5/finetune_npu.sh b/scripts/v1_5/finetune_npu.sh new file mode 100644 index 000000000..6e2425aa7 --- /dev/null +++ b/scripts/v1_5/finetune_npu.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Default path, change it if needed. +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +deepspeed llava/train/train_npu.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path lmsys/vicuna-7b-v1.5 \ + --version v1 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-7b-pretrain/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-7b \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb