Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention Scores Matrix Visualization #10

Open
bulaikexiansheng opened this issue Aug 27, 2024 · 1 comment
Open

Attention Scores Matrix Visualization #10

bulaikexiansheng opened this issue Aug 27, 2024 · 1 comment
Assignees

Comments

@bulaikexiansheng
Copy link

Hi, I would like to ask why the attention mask is not used in the prefill stage.
I want to output the attention scores matrix in prefill stage. Is the code below right?

        if spec: # spec decoding
            key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
        else:
            # update kv cache first
            key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
            if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)): 
                if graph_cache.init_graph == False:
                    # init graph cache
                    graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
                else:
                    # update graph cache (customized)
                    graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)

        # 计算注意力得分矩阵
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
        attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if attention_mask is not None:
            attention_mask = attention_mask.to(attention_scores.device)
            attention_scores += attention_mask

        attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, attention_scores
@bulaikexiansheng bulaikexiansheng changed the title Hi, I would like to ask why the attention mask is not used in the prefill stage Attention Scores Matrix Visualization Aug 27, 2024
@preminstrel
Copy link
Contributor

Hello,

We use flash attention function which already has causal mask for prefilling phase.

It should be noted that it is easy to have OOM issue when you are trying to compute attention matrix directly for long sequences.

@preminstrel preminstrel self-assigned this Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@bulaikexiansheng @preminstrel and others