Unverified Commit bd9bbe09 authored by Patryk Elszkowski's avatar Patryk Elszkowski Committed by GitHub
Browse files

New Gather op reference implementation. (#3633)


* New Gather op reference implementation.

* Unify span implementation for gather and gather_nd.

Create span.hpp for common implementation of span.

* Move span to utils directory.

* Address review comments.

* update span

* Address PR comments.
Co-authored-by: default avatarPatryk Elszkowski <patryk.elszkowki@intel.com>
Showing with 460 additions and 211 deletions
+460 -211
......@@ -185,7 +185,7 @@ namespace ngraph
Strides(source_shape.size(), 1));
}
/// \brief Class allows to iterate over Tensor with reverted axies part by part.
/// \brief Class allows to iterate over Tensor with reverted axes part by part.
///
/// To create ReverseRange use _reverse_ function.
///
......@@ -213,8 +213,14 @@ namespace ngraph
return ReverseRange(source_shape, reversed_axis);
}
inline ReverseRange index(const Shape& source_shape)
{
return reverse(source_shape, {});
}
} // namespace impl
using impl::Direction;
using impl::index;
using impl::reverse;
using impl::slice;
} // namespace coordinates
......
......@@ -18,8 +18,10 @@
#include <numeric>
#include "ngraph/coordinate_range.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "utils/span.hpp"
namespace ngraph
{
......@@ -27,147 +29,105 @@ namespace ngraph
{
namespace reference
{
// Implement gather by calling gather_nd on sub-problems
// # prepare constant shapes for tensors used for sub problems
// indices'.shape = indices.shape[-1] + [1]
// params'.shape = params.shape[axis:]
// out'.shape = params'.shape
// out'.shape[0] = indices.shape[-1]
// # call sub-problems
// foreach (params_index, out_index) in outer "axis" dimensions
// # params_prime is shared by inner loop
// params' = param[params_index] # rank(params') == rank(params) - axis
// foreach indices_index in outer N-1 dimensions
// indices' = indices[indices_index] # rank(indices') == 2
// out_index = out_index + indices_index
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
namespace
{
template <typename Container>
Shape to_shape(const Container& c)
{
return Shape(begin(c), end(c));
}
template <typename Container>
std::vector<size_t>
join(const Container& c1, const Container& c2, const Container& c3)
{
using container_value_type =
typename std::remove_cv<typename Container::value_type>::type;
static_assert(std::is_same<container_value_type, size_t>::value,
"Expect same type in container");
std::vector<size_t> ret;
ret.reserve(c1.size() + c2.size() + c3.size());
std::copy(begin(c1), end(c1), std::back_inserter(ret));
std::copy(begin(c2), end(c2), std::back_inserter(ret));
std::copy(begin(c3), end(c3), std::back_inserter(ret));
return ret;
}
const auto only_one = [] { return coordinates::index(Shape{1}); };
} // namespace
template <typename T, typename U>
void gather(const T* params,
const U* indices,
T* out,
void gather(const T* const params,
const U* const indices,
T* const out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape,
size_t axis)
{
// prepare shape of params_prime (remove first "axis" dimensions)
const Shape params_prime_shape(params_shape.begin() + axis, params_shape.end());
// prepare shape of indices_prime
const size_t indices_ndim = indices_shape.size();
Shape indices_prime_shape;
// prepare shape of out_prime (same as params_prime except for first dim)
Shape out_prime_shape(params_prime_shape);
if (indices_ndim > 0)
{
out_prime_shape[0] = indices_shape[indices_ndim - 1];
indices_prime_shape.emplace_back(indices_shape[indices_ndim - 1]);
}
else
{
out_prime_shape[0] = 1;
}
indices_prime_shape.emplace_back(1);
using std::next;
assert(std::memset(out, 0, shape_size(out_shape) * sizeof(T)));
// Create a CoordinateTransform for "out" that visits the outer "axis" dimensions
const size_t out_ndim = out_shape.size();
const Coordinate out_outer_start_corner(out_ndim, 0);
Coordinate out_outer_end_corner(out_shape);
for (size_t i = axis; i < out_ndim; i++)
{
out_outer_end_corner[i] = 1;
}
Strides out_outer_strides(out_ndim, 1);
AxisVector out_outer_axis_order(out_ndim);
std::iota(out_outer_axis_order.begin(), out_outer_axis_order.end(), 0);
CoordinateTransform out_outer_transform(out_shape,
out_outer_start_corner,
out_outer_end_corner,
out_outer_strides,
out_outer_axis_order);
// Create a CoordinateTransform for "params" that visits the outer "axis" dimensions
const size_t params_ndim = params_shape.size();
const Coordinate params_outer_start_corner(params_ndim, 0);
Coordinate params_outer_end_corner(params_shape);
for (size_t i = axis; i < params_ndim; i++)
{
params_outer_end_corner[i] = 1;
}
const Strides params_outer_strides(params_ndim, 1);
AxisVector params_outer_axis_order(params_ndim);
std::iota(params_outer_axis_order.begin(), params_outer_axis_order.end(), 0);
const CoordinateTransform params_outer_transform(params_shape,
params_outer_start_corner,
params_outer_end_corner,
params_outer_strides,
params_outer_axis_order);
// Create a CoordinateTransform for "indices" that visits only the first element
// along inner most axis
const Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
if (indices_ndim > 0)
{
indices_outer_end_corner[indices_ndim - 1] = 1;
}
const Strides indices_outer_strides(indices_ndim, 1);
AxisVector indices_outer_axis_order(indices_ndim);
std::iota(indices_outer_axis_order.begin(), indices_outer_axis_order.end(), 0);
const CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_outer_strides,
indices_outer_axis_order);
// Create an inner CoordinateTransfrom for "out"
const size_t out_inner_ndim = out_ndim - axis;
const Shape out_inner_shape(out_shape.begin() + axis, out_shape.end());
const Coordinate out_inner_start_corner(out_inner_ndim, 0);
Coordinate out_inner_end_corner(out_inner_shape);
if (indices_ndim > 0)
{
out_inner_end_corner[indices_ndim - 1] = 1;
}
for (size_t i = indices_ndim; i < out_inner_ndim; i++)
{
out_inner_end_corner[i] = 1;
}
const Strides out_inner_strides(out_inner_ndim, 1);
AxisVector out_inner_axis_order(out_inner_ndim);
std::iota(out_inner_axis_order.begin(), out_inner_axis_order.end(), 0);
const CoordinateTransform out_inner_transform(out_inner_shape,
out_inner_start_corner,
out_inner_end_corner,
out_inner_strides,
out_inner_axis_order);
auto out_outer_coord_iter = out_outer_transform.begin();
for (const Coordinate& params_outer_coord : params_outer_transform)
const auto params_axes_part = span(params_shape).subspan(0, axis);
NGRAPH_CHECK(params_shape.size() >= axis, "Not enough axes in param_shape.");
const auto remainder_part_shape = span(params_shape).subspan(axis + 1);
const auto found_out_shape =
join(params_axes_part, span(indices_shape), remainder_part_shape);
NGRAPH_CHECK(found_out_shape == out_shape,
"Output shape mismatch with calculations");
const auto batch_shape = span(params_shape).subspan(axis);
const auto batch_size = shape_size(batch_shape);
const auto copy_size = shape_size(remainder_part_shape);
const size_t copy_round_in_batch =
indices_shape.size() > 1
? shape_size(span(indices_shape.data(), indices_shape.size() - 1))
: 1;
const size_t round_batch_offset = indices_shape.empty() ? 1 : indices_shape.back();
auto dst = out;
auto gather_range = params_axes_part.empty()
? only_one()
: coordinates::index(to_shape(params_axes_part));
for (auto i : gather_range)
{
if (out_outer_coord_iter == out_outer_transform.end())
break;
const T* params_prime =
&params[params_outer_transform.index(params_outer_coord)];
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];
auto out_inner_coord_iter = out_inner_transform.begin();
for (const Coordinate& indices_outer_coord : indices_outer_transform)
auto batch_index = i.begin_index;
for (size_t batch = 0; batch != i.element_number;
batch_index += i.step, ++batch)
{
if (out_inner_coord_iter == out_inner_transform.end())
break;
const U* indices_prime =
&indices[indices_outer_transform.index(indices_outer_coord)];
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
gather_nd<T, U>(params_prime,
indices_prime,
out_prime,
params_prime_shape,
indices_prime_shape,
out_prime_shape);
++out_inner_coord_iter;
const auto batch_offset = batch_index * batch_size;
assert(batch_offset < shape_size(params_shape));
for (size_t round = 0; round != copy_round_in_batch; ++round)
{
const U* input_indices = indices + round * round_batch_offset;
const auto indices_no =
indices_shape.empty() ? 1 : indices_shape.back();
assert(!batch_shape.empty());
for (size_t ii = 0; ii != indices_no; ++ii)
{
const auto positive_input_index =
input_indices[ii] < 0 ? batch_shape.front() + input_indices[ii]
: input_indices[ii];
const auto src_offset =
batch_offset + copy_size * positive_input_index;
const auto src_begin = next(params, src_offset);
const auto src_end = next(src_begin, copy_size);
std::copy(src_begin, src_end, dst);
dst += copy_size;
}
}
}
++out_outer_coord_iter;
}
}
} // namespace reference
......
......@@ -21,6 +21,7 @@
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
#include "utils/span.hpp"
namespace ngraph
{
......@@ -28,52 +29,8 @@ namespace ngraph
{
namespace reference
{
namespace
namespace details
{
template <bool check>
using Required = typename std::enable_if<check, bool>::type;
template <typename It>
struct IsRandomAccessIt
{
static constexpr bool value =
std::is_same<typename It::iterator_category,
std::random_access_iterator_tag>::value;
};
template <typename Iterator, Required<IsRandomAccessIt<Iterator>::value> = true>
class Span
{
public:
Span(Iterator begin, Iterator end)
: m_begin{begin}
, m_end{end}
{
}
Iterator begin() const { return m_begin; }
Iterator end() const { return m_end; };
typename Iterator::value_type operator[](size_t idx) const
{
return *next(m_begin, idx);
}
typename Iterator::difference_type size() const
{
return std::distance(m_begin, m_end);
}
private:
Iterator m_begin;
Iterator m_end;
};
template <typename Iterator>
Span<Iterator> span(Iterator begin, Iterator end)
{
return Span<Iterator>{begin, end};
};
template <typename Iterator>
std::vector<size_t> get_indices_offsets(const Iterator beg,
const Iterator end,
......@@ -90,7 +47,7 @@ namespace ngraph
return offsets;
}
} // namespace
} // namespace details
///
/// Implementation find maximum length of *slice* of input *params* which might be
......@@ -143,14 +100,14 @@ namespace ngraph
"params_shape should have enough rank to be index by indices"};
}
const auto slice_shape =
span(next(begin(params_shape), first_slice_index_in_params), end(params_shape));
const auto slice_shape = span(params_shape).subspan(first_slice_index_in_params);
const auto slice_size = shape_size(slice_shape);
const auto dims_begin = next(rbegin(params_shape), slice_shape.size());
const auto dims_end = next(dims_begin, indices_shape.back() - 1);
const auto indices_offsets = get_indices_offsets(dims_begin, dims_end, slice_size);
const auto indices_offsets =
details::get_indices_offsets(dims_begin, dims_end, slice_size);
const auto batch_offset = indices_offsets.front() * params_shape[batch_dims];
......
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <iterator>
#include <limits>
#include <type_traits>
namespace ngraph
{
namespace runtime
{
namespace reference
{
namespace details
{
template <bool check>
using Required = typename std::enable_if<check, bool>::type;
template <typename It>
struct IsRandomAccessIt
{
static constexpr bool value =
std::is_same<typename It::iterator_category,
std::random_access_iterator_tag>::value;
};
template <typename... Args>
using void_t = void;
template <typename, typename = size_t>
struct is_complete : std::false_type
{
};
template <typename T>
struct is_complete<T, decltype(sizeof(T))> : std::true_type
{
};
template <typename It>
struct from_iterator
{
using stored_value = typename std::remove_pointer<
typename std::iterator_traits<It>::pointer>::type;
};
} // namespace details
/// @brief Span should mimic std::span
template <typename Element>
class Span
{
public:
static_assert(std::is_object<Element>::value,
"Element must be an object type (not a reference type or void)");
static_assert(details::is_complete<Element>::value,
"Element must be a complete type (not a forward declaration)");
static_assert(!std::is_abstract<Element>::value,
"Element cannot be an abstract class type");
constexpr Span() = default;
constexpr Span(Element* data, std::size_t size)
: m_data{data}
, m_size{size}
{
}
using value_type = Element;
using size_type = std::size_t;
constexpr Element* begin() const noexcept { return m_data; }
constexpr Element* end() const noexcept { return m_data + m_size; }
friend constexpr Element* begin(const Span& s) noexcept { return s.begin(); }
friend constexpr Element* end(const Span& s) noexcept { return s.end(); }
constexpr std::size_t size() const noexcept { return m_size; }
constexpr bool empty() const noexcept { return !m_size; }
constexpr Element& front() const noexcept { return *m_data; }
constexpr Element& back() const noexcept { return *(m_data + (m_size - 1)); }
constexpr Element& operator[](std::size_t idx) const { return *(m_data + idx); }
Element& at(std::size_t idx) const { return *(m_data + idx); }
Span subspan(std::size_t offset,
std::size_t size = std::numeric_limits<std::size_t>::max())
{
if (offset > m_size)
{
return {};
}
return {m_data + offset, std::min(size, m_size - offset)};
}
private:
Element* m_data{nullptr};
std::size_t m_size{0};
};
template <typename Iterator,
typename value = typename details::from_iterator<Iterator>::stored_value,
details::Required<details::IsRandomAccessIt<Iterator>::value> = true>
constexpr auto span(Iterator first, Iterator second) -> Span<value>
{
return Span<value>{
std::addressof(*first),
static_cast<typename Span<value>::size_type>(std::distance(first, second))};
}
template <typename Container,
// check if Container has contiguous range memory
typename = details::void_t<decltype(std::declval<Container>().data()),
decltype(std::declval<Container>().size())>>
constexpr auto span(const Container& c) -> Span<const typename Container::value_type>
{
return {c.data(), c.size()};
}
template <typename Container,
// check if Container has contiguous range memory
typename = details::void_t<decltype(std::declval<Container>().data()),
decltype(std::declval<Container>().size())>>
constexpr auto span(Container& c) -> Span<typename Container::value_type>
{
return {c.data(), c.size()};
}
template <typename Element>
constexpr auto span(const Element* data, std::size_t size) -> Span<const Element>
{
return {data, size};
}
template <typename Element>
constexpr auto span(Element* data, std::size_t size) -> Span<Element>
{
return {data, size};
}
} // namespace reference
} // namespace runtime
} // namespace ngraph
......@@ -72,19 +72,95 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_axis_0_2d_input)
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
test_case.add_input<int32_t>({0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2});
// clang-format off
test_case.add_input<float>({1.0f, 1.1f,
2.0f, 2.1f,
3.0f, 3.1f});
test_case.add_input<int32_t>({0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2});
test_case.add_expected_output<float>(
out_shape,
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f,
3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f,
2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f,
1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f,
3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
{ 1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -100,14 +176,50 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_3d_indices_axis_0_2d_input)
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
// clang-format off
test_case.add_input<float>({1.0f, 1.1f,
2.0f, 2.1f,
3.0f, 3.1f});
test_case.add_input<int32_t>(
{0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2});
{0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2,
0, 1, 1, 2});
test_case.add_expected_output<float>(
out_shape, {1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f,
1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f,
2.0f, 2.1f, 3.0f, 3.1f, 1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
out_shape, {1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f,
1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -123,10 +235,20 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_axis_0_2d_input)
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
// clang-format off
test_case.add_input<float>({1.0f, 1.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.add_input<int32_t>({0, 1, 1, 2});
// clang-format off
test_case.add_expected_output<float>(out_shape,
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
{1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -142,10 +264,24 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_negative_and_positive_indices_axis_0_2d_i
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 2.0f, 2.1f, 3.0f, 3.1f});
// clang-format off
test_case.add_input<float>({1.0f, 1.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.add_input<int32_t>({0, -2, 1, 2});
// clang-format off
test_case.add_expected_output<float>(out_shape,
{1.0f, 1.1f, 2.0f, 2.1f, 2.0f, 2.1f, 3.0f, 3.1f});
{1.0f, 1.1f,
2.0f, 2.1f,
2.0f, 2.1f,
3.0f, 3.1f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -197,9 +333,19 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_axis_1_2d_input)
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f});
// clang-format off
test_case.add_input<float>({1.0f, 1.1f, 1.2f,
2.0f, 2.1f, 2.2f,
3.0f, 3.1f, 3.2f});
// clang-format on
test_case.add_input<int32_t>({0, 2});
test_case.add_expected_output<float>(out_shape, {1.0f, 1.2f, 2.0f, 2.2f, 3.0f, 3.2f});
// clang-format off
test_case.add_expected_output<float>(out_shape, {1.0f, 1.2f,
2.0f, 2.2f,
3.0f, 3.2f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -215,14 +361,40 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_1d_indices_axis_2_4d_input)
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>({1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f});
// clang-format off
test_case.add_input<float>({ 1.0f, 1.1f, 1.2f,
2.0f, 2.1f, 2.2f,
3.0f, 3.1f, 3.2f,
11.0f, 11.1f, 11.2f,
12.0f, 12.1f, 12.2f,
13.0f, 13.1f, 13.2f,
101.0f, 101.1f, 101.2f,
102.0f, 102.1f, 102.2f,
103.0f, 103.1f, 103.2f,
111.0f, 111.1f, 111.2f,
112.0f, 112.1f, 112.2f,
113.0f, 113.1f, 113.2f});
// clang-format on
test_case.add_input<int32_t>({0, 2});
// clang-format off
test_case.add_expected_output<float>(
out_shape, {1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f});
out_shape, { 1.0f, 1.1f, 1.2f,
3.0f, 3.1f, 3.2f,
11.0f, 11.1f, 11.2f,
13.0f, 13.1f, 13.2f,
101.0f, 101.1f, 101.2f,
103.0f, 103.1f, 103.2f,
111.0f, 111.1f, 111.2f,
113.0f, 113.1f, 113.2f});
// clang-format on
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
......@@ -404,4 +576,4 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_axis_0_bool)
test_case.add_input<int64_t>({0, 1, 1, 2});
test_case.add_expected_output<char>(out_shape, {1, 1, 1, 0, 1, 0, 0, 1});
test_case.run(MIN_FLOAT_TOLERANCE_BITS);
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment