From 79bb3858b911e15dda69576050ed3ed959d95d71 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Sun, 17 Feb 2019 00:01:05 -0800 Subject: [PATCH] Add more test cases for gather_nd in generate_examples.py. PiperOrigin-RevId: 234346858 --- tensorflow/lite/testing/generate_examples.py | 26 +++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py index e8a3efdef32..a7e16e641dd 100644 --- a/tensorflow/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -1388,12 +1388,26 @@ def make_gather_tests(zip_path): def make_gather_nd_tests(zip_path): """Make a set of tests to do gather_nd.""" - test_parameters = [{ - "params_dtype": [tf.float32, tf.int32, tf.int64], - "params_shape": [[5, 5, 10]], - "indices_dtype": [tf.int32, tf.int64], - "indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]], - }] + test_parameters = [ + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 1]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[1, 1]], + }, + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 5]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[2, 1], [2, 2]], + }, + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 5, 10]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]], + }, + ] def build_graph(parameters): """Build the gather_nd op testing graph."""