Add quantized test cases for expand_dims operator.

PiperOrigin-RevId: 292996653
Change-Id: Iafc2c802dedab1b13b95d78d8fd086069b876317
This commit is contained in:
A. Unique TensorFlower 2020-02-03 13:52:24 -08:00 committed by TensorFlower Gardener
parent 8fca65f324
commit 5ce881f411
2 changed files with 15 additions and 2 deletions

View File

@ -541,6 +541,7 @@ edgetpu_ops = [
"conv_relu1",
"conv_relu6",
"depthwiseconv", # high error
"expand_dims",
"fully_connected",
"l2norm", # high error
"maximum",

View File

@ -30,9 +30,16 @@ def make_expand_dims_tests(options):
test_parameters = [{
"input_type": [tf.float32, tf.int32],
"input_shape": [[5, 4]],
"input_shape": [[5, 4], [1, 5, 4]],
"axis_value": [0, 1, 2, -1, -2, -3],
"constant_axis": [True, False],
"fully_quantize": [False],
}, {
"input_type": [tf.float32],
"input_shape": [[5, 4], [1, 5, 4]],
"axis_value": [0, 1, 2, -1, -2, -3],
"constant_axis": [True],
"fully_quantize": [True],
}]
def build_graph(parameters):
@ -56,9 +63,14 @@ def make_expand_dims_tests(options):
return inputs, [out]
def build_inputs(parameters, sess, inputs, outputs):
"""Builds the inputs for expand_dims."""
input_values = []
input_values.append(
create_tensor_data(parameters["input_type"], parameters["input_shape"]))
create_tensor_data(
parameters["input_type"],
parameters["input_shape"],
min_value=-1,
max_value=1))
if not parameters["constant_axis"]:
input_values.append(np.array([parameters["axis_value"]], dtype=np.int32))
return input_values, sess.run(