From 5ce881f411542508eb1de0a8abb021c166694e81 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 3 Feb 2020 13:52:24 -0800 Subject: [PATCH] Add quantized test cases for expand_dims operator. PiperOrigin-RevId: 292996653 Change-Id: Iafc2c802dedab1b13b95d78d8fd086069b876317 --- tensorflow/lite/testing/BUILD | 1 + tensorflow/lite/testing/op_tests/expand_dims.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 0c898ac4f25..631dc05fd6a 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -541,6 +541,7 @@ edgetpu_ops = [ "conv_relu1", "conv_relu6", "depthwiseconv", # high error + "expand_dims", "fully_connected", "l2norm", # high error "maximum", diff --git a/tensorflow/lite/testing/op_tests/expand_dims.py b/tensorflow/lite/testing/op_tests/expand_dims.py index 2f7968d6673..45ad9c6f97c 100644 --- a/tensorflow/lite/testing/op_tests/expand_dims.py +++ b/tensorflow/lite/testing/op_tests/expand_dims.py @@ -30,9 +30,16 @@ def make_expand_dims_tests(options): test_parameters = [{ "input_type": [tf.float32, tf.int32], - "input_shape": [[5, 4]], + "input_shape": [[5, 4], [1, 5, 4]], "axis_value": [0, 1, 2, -1, -2, -3], "constant_axis": [True, False], + "fully_quantize": [False], + }, { + "input_type": [tf.float32], + "input_shape": [[5, 4], [1, 5, 4]], + "axis_value": [0, 1, 2, -1, -2, -3], + "constant_axis": [True], + "fully_quantize": [True], }] def build_graph(parameters): @@ -56,9 +63,14 @@ def make_expand_dims_tests(options): return inputs, [out] def build_inputs(parameters, sess, inputs, outputs): + """Builds the inputs for expand_dims.""" input_values = [] input_values.append( - create_tensor_data(parameters["input_type"], parameters["input_shape"])) + create_tensor_data( + parameters["input_type"], + parameters["input_shape"], + min_value=-1, + max_value=1)) if not parameters["constant_axis"]: input_values.append(np.array([parameters["axis_value"]], dtype=np.int32)) return input_values, sess.run(