Fix BroadcastTo logic to handle 8D non-broadcasting case

The error happen with 8D inputs when non-broadcasting is requested

PiperOrigin-RevId: 340586767
Change-Id: I108d54fe40291705c7230206fa955853ad831c30
This commit is contained in:
Thai Nguyen 2020-11-03 22:19:20 -08:00 committed by TensorFlower Gardener
parent d77d8b3f4a
commit d895da1ac4
2 changed files with 59 additions and 2 deletions

View File

@ -214,6 +214,47 @@ TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2, 1, 4, 1, 1}, {8},
{2, 3, 1, 2, 2, 4, 1, 1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1, 2, 2, 4, 1, 1}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6,
7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16,
13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22,
23, 24, 21, 22, 23, 24, 1, 2, 3, 4, 1, 2, 3, 4,
5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10,
11, 12, 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, 19, 20,
17, 18, 19, 20, 21, 22, 23, 24, 21, 22, 23, 24}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DDynamicTest) {
BroadcastToOpModel<TypeParam> m({2, 1, 1, 2, 1, 4, 1, 1}, {8});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetShape({2, 3, 2, 2, 2, 4, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2, 2, 2, 4, 1, 1}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray(
{1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16}));
}
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
@ -233,6 +274,15 @@ TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, NoBroadcasting8DConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 1, 1, 1, 1, 1, 2}, {8},
{3, 1, 1, 1, 1, 1, 1, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1, 1, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
{1, 1, 1, 1, 1, 1, 2, 2});

View File

@ -72,14 +72,21 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
// Get the last dimension has broadcasting. At this dimension, the data is
// copied from input tensor to output tensor.
int last_broadcast_dim = 0;
for (int i = N - 1; i > 0; --i) {
int last_broadcast_dim = -1;
for (int i = N - 1; i >= 0; --i) {
if (input_desc.extents[i] != output_desc.extents[i]) {
last_broadcast_dim = i;
break;
}
}
// If non-broadcasting, just copy data from input to output tensor.
if (last_broadcast_dim == -1) {
memcpy(output_data, input_data,
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
return;
}
// Broadcasting using memcpy.
int indexes[N] = {0};
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,