Add more test cases for gather_nd in generate_examples.py.

PiperOrigin-RevId: 234346858
This commit is contained in:
Haoliang Zhang 2019-02-17 00:01:05 -08:00 committed by TensorFlower Gardener
parent bf6c9fbfc2
commit 79bb3858b9

View File

@ -1388,12 +1388,26 @@ def make_gather_tests(zip_path):
def make_gather_nd_tests(zip_path):
"""Make a set of tests to do gather_nd."""
test_parameters = [{
test_parameters = [
{
"params_dtype": [tf.float32, tf.int32, tf.int64],
"params_shape": [[5, 1]],
"indices_dtype": [tf.int32, tf.int64],
"indices_shape": [[1, 1]],
},
{
"params_dtype": [tf.float32, tf.int32, tf.int64],
"params_shape": [[5, 5]],
"indices_dtype": [tf.int32, tf.int64],
"indices_shape": [[2, 1], [2, 2]],
},
{
"params_dtype": [tf.float32, tf.int32, tf.int64],
"params_shape": [[5, 5, 10]],
"indices_dtype": [tf.int32, tf.int64],
"indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]],
}]
},
]
def build_graph(parameters):
"""Build the gather_nd op testing graph."""