From 83dcd591175d7f4b3dbc411330cc5a401889591f Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 16 Jan 2025 17:08:54 +0800 Subject: [PATCH] fix Signed-off-by: Austin Liu --- test/transformers/test_rope.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 6f7e79d83..f026f47de 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -58,7 +58,7 @@ def test_correctness( atol, rtol, ): - rotary_emb = LlamaRotaryEmbedding(LlamaConfig(head_dim=head_dim), device=device) + rotary_emb = LlamaRotaryEmbedding(config=LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), device=device) _tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype) @@ -134,7 +134,7 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(LlamaConfig(head_dim=head_dim), device=device) + rotary_emb = LlamaRotaryEmbedding(config=LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), device=device) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) if expand_position_ids: