From da3013c1fc53df9e0252c7bc8005ab69688400b8 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Tue, 11 Jun 2019 13:29:21 -0700 Subject: [PATCH] Add a batch-1 test case for batch_to_space_nd PiperOrigin-RevId: 252686434 --- tensorflow/lite/kernels/batch_to_space_nd_test.cc | 8 ++++++++ tensorflow/lite/testing/generate_examples_lib.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/kernels/batch_to_space_nd_test.cc index d723ed36c94..1ecdae1b8ac 100644 --- a/tensorflow/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd_test.cc @@ -121,6 +121,14 @@ TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) { {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16})); } +TEST(BatchToSpaceNDOpTest, BatchOneConstTest) { + BatchToSpaceNDOpConstModel m({1, 2, 2, 1}, {1, 1}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); +} + TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 3ab634a029a..a0d51100a1a 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -2875,7 +2875,16 @@ def make_batch_to_space_nd_tests(options): "constant_block_shape": [True, False], "constant_crops": [True, False], }, - # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. + # Single batch (no-op) + { + "dtype": [tf.float32], + "input_shape": [[1, 3, 3, 1]], + "block_shape": [[1, 1]], + "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True], + "constant_crops": [True], + }, + # Non-4D use case: 1 batch dimension, 3 spatial dimensions, 2 others. { "dtype": [tf.float32], "input_shape": [[8, 2, 2, 2, 1, 1]],