Skip to content

Commit

Permalink
Fix RopeFusion transformation after applying SDPA to PagedAttention c…
Browse files Browse the repository at this point in the history
…onversion (#28447)

### Details:
After internal discussion, we decided to use the changes from
#27718 as a base line
Fixed Rope pattern detection for ChatGLM.

### Tickets:
 - *CVS-158393*
  • Loading branch information
itikhono authored Jan 16, 2025
1 parent 25cd6b0 commit ed50d51
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_452_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims / 2, 2});
auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2});
reshape_Reshape_453 = makePattern<opset1::Reshape>(
{slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1});
reshape_Reshape_453 =
makePattern<opset1::Reshape>({slice_Slice_437 | var_split_1->output(0),
ListConstruct_452_Concat | const_target_shape_1 | const_target_shape_0});
}

auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
Expand Down Expand Up @@ -588,6 +590,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_379_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({1, -1, 1, ndims / 2, 2});
auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2});

auto slice_Slice_449 = makePattern<ov::opset8::Slice>({cos_sin_cache, {0}, seq_length, {1}, {0}});
Expand All @@ -596,7 +599,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
// [seq_length, 1, batch, half_rotary_dims, 2]
view_Reshape_460 =
makePattern<opset1::Reshape>({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0),
ListConstruct_379_Concat | const_target_shape_2},
ListConstruct_379_Concat | const_target_shape_0 | const_target_shape_2},
{{"special_zero", false}});
}

Expand All @@ -609,12 +612,17 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
auto sub_Subtract_469 = makePattern<opset1::Add>({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}});

auto y_even = makePattern<opset1::Unsqueeze>({sub_Subtract_469, -1});
auto const_y_even_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_even_reshape =
makePattern<opset1::Reshape>({sub_Subtract_469, const_y_even_reshape}, {{"special_zero", false}});
auto x_odd_cos = makePattern<opset1::Multiply>({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}});
auto x_even_sin = makePattern<opset1::Multiply>({x_even, sin_tab}, {{"auto_broadcast", "numpy"}});
auto add_Add_476 = makePattern<opset1::Add>({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}});
auto y_odd = makePattern<opset1::Unsqueeze>({add_Add_476, -1});
auto const_y_odd_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_odd_reshape = makePattern<opset1::Reshape>({add_Add_476, const_y_odd_reshape}, {{"special_zero", false}});

auto stack_481 = makePattern<opset1::Concat>({y_even, y_odd}, {{"axis", -1}});
auto stack_481 = makePattern<opset1::Concat>({y_even | y_even_reshape, y_odd | y_odd_reshape}, {{"axis", -1}});

auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0);
Expand All @@ -629,9 +637,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
{{"special_zero", true}});
} else {
// [length, batch, head_cnt, half_rotary_dims, 2]
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims});
const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims});
flatten_Reshape_501 = makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3},
{{"special_zero", true}});
flatten_Reshape_501 =
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3 | const_target_shape_0},
{{"special_zero", true}});
}
auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,88 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) {
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
disable_rt_info_check();
const int batch = -1;
const int seq_len = 1;
const int num_heads = 32;
const int num_heads_kv = 2;
const int ndims = 128;
const int rotary_ndims = 64;
const int hidden_size = ndims * (num_heads + 2 * num_heads_kv);
const int hidden_size_q = ndims * num_heads;
const int hidden_size_kv = ndims * num_heads_kv;
using namespace ov;
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size});
auto cos_sin = std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2});
auto aten_slice_Slice_1 = makeOP<opset8::Slice>({cos_sin, {0}, {1}, {1}, {0}});
auto aten_view_Reshape = makeOP<opset1::Reshape>({aten_slice_Slice_1, {seq_len, batch, 1, rotary_ndims / 2, 2}},
{{"special_zero", false}});
auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}});
auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}});

auto attn_prim_ListUnpack =
makeOP<opset1::VariadicSplit>({input, -1, {hidden_size_q, hidden_size_kv, hidden_size_kv}});
auto attn_aten_view_Reshape_2 =
makeOP<opset1::Reshape>({attn_prim_ListUnpack->output(0), {0, 0, num_heads, ndims}},
{{"special_zero", true}});
auto VariadicSplit_29663 =
makeOP<opset1::VariadicSplit>({attn_aten_view_Reshape_2, 3, {rotary_ndims, ndims - rotary_ndims}});
auto aten_reshape_Reshape_55 =
makeOP<opset1::Reshape>({VariadicSplit_29663->output(0), {0, 0, num_heads, rotary_ndims / 2, 2}},
{{"special_zero", true}});
auto aten_select_Gather_440 = makeOP<opset8::Gather>({aten_reshape_Reshape_55, 0, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_276 =
makeOP<opset1::Multiply>({aten_select_Gather_440, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
auto aten_select_Gather_442 = makeOP<opset8::Gather>({aten_reshape_Reshape_55, 1, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_277 =
makeOP<opset1::Multiply>({aten_select_Gather_442, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
auto Multiply_34833 =
makeOP<opset1::Multiply>({aten_mul_Multiply_277, -1.000000f}, {{"auto_broadcast", "numpy"}});
auto aten_sub_Subtract_55 =
makeOP<opset1::Add>({aten_mul_Multiply_276, Multiply_34833}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62197 = makeOP<opset1::Reshape>({aten_sub_Subtract_55, {1, -1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_mul_Multiply_278 =
makeOP<opset1::Multiply>({aten_select_Gather_442, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
auto aten_mul_Multiply_279 =
makeOP<opset1::Multiply>({aten_select_Gather_440, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
auto aten_add_Add_55 =
makeOP<opset1::Add>({aten_mul_Multiply_278, aten_mul_Multiply_279}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62198 = makeOP<opset1::Reshape>({aten_add_Add_55, {1, -1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_stack_55 = makeOP<opset1::Concat>({Unsqueeze_62197, Unsqueeze_62198}, {{"axis", -1}});
auto aten_flatten_Reshape_55 =
makeOP<opset1::Reshape>({aten_stack_55, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}});
auto aten_cat_Concat_55 =
makeOP<opset1::Concat>({aten_flatten_Reshape_55, VariadicSplit_29663->output(1)}, {{"axis", -1}});

model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_55}, ov::ParameterVector{input, cos_sin});
}
manager.register_pass<ov::pass::RoPEFusion>(false);
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size});
auto gather_cos_sin =
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2});
auto rope = makeOP<ov::op::internal::RoPE>({input, gather_cos_sin, gather_cos_sin},
{{"config.slice_start", 0},
{"config.slice_stop", 4096},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", true},
{"config.support_2d_rope", false},
{"config.is_qwen", false},
{"config.head_cnt", num_heads},
{"config.head_size", ndims},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin});
}
}

0 comments on commit ed50d51

Please sign in to comment.