Fix BroadcastTo logic to handle 8D non-broadcasting case
The error happen with 8D inputs when non-broadcasting is requested PiperOrigin-RevId: 340586767 Change-Id: I108d54fe40291705c7230206fa955853ad831c30
This commit is contained in:
parent
d77d8b3f4a
commit
d895da1ac4
@ -214,6 +214,47 @@ TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
|
|||||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
|
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DConstTest) {
|
||||||
|
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2, 1, 4, 1, 1}, {8},
|
||||||
|
{2, 3, 1, 2, 2, 4, 1, 1});
|
||||||
|
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
|
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1, 2, 2, 4, 1, 1}));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetOutput(),
|
||||||
|
ElementsAreArray({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6,
|
||||||
|
7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22,
|
||||||
|
23, 24, 21, 22, 23, 24, 1, 2, 3, 4, 1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10,
|
||||||
|
11, 12, 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20, 21, 22, 23, 24, 21, 22, 23, 24}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DDynamicTest) {
|
||||||
|
BroadcastToOpModel<TypeParam> m({2, 1, 1, 2, 1, 4, 1, 1}, {8});
|
||||||
|
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
|
m.SetShape({2, 3, 2, 2, 2, 4, 1, 1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2, 2, 2, 4, 1, 1}));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetOutput(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||||
|
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16}));
|
||||||
|
}
|
||||||
|
|
||||||
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
|
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
|
||||||
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
|
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
|
||||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||||
@ -233,6 +274,15 @@ TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(BroadcastToOpTest, NoBroadcasting8DConstTest) {
|
||||||
|
BroadcastToOpModel<TypeParam> m({3, 1, 1, 1, 1, 1, 1, 2}, {8},
|
||||||
|
{3, 1, 1, 1, 1, 1, 1, 2});
|
||||||
|
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1, 1, 1, 2}));
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
|
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
|
||||||
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
|
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
|
||||||
{1, 1, 1, 1, 1, 1, 2, 2});
|
{1, 1, 1, 1, 1, 1, 2, 2});
|
||||||
|
@ -72,14 +72,21 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
|
|||||||
|
|
||||||
// Get the last dimension has broadcasting. At this dimension, the data is
|
// Get the last dimension has broadcasting. At this dimension, the data is
|
||||||
// copied from input tensor to output tensor.
|
// copied from input tensor to output tensor.
|
||||||
int last_broadcast_dim = 0;
|
int last_broadcast_dim = -1;
|
||||||
for (int i = N - 1; i > 0; --i) {
|
for (int i = N - 1; i >= 0; --i) {
|
||||||
if (input_desc.extents[i] != output_desc.extents[i]) {
|
if (input_desc.extents[i] != output_desc.extents[i]) {
|
||||||
last_broadcast_dim = i;
|
last_broadcast_dim = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If non-broadcasting, just copy data from input to output tensor.
|
||||||
|
if (last_broadcast_dim == -1) {
|
||||||
|
memcpy(output_data, input_data,
|
||||||
|
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Broadcasting using memcpy.
|
// Broadcasting using memcpy.
|
||||||
int indexes[N] = {0};
|
int indexes[N] = {0};
|
||||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
|
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
|
||||||
|
Loading…
Reference in New Issue
Block a user