Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Sep 2, 2024
1 parent 7f3e67f commit 9063c30
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
"""
This script demonstrates the usage of the FlashAttentionmodule from zeta.nn as an example.
This script demonstrates the usage of the FlashAttention module from zeta.nn.
"""

import torch

from zeta.nn import FlashAttention

q = torch.randn(2, 4, 6, 8)
k = torch.randn(2, 4, 10, 8)
v = torch.randn(2, 4, 10, 8)
# Set random seed for reproducibility
torch.manual_seed(42)

# Define input tensor shapes
batch_size, num_heads, seq_len_q, d_head = 2, 4, 6, 8
seq_len_kv = 10

# Create random input tensors
q = torch.randn(batch_size, num_heads, seq_len_q, d_head)
k = torch.randn(batch_size, num_heads, seq_len_kv, d_head)
v = torch.randn(batch_size, num_heads, seq_len_kv, d_head)

# Initialize FlashAttention module
attention = FlashAttention(causal=False, dropout=0.1, flash=False)
print(attention)
print("FlashAttention configuration:", attention)

# Perform attention operation
output = attention(q, k, v)

print(output.shape)
print(f"Output shape: {output.shape}")

# Optional: Add assertion to check expected output shape
assert output.shape == (batch_size, num_heads, seq_len_q, d_head), "Unexpected output shape"

0 comments on commit 9063c30

Please sign in to comment.