Unverified Commit 7be7a8fb authored by Mateusz Tabaka's avatar Mateusz Tabaka Committed by GitHub
Browse files

[ONNX] don't hardcode shapes in Interpolate and Shape operator (#3778)

parent 00d37aaa
Showing with 1 addition and 158 deletions
+1 -158
......@@ -148,29 +148,6 @@ namespace ngraph
calculate_output_shape_based_on_scales(const Output<ngraph::Node>& data,
const Output<ngraph::Node>& scales)
{
const auto& data_shape = data.get_partial_shape();
const auto& scales_shape = scales.get_partial_shape();
if (ngraph::op::is_constant(scales.get_node()) && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales.get_node_shared_ptr());
const auto scales_vector = scales_const->cast_vector<float>();
const auto data_static_shape = data_shape.to_shape();
std::vector<int64_t> output_shape;
for (size_t i = 0; i < data_static_shape.size(); ++i)
{
output_shape.push_back(
std::floor(data_static_shape.at(i) * scales_vector.at(i)));
}
auto output_shape_const = default_opset::Constant::create(
element::u64, Shape({output_shape.size()}), output_shape);
return output_shape_const;
}
const auto shape_of_data = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::ShapeOf>(data), scales.get_element_type());
const auto multiply =
......@@ -185,33 +162,7 @@ namespace ngraph
calculate_scales_based_on_sizes(const Output<ngraph::Node>& data,
const Output<ngraph::Node>& sizes)
{
const auto& data_shape = data.get_partial_shape();
const auto& sizes_shape = sizes.get_partial_shape();
const float epsilon = 1.0e-5;
if (ngraph::op::is_constant(sizes.get_node()) && data_shape.is_static())
{
const auto sizes_const =
as_type_ptr<default_opset::Constant>(sizes.get_node_shared_ptr());
const auto sizes_vector = sizes_const->cast_vector<int64_t>();
const auto data_static_shape = data_shape.to_shape();
std::vector<float> scales;
for (size_t i = 0; i < data_static_shape.size(); ++i)
{
float scale = static_cast<float>(sizes_vector.at(i)) /
static_cast<float>(data_static_shape.at(i)) +
epsilon;
scales.push_back(scale);
}
auto scales_const = default_opset::Constant::create(
element::f32, Shape({scales.size()}), scales);
return scales_const;
}
const auto shape_of_data = std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::ShapeOf>(data), ngraph::element::f32);
const auto converted_sizes =
......
......@@ -33,20 +33,7 @@ namespace ngraph
OutputVector shape(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
const auto data_shape = data.get_partial_shape();
if (data_shape.is_static())
{
const auto static_data_shape = data_shape.to_shape();
return {default_opset::Constant::create(ngraph::element::i64,
Shape{static_data_shape.size()},
static_data_shape)};
}
else
{
return {std::make_shared<default_opset::ShapeOf>(data)};
}
return {std::make_shared<default_opset::ShapeOf>(data)};
}
} // namespace set_1
......
ir_version: 6
producer_name: "test_model"
graph {
node {
output: "scales"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 4
data_type: 1
float_data: 4.0
float_data: 3.0
float_data: 2.0
float_data: 1.0
name: "scales_const"
}
type: TENSOR
}
}
node {
input: "X"
input: "scales"
output: "output"
op_type: "Resize"
attribute {
name: "mode"
s: "nearest"
type: STRING
}
}
name: "test_model"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 6
}
dim {
dim_value: 6
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 10
}
......@@ -1122,21 +1122,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_resize10_import_only)
{
const auto resize_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/resize_opset10.prototxt"));
// Input data shape (1, 2, 3, 4)
// Scales input constant values {4, 3, 2, 1}
Shape expected_output_shape{4, 6, 6, 4};
EXPECT_EQ(resize_fn->get_output_size(), 1);
EXPECT_EQ(resize_fn->get_output_shape(0), expected_output_shape);
EXPECT_EQ(count_ops_of_type<op::v0::Interpolate>(resize_fn), 1);
EXPECT_EQ(count_ops_of_type<onnx_import::default_opset::Constant>(resize_fn), 1);
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input)
{
// this model contains a Constant node with an empty underlying tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment