Skip to content

Commit

Permalink
foramt
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 8, 2025
1 parent 031707d commit bd4c9f5
Showing 1 changed file with 18 additions and 36 deletions.
54 changes: 18 additions & 36 deletions test/transformers/test_flex_attention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16

import pytest
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention,
)

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import create_mask
from torch.nn.attention.flex_attention import flex_attention

from liger_kernel.utils import infer_device

Expand All @@ -17,13 +17,9 @@ def causal_mask(b, h, q_idx, kv_idx):


def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
return (
~(
(q_idx >= rejected_index[b])
& (chosen_index[b] <= kv_idx)
& (kv_idx < rejected_index[b])
)
) & (q_idx >= kv_idx)
return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & (
q_idx >= kv_idx
)


device = infer_device()
Expand All @@ -47,30 +43,22 @@ def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cud
torch.manual_seed(0)

# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
query_torch = torch.randn(
B, H, S, D, device=device, dtype=dtype, requires_grad=True
)
query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
value_torch = torch.randn(
B, H, S, D, device=device, dtype=dtype, requires_grad=True
)
value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)

query_flex = query_torch.clone().detach().requires_grad_(True)
key_flex = key_torch.clone().detach().requires_grad_(True)
value_flex = value_torch.clone().detach().requires_grad_(True)

block_mask = create_block_mask(
mask_func, B, H, S, S, device=device
) # Sparsity block mask
block_mask = create_block_mask(mask_func, B, H, S, S, device=device) # Sparsity block mask
mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask

# If you are using a causal mask with FA2, you can enable `is_causal`."
# e.g.,
# F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)

torch_out = F.scaled_dot_product_attention(
query_torch, key_torch, value_torch, attn_mask=mask
)
torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask)

flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)

Expand Down Expand Up @@ -101,9 +89,7 @@ def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cud
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
Expand Down Expand Up @@ -166,7 +152,7 @@ def _test_correctness_prefix(
P P P R R R
We test them as belwo to ensure attention value equivalence:
We test them as below to ensure attention value equivalence:
1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
P P
Expand Down Expand Up @@ -286,16 +272,12 @@ def causal_mask(b, h, q_idx, kv_idx):
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
],
)
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
"""Parametrized test for different configurations"""
_test_correctness_prefix(
B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol
)
_test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)

0 comments on commit bd4c9f5

Please sign in to comment.