Fix BroadcastTo logic to handle 8D non-broadcasting case
The error happen with 8D inputs when non-broadcasting is requested PiperOrigin-RevId: 340586767 Change-Id: I108d54fe40291705c7230206fa955853ad831c30
This commit is contained in:
parent
d77d8b3f4a
commit
d895da1ac4
@ -214,6 +214,47 @@ TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
|
||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2, 1, 4, 1, 1}, {8},
|
||||
{2, 3, 1, 2, 2, 4, 1, 1});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1, 2, 2, 4, 1, 1}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6,
|
||||
7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 21, 22, 23, 24, 1, 2, 3, 4, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 21, 22, 23, 24}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({2, 1, 1, 2, 1, 4, 1, 1}, {8});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetShape({2, 3, 2, 2, 2, 4, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2, 2, 2, 4, 1, 1}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray(
|
||||
{1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16,
|
||||
9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
@ -233,6 +274,15 @@ TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, NoBroadcasting8DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({3, 1, 1, 1, 1, 1, 1, 2}, {8},
|
||||
{3, 1, 1, 1, 1, 1, 1, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1, 1, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
|
||||
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
|
||||
{1, 1, 1, 1, 1, 1, 2, 2});
|
||||
|
@ -72,14 +72,21 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
|
||||
|
||||
// Get the last dimension has broadcasting. At this dimension, the data is
|
||||
// copied from input tensor to output tensor.
|
||||
int last_broadcast_dim = 0;
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
int last_broadcast_dim = -1;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
if (input_desc.extents[i] != output_desc.extents[i]) {
|
||||
last_broadcast_dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If non-broadcasting, just copy data from input to output tensor.
|
||||
if (last_broadcast_dim == -1) {
|
||||
memcpy(output_data, input_data,
|
||||
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
|
||||
return;
|
||||
}
|
||||
|
||||
// Broadcasting using memcpy.
|
||||
int indexes[N] = {0};
|
||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
|
||||
|
Loading…
Reference in New Issue
Block a user