Add more test cases for gather_nd in generate_examples.py.
PiperOrigin-RevId: 234346858
This commit is contained in:
parent
bf6c9fbfc2
commit
79bb3858b9
@ -1388,12 +1388,26 @@ def make_gather_tests(zip_path):
|
|||||||
def make_gather_nd_tests(zip_path):
|
def make_gather_nd_tests(zip_path):
|
||||||
"""Make a set of tests to do gather_nd."""
|
"""Make a set of tests to do gather_nd."""
|
||||||
|
|
||||||
test_parameters = [{
|
test_parameters = [
|
||||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
{
|
||||||
"params_shape": [[5, 5, 10]],
|
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||||
"indices_dtype": [tf.int32, tf.int64],
|
"params_shape": [[5, 1]],
|
||||||
"indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]],
|
"indices_dtype": [tf.int32, tf.int64],
|
||||||
}]
|
"indices_shape": [[1, 1]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||||
|
"params_shape": [[5, 5]],
|
||||||
|
"indices_dtype": [tf.int32, tf.int64],
|
||||||
|
"indices_shape": [[2, 1], [2, 2]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||||
|
"params_shape": [[5, 5, 10]],
|
||||||
|
"indices_dtype": [tf.int32, tf.int64],
|
||||||
|
"indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
"""Build the gather_nd op testing graph."""
|
"""Build the gather_nd op testing graph."""
|
||||||
|
Loading…
Reference in New Issue
Block a user