Fix zip_test_arg_min_max
This commit is contained in:
parent
d41c879e5c
commit
23603971df
@ -38,6 +38,7 @@ def make_arg_min_max_tests(options):
|
||||
],
|
||||
"output_type": [tf.int32, tf.int64],
|
||||
"is_arg_max": [True],
|
||||
"is_last_axis": [False],
|
||||
"dynamic_range_quantize": [False, True],
|
||||
},
|
||||
{
|
||||
@ -48,7 +49,7 @@ def make_arg_min_max_tests(options):
|
||||
],
|
||||
"output_type": [tf.int32, tf.int64],
|
||||
"is_arg_max": [False, True],
|
||||
"axis": [-1],
|
||||
"is_last_axis": [True],
|
||||
"dynamic_range_quantize": [False, True],
|
||||
},
|
||||
]
|
||||
@ -59,10 +60,10 @@ def make_arg_min_max_tests(options):
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
if "axis" not in parameters:
|
||||
if not parameters["is_last_axis"]:
|
||||
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
|
||||
else:
|
||||
axis = parameters["axis"]
|
||||
axis = -1
|
||||
if parameters["is_arg_max"]:
|
||||
out = tf.math.argmax(
|
||||
input_value, axis, output_type=parameters["output_type"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user