Skip to content

Commit

Permalink
Adding support for set_shape to the state tensors
Browse files Browse the repository at this point in the history
Signed-off-by: Bogdan Pereanu <[email protected]>
  • Loading branch information
pereanub committed Jan 16, 2025
1 parent 2088b8f commit 0e8fda2
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ void ZeroInferRequest::infer_async() {
const auto inputDescriptor = _metadata.inputs.at(ioIndex);
auto zeroTensor = std::dynamic_pointer_cast<ZeroTensor>(levelZeroTensor.at(SINGLE_TENSOR));

if (is_batched_input(ioIndex) || inputDescriptor.isShapeTensor || inputDescriptor.isStateInput ||
if (is_batched_input(ioIndex) || inputDescriptor.isShapeTensor ||
is_remote_tensor(levelZeroTensor.at(SINGLE_TENSOR)) || zeroTensor == nullptr) {
++ioIndex;
continue;
Expand All @@ -499,7 +499,9 @@ void ZeroInferRequest::infer_async() {
zeroTensor->get_byte_size());
closePipeline = true;

zeroTensor->reset_memory_flag();
if (!inputDescriptor.isStateInput) {
zeroTensor->reset_memory_flag();
}
}

++ioIndex;
Expand All @@ -511,8 +513,7 @@ void ZeroInferRequest::infer_async() {
const auto outputDescriptor = _metadata.outputs.at(ioIndex);
auto zeroTensor = std::dynamic_pointer_cast<ZeroTensor>(levelZeroTensor);

if (outputDescriptor.isShapeTensor || outputDescriptor.isStateOutput ||
is_remote_tensor(levelZeroTensor) || zeroTensor == nullptr) {
if (outputDescriptor.isShapeTensor || is_remote_tensor(levelZeroTensor) || zeroTensor == nullptr) {
++ioIndex;
continue;
}
Expand Down
131 changes: 131 additions & 0 deletions src/plugins/intel_npu/tests/functional/behavior/infer_request_run.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <array>
#include <exception>
#include <random>
#include <thread>

#include "base/ov_behavior_test_utils.hpp"
Expand All @@ -20,7 +21,10 @@
#include "intel_npu/npu_private_properties.hpp"
#include "openvino/core/any.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/runtime/compiled_model.hpp"
#include "openvino/runtime/core.hpp"
Expand Down Expand Up @@ -124,6 +128,35 @@ class InferRequestRunTests : public ov::test::behavior::OVPluginTestBase,

return std::make_shared<Model>(res, params);
}

std::shared_ptr<ov::Model> createModelWithStates(element::Type type, const Shape& shape) {
auto input = std::make_shared<ov::op::v0::Parameter>(type, shape);
auto mem_i1 = std::make_shared<ov::op::v0::Constant>(type, shape, 0);
auto mem_r1 = std::make_shared<ov::op::v3::ReadValue>(mem_i1, "r_1-3");
auto mul1 = std::make_shared<ov::op::v1::Multiply>(mem_r1, input);

auto mem_i2 = std::make_shared<ov::op::v0::Constant>(type, shape, 0);
auto mem_r2 = std::make_shared<ov::op::v3::ReadValue>(mem_i2, "c_1-3");
auto mul2 = std::make_shared<ov::op::v1::Multiply>(mem_r2, mul1);
auto mem_w2 = std::make_shared<ov::op::v3::Assign>(mul2, "c_1-3");

auto mem_w1 = std::make_shared<ov::op::v3::Assign>(mul2, "r_1-3");
auto sigm = std::make_shared<ov::op::v0::Sigmoid>(mul2);
sigm->set_friendly_name("sigmod_state");
sigm->get_output_tensor(0).set_names({"sigmod_state"});
mem_r1->set_friendly_name("Memory_1");
mem_r1->get_output_tensor(0).set_names({"Memory_1"});
mem_w1->add_control_dependency(mem_r1);
sigm->add_control_dependency(mem_w1);

mem_r2->set_friendly_name("Memory_2");
mem_r2->get_output_tensor(0).set_names({"Memory_2"});
mem_w2->add_control_dependency(mem_r2);
sigm->add_control_dependency(mem_w2);

auto function = std::make_shared<ov::Model>(ov::NodeVector{sigm}, ov::ParameterVector{input}, "add_output");
return function;
}
};

TEST_P(InferRequestRunTests, AllocatorCanDisposeBlobWhenOnlyInferRequestIsInScope) {
Expand Down Expand Up @@ -962,6 +995,104 @@ TEST_P(SetShapeInferRunTests, checkResultsAfterIOBlobReallocation) {
}
}

TEST_P(SetShapeInferRunTests, checkResultsAfterStateTensorsReallocation) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()

testing::internal::Random random(1);
ov::Tensor input_tensor;

auto original_shape = Shape{1, 10, 10, 10};
auto dummy_shape = Shape{1, 50, 100, 100};
auto shape_size = ov::shape_size(original_shape);
auto model = createModelWithStates(element::f32, original_shape);

auto context = core->get_default_context(target_device);

compiled_model = core->compile_model(model, target_device, configuration);
ov::InferRequest inference_request;
inference_request = compiled_model.create_infer_request();

auto input = compiled_model.input();
OV_ASSERT_NO_THROW(input_tensor = inference_request.get_tensor(input));
auto* input_data = input_tensor.data<float>();
for (size_t i = 0; i < shape_size; ++i) {
input_data[i] = static_cast<float>(random.Generate(10));
}

for (auto&& state : inference_request.query_state()) {
state.reset();
}

OV_ASSERT_NO_THROW(inference_request.infer());

auto output_tensor = inference_request.get_tensor("sigmod_state");
auto output_data = output_tensor.data<float>();
for (size_t i = 0; i < output_tensor.get_size(); i++) {
EXPECT_NEAR(0.5f, output_data[i], 1e-5);
}

auto states = inference_request.query_state();
for (auto state : states) {
auto last_state = state.get_state();
auto last_state_size = last_state.get_size();
auto last_state_data = static_cast<float*>(last_state.data());

ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";

for (size_t i = 0; i < last_state_size; ++i) {
EXPECT_NEAR(0.0, last_state_data[i], 1e-5);
}
}

// create dummy Tensors to force the driver to allocate memory for the initial tensor somewhere else
[[maybe_unused]] auto l0_host_dummy_tensor_0 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_1 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_2 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_3 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_4 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_5 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_6 = context.create_host_tensor(ov::element::f32, dummy_shape);
[[maybe_unused]] auto l0_host_dummy_tensor_7 = context.create_host_tensor(ov::element::f32, dummy_shape);

for (auto item : inference_request.query_state()) {
auto tensor_state = item.get_state();
auto original_shape = tensor_state.get_shape();
OV_ASSERT_NO_THROW(tensor_state.set_shape({1, 50, 20, 20}));
OV_ASSERT_NO_THROW(tensor_state.set_shape(original_shape));
}

for (auto&& state : inference_request.query_state()) {
state.reset();
}

for (auto state : states) {
auto last_state = state.get_state();
auto last_state_size = last_state.get_size();
auto last_state_data = static_cast<float*>(last_state.data());

ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";

for (size_t i = 0; i < last_state_size; ++i) {
last_state_data[i] = 1.0f;
}
}

OV_ASSERT_NO_THROW(inference_request.infer());

for (auto state : states) {
auto last_state = state.get_state();
auto last_state_size = last_state.get_size();
auto last_state_data = static_cast<float*>(last_state.data());

ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";

for (size_t i = 0; i < last_state_size; ++i) {
EXPECT_NEAR(input_data[i], last_state_data[i], 1e-5);
}
}
}

} // namespace behavior
} // namespace test
} // namespace ov

0 comments on commit 0e8fda2

Please sign in to comment.