diff --git a/test/transformers/test_flex_attention.py b/test/transformers/test_flex_attention.py index 3f32e23da..4fabdc1f0 100644 --- a/test/transformers/test_flex_attention.py +++ b/test/transformers/test_flex_attention.py @@ -1,13 +1,13 @@ -from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 - import pytest import torch import torch.nn.functional as F -from torch.nn.attention.flex_attention import ( - create_block_mask, - create_mask, - flex_attention, -) + +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 +from torch.nn.attention.flex_attention import create_block_mask +from torch.nn.attention.flex_attention import create_mask +from torch.nn.attention.flex_attention import flex_attention from liger_kernel.utils import infer_device @@ -17,13 +17,9 @@ def causal_mask(b, h, q_idx, kv_idx): def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index): - return ( - ~( - (q_idx >= rejected_index[b]) - & (chosen_index[b] <= kv_idx) - & (kv_idx < rejected_index[b]) - ) - ) & (q_idx >= kv_idx) + return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & ( + q_idx >= kv_idx + ) device = infer_device() @@ -47,30 +43,22 @@ def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cud torch.manual_seed(0) # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input) - query_torch = torch.randn( - B, H, S, D, device=device, dtype=dtype, requires_grad=True - ) + query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) - value_torch = torch.randn( - B, H, S, D, device=device, dtype=dtype, requires_grad=True - ) + value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) query_flex = query_torch.clone().detach().requires_grad_(True) key_flex = key_torch.clone().detach().requires_grad_(True) value_flex = value_torch.clone().detach().requires_grad_(True) - block_mask = create_block_mask( - mask_func, B, H, S, S, device=device - ) # Sparsity block mask + block_mask = create_block_mask(mask_func, B, H, S, S, device=device) # Sparsity block mask mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask # If you are using a causal mask with FA2, you can enable `is_causal`." # e.g., # F.scaled_dot_product_attention(query, key, value, is_causal=is_causal) - torch_out = F.scaled_dot_product_attention( - query_torch, key_torch, value_torch, attn_mask=mask - ) + torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask) flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask) @@ -101,9 +89,7 @@ def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cud torch.bfloat16, 3e-2, 5e-1, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (torch.float16, 1e-2, 5e-3), (torch.float32, 1e-3, 5e-4), @@ -166,7 +152,7 @@ def _test_correctness_prefix( P P P R R R - We test them as belwo to ensure attention value equivalence: + We test them as below to ensure attention value equivalence: 1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.) 2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.) P P @@ -286,9 +272,7 @@ def causal_mask(b, h, q_idx, kv_idx): torch.bfloat16, 3e-2, 5e-1, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (torch.float16, 1e-2, 5e-3), (torch.float32, 1e-3, 5e-4), @@ -296,6 +280,4 @@ def causal_mask(b, h, q_idx, kv_idx): ) def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol): """Parametrized test for different configurations""" - _test_correctness_prefix( - B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol - ) + _test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)