Fix zip_test_arg_min_max

This commit is contained in:
Tzu-Wei Sung 2021-03-05 15:52:21 -08:00
parent d41c879e5c
commit 23603971df

View File

@ -38,6 +38,7 @@ def make_arg_min_max_tests(options):
],
"output_type": [tf.int32, tf.int64],
"is_arg_max": [True],
"is_last_axis": [False],
"dynamic_range_quantize": [False, True],
},
{
@ -48,7 +49,7 @@ def make_arg_min_max_tests(options):
],
"output_type": [tf.int32, tf.int64],
"is_arg_max": [False, True],
"axis": [-1],
"is_last_axis": [True],
"dynamic_range_quantize": [False, True],
},
]
@ -59,10 +60,10 @@ def make_arg_min_max_tests(options):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
if "axis" not in parameters:
if not parameters["is_last_axis"]:
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
else:
axis = parameters["axis"]
axis = -1
if parameters["is_arg_max"]:
out = tf.math.argmax(
input_value, axis, output_type=parameters["output_type"])