Add more test cases for gather_nd in generate_examples.py.
PiperOrigin-RevId: 234346858
This commit is contained in:
parent
bf6c9fbfc2
commit
79bb3858b9
@ -1388,12 +1388,26 @@ def make_gather_tests(zip_path):
|
||||
def make_gather_nd_tests(zip_path):
|
||||
"""Make a set of tests to do gather_nd."""
|
||||
|
||||
test_parameters = [{
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[5, 5, 10]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]],
|
||||
}]
|
||||
test_parameters = [
|
||||
{
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[5, 1]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[1, 1]],
|
||||
},
|
||||
{
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[5, 5]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[2, 1], [2, 2]],
|
||||
},
|
||||
{
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[5, 5, 10]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the gather_nd op testing graph."""
|
||||
|
Loading…
Reference in New Issue
Block a user