diff --git a/tensorflow/lite/kernels/broadcast_to_test.cc b/tensorflow/lite/kernels/broadcast_to_test.cc index a36ed352055..9a352df62f0 100644 --- a/tensorflow/lite/kernels/broadcast_to_test.cc +++ b/tensorflow/lite/kernels/broadcast_to_test.cc @@ -214,6 +214,47 @@ TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) { 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12})); } +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DConstTest) { + BroadcastToOpModel m({1, 3, 1, 2, 1, 4, 1, 1}, {8}, + {2, 3, 1, 2, 2, 4, 1, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1, 2, 2, 4, 1, 1})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, + 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, + 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, + 23, 24, 21, 22, 23, 24, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, + 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, 19, 20, + 17, 18, 19, 20, 21, 22, 23, 24, 21, 22, 23, 24})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast8DDynamicTest) { + BroadcastToOpModel m({2, 1, 1, 2, 1, 4, 1, 1}, {8}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetShape({2, 3, 2, 2, 2, 4, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2, 2, 2, 4, 1, 1})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray( + {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16, + 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, 16})); +} + TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) { BroadcastToOpModel m({3, 1, 2}, {4}, {3, 3, 2, 2}); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -233,6 +274,15 @@ TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } +TYPED_TEST(BroadcastToOpTest, NoBroadcasting8DConstTest) { + BroadcastToOpModel m({3, 1, 1, 1, 1, 1, 1, 2}, {8}, + {3, 1, 1, 1, 1, 1, 1, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1, 1, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) { BroadcastToOpModel m({1, 1, 1, 1, 1, 1, 2, 1}, {8}, {1, 1, 1, 1, 1, 1, 2, 2}); diff --git a/tensorflow/lite/kernels/internal/reference/broadcast_to.h b/tensorflow/lite/kernels/internal/reference/broadcast_to.h index 09ffa704cca..f106b2b52f6 100644 --- a/tensorflow/lite/kernels/internal/reference/broadcast_to.h +++ b/tensorflow/lite/kernels/internal/reference/broadcast_to.h @@ -72,14 +72,21 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape, // Get the last dimension has broadcasting. At this dimension, the data is // copied from input tensor to output tensor. - int last_broadcast_dim = 0; - for (int i = N - 1; i > 0; --i) { + int last_broadcast_dim = -1; + for (int i = N - 1; i >= 0; --i) { if (input_desc.extents[i] != output_desc.extents[i]) { last_broadcast_dim = i; break; } } + // If non-broadcasting, just copy data from input to output tensor. + if (last_broadcast_dim == -1) { + memcpy(output_data, input_data, + unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type)); + return; + } + // Broadcasting using memcpy. int indexes[N] = {0}; BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, 0,