From bfa9c754f574c61e38e12573a2f63742734f1346 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 Oct 2019 15:59:29 -0700 Subject: [PATCH] Add more single-op tests with quantize option. PiperOrigin-RevId: 276369706 Change-Id: I0e832d8644809c7ee8bd50fa7d03b114fc80e56e --- tensorflow/lite/testing/BUILD | 39 +++- .../testing/generated_examples_zip_test.cc | 205 +++++++++++------- tensorflow/lite/testing/op_tests/binary_op.py | 82 ++++++- tensorflow/lite/testing/op_tests/concat.py | 16 +- .../lite/testing/op_tests/depthwiseconv.py | 33 ++- .../lite/testing/op_tests/fully_connected.py | 58 ++++- tensorflow/lite/testing/op_tests/l2norm.py | 10 +- tensorflow/lite/testing/op_tests/pad.py | 37 +++- tensorflow/lite/testing/op_tests/pool.py | 54 ++++- tensorflow/lite/testing/op_tests/reduce.py | 48 +++- tensorflow/lite/testing/op_tests/reshape.py | 16 +- .../lite/testing/op_tests/resize_bilinear.py | 14 +- tensorflow/lite/testing/op_tests/slice.py | 102 +++++++-- tensorflow/lite/testing/op_tests/softmax.py | 9 +- .../lite/testing/op_tests/space_to_depth.py | 13 +- tensorflow/lite/testing/op_tests/split.py | 8 +- tensorflow/lite/testing/op_tests/squeeze.py | 31 ++- .../lite/testing/op_tests/strided_slice.py | 40 +++- tensorflow/lite/testing/zip_test_utils.py | 18 +- 19 files changed, 665 insertions(+), 168 deletions(-) diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index dce1b353c40..41b70ea0889 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -22,13 +22,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -gen_zipped_test_file( - name = "zip_conv_edgetpu", - file = "conv_edgetpu.zip", - flags = " --make_edgetpu_tests", - toco = "//tensorflow/lite/toco:toco", # Unused -) - [gen_zip_test( name = "zip_test_%s" % test_name, size = "medium", @@ -485,3 +478,35 @@ tf_py_wrap_cc( ) tflite_portable_test_suite() + +edgetpu_ops = [ + "conv", # high error + "fully_connected", + "softmax", + "reshape", + "add", + "mul", + "sub", + "avg_pool", + "max_pool", + "concat", + "resize_bilinear", + "l2norm", # high error + "sum", # high error + "depthwiseconv", # high error + "space_to_depth", + "split", + "squeeze", + "pad", # high error + "slice", + "strided_slice", +] + +[gen_zipped_test_file( + name = "zip_%s_edgetpu" % op_name, + file = "%s_edgetpu.zip" % op_name, + flags = " --make_edgetpu_tests", + toco = "//tensorflow/lite/toco:toco", # Unused +) for op_name in edgetpu_ops] + +edgetpu_targets = [":zip_%s_edgetpu" % op_name for op_name in edgetpu_ops] diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc index df77b94aeab..7068f281ddf 100644 --- a/tensorflow/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/lite/testing/generated_examples_zip_test.cc @@ -57,59 +57,66 @@ tensorflow::Env* env = tensorflow::Env::Default(); // --test_arg=--ignore_known_bugs=false // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. -std::map kBrokenTests = { - // L2Norm only supports tensors with 4D or fewer. - {R"(^\/l2norm.*_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, +const std::map& GetKnownBrokenTests() { + static const std::map* const kBrokenTests = new std::map< + string, string>({ + // L2Norm only supports tensors with 4D or fewer. + {R"(^\/l2norm.*_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", + "67963684"}, - // SpaceToBatchND only supports 4D tensors. - {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, + // SpaceToBatchND only supports 4D tensors. + {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, - // BatchToSpaceND only supports 4D tensors. - {R"(^\/batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\])", "70848787"}, + // BatchToSpaceND only supports 4D tensors. + {R"(^\/batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\])", "70848787"}, - // L2Norm only works for dim=-1. - {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2norm.*_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", - "67963812"}, - {R"(^\/l2norm.*_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", - "67963812"}, + // L2Norm only works for dim=-1. + {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", + "67963812"}, + {R"(^\/l2norm.*_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm.*_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + "67963812"}, + {R"(^\/l2norm.*_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", + "67963812"}, - // ResizeBilinear looks completely incompatible with Tensorflow - {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, + // ResizeBilinear looks completely incompatible with Tensorflow + {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, - // Transpose only supports 1D-4D input tensors. - {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, + // Transpose only supports 1D-4D input tensors. + {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, - // Relu does not support int32. - // These test cases appends a Relu after the tested ops when - // activation=True. The tests are failing since Relu doesn't support int32. - {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"}, - {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"}, - {R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"}, - {R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"}, + // Relu does not support int32. + // These test cases appends a Relu after the tested ops when + // activation=True. The tests are failing since Relu doesn't support + // int32. + {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"}, + {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"}, + {R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"}, + {R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"}, - {R"(^\/sub.*dtype=tf\.int64)", "119126484"}, - {R"(^\/div.*dtype=tf\.int64)", "119126484"}, - {R"(^\/mul.*dtype=tf\.int64)", "119126484"}, - {R"(^\/add.*dtype=tf\.int64)", "119126484"}, - {R"(^\/floor_div.*dtype=tf\.int64)", "119126484"}, - {R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"}, + {R"(^\/sub.*dtype=tf\.int64)", "119126484"}, + {R"(^\/div.*dtype=tf\.int64)", "119126484"}, + {R"(^\/mul.*dtype=tf\.int64)", "119126484"}, + {R"(^\/add.*dtype=tf\.int64)", "119126484"}, + {R"(^\/floor_div.*dtype=tf\.int64)", "119126484"}, + {R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"}, - // Select kernel doesn't support broadcasting yet. - {R"(^\/where.*1,2,3,1)", "134692786"}, + // Select kernel doesn't support broadcasting yet. + {R"(^\/where.*1,2,3,1)", "134692786"}, - // Strided slice doesn't support ellipsis. - {R"(strided_slice.*Ellipsis)", "138098220"}, -}; + // Strided slice doesn't support ellipsis. + {R"(strided_slice.*Ellipsis)", "138098220"}, + }); + return *kBrokenTests; +} // Additional list of tests that are expected to fail when // --test_arg=--ignore_known_bugs=false @@ -119,19 +126,45 @@ std::map kBrokenTests = { // handled separately; this list is specifically for broken cases where // execution produces broken output. // Key is a substring of the test name and value is a bug number. -std::map kBrokenNnapiTests = { - // Certain NNAPI kernels silently fail with int32 types. - {R"(^\/add.*dtype=tf\.int32)", "122987564"}, - {R"(^\/concat.*dtype=tf\.int32)", "122987564"}, - {R"(^\/mul.*dtype=tf\.int32)", "122987564"}, - {R"(^\/space_to_depth.*dtype=tf\.int32)", "122987564"}, +const std::map& GetKnownBrokenNnapiTests() { + static const std::map* const kBrokenNnapiTests = + new std::map({ + // Certain NNAPI kernels silently fail with int32 types. + {R"(^\/add.*dtype=tf\.int32)", "122987564"}, + {R"(^\/concat.*dtype=tf\.int32)", "122987564"}, + {R"(^\/mul.*dtype=tf\.int32)", "122987564"}, + {R"(^\/space_to_depth.*dtype=tf\.int32)", "122987564"}, - // Certain NNAPI fully_connected shape permutations fail. - {R"(^\/fully_connected_constant_filter=True.*shape1=\[3,3\])", "122987564"}, - {R"(^\/fully_connected_constant_filter=True.*shape1=\[4,4\])", "122987564"}, - {R"(^\/fully_connected.*shape1=\[3,3\].*transpose_b=True)", "122987564"}, - {R"(^\/fully_connected.*shape1=\[4,4\].*shape2=\[4,1\])", "122987564"}, -}; + // Certain NNAPI fully_connected shape permutations fail. + {R"(^\/fully_connected_constant_filter=True.*shape1=\[3,3\])", + "122987564"}, + {R"(^\/fully_connected_constant_filter=True.*shape1=\[4,4\])", + "122987564"}, + {R"(^\/fully_connected.*shape1=\[3,3\].*transpose_b=True)", + "122987564"}, + {R"(^\/fully_connected.*shape1=\[4,4\].*shape2=\[4,1\])", + "122987564"}, + }); + return *kBrokenNnapiTests; +} + +// List of quantize tests that are probably to fail. +// Quantized tflite models has high diff error with tensorflow models. +// Key is a substring of the test name and value is a bug number. +// TODO(b/134594898): Remove these bugs and corresponding codes or move them to +// kBrokenTests after b/134594898 is fixed. +const std::map& GetKnownQuantizeBrokenTests() { + static const std::map* const kQuantizeBrokenTests = + new std::map({ + {R"(^\/conv.*fully_quantize=True)", "134594898"}, + {R"(^\/depthwiseconv.*fully_quantize=True)", "134594898"}, + {R"(^\/mean.*fully_quantize=True)", "134594898"}, + {R"(^\/pad.*fully_quantize=True)", "134594898"}, + {R"(^\/sum.*fully_quantize=True)", "134594898"}, + {R"(^\/l2norm.*fully_quantize=True)", "134594898"}, + }); + return *kQuantizeBrokenTests; +} // Allows test data to be unarchived into a temporary directory and makes // sure those temporary directories are removed later. @@ -268,42 +301,60 @@ TEST_P(OpsTest, RunZipTests) { tflite::testing::TfLiteDriver test_driver( FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi : TfLiteDriver::DelegateType::kNone); - + bool fully_quantize = false; if (test_path.find("fully_quantize=True") != std::string::npos) { // TODO(b/134594898): Tighten this constraint. - test_driver.SetThreshold(5e-1f, 4e-1f); + test_driver.SetThreshold(0.2, 0.1); + fully_quantize = true; } test_driver.SetModelBaseDir(tflite_dir); - auto broken_tests = kBrokenTests; + auto broken_tests = GetKnownBrokenTests(); if (FLAGS_use_nnapi) { + auto kBrokenNnapiTests = GetKnownBrokenNnapiTests(); broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end()); } - - string bug_number; - for (const auto& p : broken_tests) { - if (RE2::PartialMatch(test_name, p.first)) { - bug_number = p.second; - } - } + auto quantize_broken_tests = GetKnownQuantizeBrokenTests(); bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver); string message = test_driver.GetErrorMessage(); - if (bug_number.empty()) { - if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { - EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; + + if (!fully_quantize) { + string bug_number; + for (const auto& p : broken_tests) { + if (RE2::PartialMatch(test_name, p.first)) { + bug_number = p.second; + break; + } + } + if (bug_number.empty()) { + if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { + EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; + } else { + EXPECT_TRUE(result) << message; + } } else { - EXPECT_TRUE(result) << message; + if (FLAGS_ignore_known_bugs) { + EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " + "you can mark http://b/" + << bug_number << " as fixed! Yay!"; + } else { + EXPECT_TRUE(result) + << message << ": Possibly due to http://b/" << bug_number; + } } } else { - if (FLAGS_ignore_known_bugs) { - EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " - "you can mark http://b/" - << bug_number << " as fixed! Yay!"; - } else { - EXPECT_TRUE(result) << message << ": Possibly due to http://b/" - << bug_number; + if (!result) { + string bug_number; + // See if the tests are potential quantize failures. + for (const auto& p : quantize_broken_tests) { + if (RE2::PartialMatch(test_name, p.first)) { + bug_number = p.second; + break; + } + } + EXPECT_FALSE(bug_number.empty()); } } } diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py index 550f465407c..260ec145e70 100644 --- a/tensorflow/lite/testing/op_tests/binary_op.py +++ b/tensorflow/lite/testing/op_tests/binary_op.py @@ -23,7 +23,10 @@ from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests from tensorflow.lite.testing.zip_test_utils import register_make_test_function -def make_binary_op_tests(options, binary_operator, expected_tf_failures=0): +def make_binary_op_tests(options, + binary_operator, + allow_fully_quantize=False, + expected_tf_failures=0): """Make a set of tests to do binary ops with and without broadcast.""" test_parameters = [ @@ -33,39 +36,88 @@ def make_binary_op_tests(options, binary_operator, expected_tf_failures=0): "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[1, 3, 4, 3]], "activation": [True], + "fully_quantize": [False], }, { "dtype": [tf.float32], "input_shape_1": [[5]], "input_shape_2": [[5]], "activation": [False, True], + "fully_quantize": [False], }, { "dtype": [tf.float32, tf.int32, tf.int64], "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[3]], "activation": [True, False], + "fully_quantize": [False], }, { "dtype": [tf.float32, tf.int32], "input_shape_1": [[3]], "input_shape_2": [[1, 3, 4, 3]], "activation": [True, False], + "fully_quantize": [False], }, { "dtype": [tf.float32], "input_shape_1": [[]], "input_shape_2": [[]], "activation": [False], + "fully_quantize": [False], }, { "dtype": [tf.float32], "input_shape_1": [[0]], "input_shape_2": [[1]], "activation": [False], - } + "fully_quantize": [False], + }, + { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + "activation": [False], + "fully_quantize": [True], + }, + { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + "activation": [False], + "fully_quantize": [True], + }, + { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + "activation": [False], + "fully_quantize": [True], + }, + { + "dtype": [tf.float32], + "input_shape_1": [[3]], + "input_shape_2": [[1, 3, 4, 3]], + "activation": [False], + "fully_quantize": [True], + }, + { + "dtype": [tf.float32], + "input_shape_1": [[]], + "input_shape_2": [[]], + "activation": [False], + "fully_quantize": [True], + }, ] + # test_parameters include fully_quantize option only when + # allow_fully_quantize is True. + if not allow_fully_quantize: + test_parameters = [ + test_parameter for test_parameter in test_parameters + if True not in test_parameter["fully_quantize"] + ] + def build_graph(parameters): """Builds the graph given the current parameters.""" input1 = tf.compat.v1.placeholder( @@ -83,10 +135,22 @@ def make_binary_op_tests(options, binary_operator, expected_tf_failures=0): def build_inputs(parameters, sess, inputs, outputs): """Builds operand inputs for op.""" - input1 = create_tensor_data(parameters["dtype"], - parameters["input_shape_1"]) - input2 = create_tensor_data(parameters["dtype"], - parameters["input_shape_2"]) + if allow_fully_quantize: + input1 = create_tensor_data( + parameters["dtype"], + parameters["input_shape_1"], + min_value=-1, + max_value=1) + input2 = create_tensor_data( + parameters["dtype"], + parameters["input_shape_2"], + min_value=-1, + max_value=1) + else: + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) return [input1, input2], sess.run( outputs, feed_dict={ inputs[0]: input1, @@ -108,7 +172,7 @@ def make_binary_op_tests_func(binary_operator): @register_make_test_function() def make_add_tests(options): - make_binary_op_tests(options, tf.add) + make_binary_op_tests(options, tf.add, allow_fully_quantize=True) @register_make_test_function() @@ -118,12 +182,12 @@ def make_div_tests(options): @register_make_test_function() def make_sub_tests(options): - make_binary_op_tests(options, tf.subtract) + make_binary_op_tests(options, tf.subtract, allow_fully_quantize=True) @register_make_test_function() def make_mul_tests(options): - make_binary_op_tests(options, tf.multiply) + make_binary_op_tests(options, tf.multiply, allow_fully_quantize=True) @register_make_test_function() diff --git a/tensorflow/lite/testing/op_tests/concat.py b/tensorflow/lite/testing/op_tests/concat.py index 746afd4a02c..2c15318deda 100644 --- a/tensorflow/lite/testing/op_tests/concat.py +++ b/tensorflow/lite/testing/op_tests/concat.py @@ -32,6 +32,13 @@ def make_concat_tests(options): "num_tensors": [1, 2, 3, 4, 5, 6], "axis": [0, 1, 2, 3, -3, -2, -1], "type": [tf.float32, tf.uint8, tf.int32, tf.int64], + "fully_quantize": [False] + }, { + "base_shape": [[1, 3, 4, 3], [3, 4], [2, 3, 4, 3]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [1, 2, 3, -3, -2, -1], + "type": [tf.float32], + "fully_quantize": [True] }] def get_shape(parameters, delta): @@ -58,8 +65,11 @@ def make_concat_tests(options): def build_inputs(parameters, sess, inputs, outputs): all_values = [] for n in range(0, parameters["num_tensors"]): - input_values = create_tensor_data(parameters["type"], - get_shape(parameters, n)) + input_values = create_tensor_data( + parameters["type"], + get_shape(parameters, n), + min_value=-1, + max_value=1) all_values.append(input_values) return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) @@ -69,4 +79,4 @@ def make_concat_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=60) + expected_tf_failures=75) diff --git a/tensorflow/lite/testing/op_tests/depthwiseconv.py b/tensorflow/lite/testing/op_tests/depthwiseconv.py index 88d2ab44bd9..4741469388f 100644 --- a/tensorflow/lite/testing/op_tests/depthwiseconv.py +++ b/tensorflow/lite/testing/op_tests/depthwiseconv.py @@ -40,6 +40,7 @@ def make_depthwiseconv_tests(options): "padding": ["SAME", "VALID"], "data_format": ["NHWC"], "constant_filter": [True, False], + "fully_quantize": [False] }, { "input_shape": [[1, 3, 4, 3]], @@ -51,7 +52,20 @@ def make_depthwiseconv_tests(options): "padding": ["SAME"], "data_format": ["NHWC"], "constant_filter": [True, False], - } + "fully_quantize": [False] + }, + { + "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]], + "channel_multiplier": [1, 2], + "rate": [[1, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + "constant_filter": [True], + "fully_quantize": [True] + }, ] def get_tensor_shapes(parameters): @@ -88,12 +102,21 @@ def make_depthwiseconv_tests(options): return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - # Build list of input values either containing 1 tensor (input) or 2 tensors - # (input, filter) based on whether filter is constant or variable input. + # pylint: disable=g-doc-return-or-yield, g-doc-args + """Build list of input values. + + It either contains 1 tensor (input) or 2 tensors (input, filter) based on + whether filter is constant or variable input. + """ + input_shape, filter_shape = get_tensor_shapes(parameters) - values = [create_tensor_data(np.float32, input_shape)] + values = [ + create_tensor_data(np.float32, input_shape, min_value=-1, max_value=1) + ] if not parameters["constant_filter"]: - values.append(create_tensor_data(np.float32, filter_shape)) + values.append( + create_tensor_data( + np.float32, filter_shape, min_value=-1, max_value=1)) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests( diff --git a/tensorflow/lite/testing/op_tests/fully_connected.py b/tensorflow/lite/testing/op_tests/fully_connected.py index d4704465d6f..c5a7dc81fea 100644 --- a/tensorflow/lite/testing/op_tests/fully_connected.py +++ b/tensorflow/lite/testing/op_tests/fully_connected.py @@ -34,30 +34,63 @@ def make_fully_connected_tests(options): "transpose_a": [True, False], "transpose_b": [True, False], "constant_filter": [True, False], + "fully_quantize": [False], }, { "shape1": [[4, 4], [1, 4], [4]], "shape2": [[4, 4], [4, 1], [4]], "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + "fully_quantize": [False], }, { "shape1": [[40, 37]], "shape2": [[37, 40]], "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + "fully_quantize": [False], }, { "shape1": [[40, 37]], "shape2": [[40, 37]], "transpose_a": [False], "transpose_b": [True], "constant_filter": [True, False], + "fully_quantize": [False], }, { "shape1": [[5, 3]], "shape2": [[5, 3]], "transpose_a": [True], "transpose_b": [False], "constant_filter": [True, False], + "fully_quantize": [False], + }, { + "shape1": [[1, 3]], + "shape2": [[3, 3]], + "transpose_a": [False], + "transpose_b": [False], + "constant_filter": [True], + "fully_quantize": [True], + }, { + "shape1": [[1, 4], [4]], + "shape2": [[4, 4], [4, 1], [4]], + "transpose_a": [False], + "transpose_b": [False], + "constant_filter": [True], + "fully_quantize": [True], + }, { + "shape1": [[1, 37], [2, 37]], + "shape2": [[37, 40]], + "transpose_a": [False], + "transpose_b": [False], + "constant_filter": [True], + "fully_quantize": [True], + }, { + "shape1": [[1, 3], [2, 3]], + "shape2": [[3, 5], [3, 1]], + "transpose_a": [False], + "transpose_b": [False], + "constant_filter": [True], + "fully_quantize": [True], }] def build_graph(parameters): @@ -68,7 +101,8 @@ def make_fully_connected_tests(options): # Get input_tensor2 either as a placeholder or constants. Also get a list of # the input tensors that are represented as placeholders. if parameters["constant_filter"]: - input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + input_tensor2 = create_tensor_data( + np.float32, parameters["shape2"], min_value=-1, max_value=1) input_tensors = [input_tensor1] else: input_tensor2 = tf.compat.v1.placeholder( @@ -83,12 +117,22 @@ def make_fully_connected_tests(options): return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - # Build list of input values either containing 1 tensor (input_values1) or 2 - # tensors (input_values1, input_values2) based on whether the second input - # is a constant or variable input. - values = [create_tensor_data(np.float32, shape=parameters["shape1"])] + # pylint: disable=g-doc-return-or-yield, g-doc-args + """Build list of input values. + + It either contains 1 tensor (input_values1) or + 2 tensors (input_values1, input_values2) based on whether the second input + is a constant or variable input. + """ + + values = [ + create_tensor_data( + np.float32, shape=parameters["shape1"], min_value=-1, max_value=1) + ] if not parameters["constant_filter"]: - values.append(create_tensor_data(np.float32, parameters["shape2"])) + values.append( + create_tensor_data( + np.float32, parameters["shape2"], min_value=-1, max_value=1)) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests( @@ -96,4 +140,4 @@ def make_fully_connected_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=10) + expected_tf_failures=14) diff --git a/tensorflow/lite/testing/op_tests/l2norm.py b/tensorflow/lite/testing/op_tests/l2norm.py index 346810e7378..c83ec03f7ae 100644 --- a/tensorflow/lite/testing/op_tests/l2norm.py +++ b/tensorflow/lite/testing/op_tests/l2norm.py @@ -34,6 +34,12 @@ def make_l2norm_tests(options): [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], "dim": [0, 1, 2, 3, [2, 3], -2], "epsilon": [None, 1e-12, 1e-3], + "fully_quantize": [False], + }, { + "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], + "dim": [0, 1, 2, 3, [2, 3], -2], + "epsilon": [None, 1e-12, 1e-3], + "fully_quantize": [True], }] def build_graph(parameters): @@ -48,7 +54,7 @@ def make_l2norm_tests(options): def build_inputs(parameters, sess, inputs, outputs): input_values = create_tensor_data( - np.float32, parameters["input_shape"], min_value=-4, max_value=10) + np.float32, parameters["input_shape"], min_value=-1, max_value=1) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) @@ -57,4 +63,4 @@ def make_l2norm_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=9) + expected_tf_failures=18) diff --git a/tensorflow/lite/testing/op_tests/pad.py b/tensorflow/lite/testing/op_tests/pad.py index 610bfb27ebe..a136fee547b 100644 --- a/tensorflow/lite/testing/op_tests/pad.py +++ b/tensorflow/lite/testing/op_tests/pad.py @@ -37,6 +37,7 @@ def make_pad_tests(options): "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0], [0, 0], [2, 3]]], "constant_paddings": [True, False], + "fully_quantize": [False] }, # 2D: { @@ -44,6 +45,7 @@ def make_pad_tests(options): "input_shape": [[1, 2]], "paddings": [[[0, 1], [2, 3]]], "constant_paddings": [True, False], + "fully_quantize": [False] }, # 1D: { @@ -51,6 +53,33 @@ def make_pad_tests(options): "input_shape": [[1]], "paddings": [[[1, 2]]], "constant_paddings": [False], + "fully_quantize": [False] + }, + # 4D: + { + "dtype": [tf.float32], + "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], + "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], + [[0, 1], [0, 0], [0, 0], [2, 3]], + [[0, 0], [0, 0], [0, 0], [0, 0]]], + "constant_paddings": [True], + "fully_quantize": [True] + }, + # 2D: + { + "dtype": [tf.float32], + "input_shape": [[1, 2]], + "paddings": [[[0, 1], [2, 3]]], + "constant_paddings": [True], + "fully_quantize": [True] + }, + # 1D: + { + "dtype": [tf.float32], + "input_shape": [[1]], + "paddings": [[[1, 2]]], + "constant_paddings": [True], + "fully_quantize": [True] }, ] @@ -75,8 +104,14 @@ def make_pad_tests(options): return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): + """Build inputs for pad op.""" + values = [ - create_tensor_data(parameters["dtype"], parameters["input_shape"]) + create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) ] if not parameters["constant_paddings"]: values.append(np.array(parameters["paddings"])) diff --git a/tensorflow/lite/testing/op_tests/pool.py b/tensorflow/lite/testing/op_tests/pool.py index d95857dcf55..f334d2a77ab 100644 --- a/tensorflow/lite/testing/op_tests/pool.py +++ b/tensorflow/lite/testing/op_tests/pool.py @@ -23,11 +23,12 @@ from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests from tensorflow.lite.testing.zip_test_utils import register_make_test_function -def make_pool_tests(pool_op_in): +def make_pool_tests(pool_op_in, allow_fully_quantize=False): """Make a set of tests to do average pooling. Args: pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool2d`. + allow_fully_quantize: bool, whether fully_quantize is allowed. Returns: A function representing the true generator (after curried pool_op_in). @@ -44,14 +45,35 @@ def make_pool_tests(pool_op_in): """ # Chose a set of parameters - test_parameters = [{ - "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], - "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], - # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]). - "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], - "padding": ["SAME", "VALID"], - "data_format": ["NHWC"], # TODO(aselle): NCHW would be good - }] + test_parameters = [ + { + "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], + [1, 10, 11, 1]], + # TODO(aselle): should add a degenerate shape (e.g. [1, 0, 1, 1]). + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + "fully_quantize": [False], + }, + { + "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], + [1, 10, 11, 1]], + # TODO(aselle): should add a degenerate shape (e.g. [1, 0, 1, 1]). + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + "fully_quantize": [True], + } + ] + # test_parameters include fully_quantize option only when + # allow_fully_quantize is True. + if not allow_fully_quantize: + test_parameters = [ + test_parameter for test_parameter in test_parameters + if True not in test_parameter["fully_quantize"] + ] def build_graph(parameters): input_tensor = tf.compat.v1.placeholder( @@ -65,7 +87,11 @@ def make_pool_tests(pool_op_in): return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + if allow_fully_quantize: + input_values = create_tensor_data( + tf.float32, parameters["input_shape"], min_value=-1, max_value=1) + else: + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) @@ -97,9 +123,13 @@ def make_l2_pool_tests(options): @register_make_test_function() def make_avg_pool_tests(options): - make_pool_tests(tf.nn.avg_pool)(options, expected_tf_failures=80) + make_pool_tests( + tf.nn.avg_pool, allow_fully_quantize=True)( + options, expected_tf_failures=160) @register_make_test_function() def make_max_pool_tests(options): - make_pool_tests(tf.nn.max_pool)(options, expected_tf_failures=80) + make_pool_tests( + tf.nn.max_pool, allow_fully_quantize=True)( + options, expected_tf_failures=160) diff --git a/tensorflow/lite/testing/op_tests/reduce.py b/tensorflow/lite/testing/op_tests/reduce.py index 28c4415a1a1..17be06bff4a 100644 --- a/tensorflow/lite/testing/op_tests/reduce.py +++ b/tensorflow/lite/testing/op_tests/reduce.py @@ -27,7 +27,8 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function def make_reduce_tests(reduce_op, min_value=-10, max_value=10, - boolean_tensor_only=False): + boolean_tensor_only=False, + allow_fully_quantize=False): """Make a set of tests to do reduce operation. Args: @@ -35,6 +36,7 @@ def make_reduce_tests(reduce_op, min_value: min value for created tensor data. max_value: max value for created tensor data. boolean_tensor_only: If true, will only generate tensor with boolean value. + allow_fully_quantize: bool, whether fully_quantize is allowed. Returns: a function representing the true generator with `reduce_op_in` curried. @@ -54,6 +56,7 @@ def make_reduce_tests(reduce_op, ], "const_axis": [True, False], "keepdims": [True, False], + "fully_quantize": [False], }, { "input_dtype": [tf.float32], @@ -67,6 +70,7 @@ def make_reduce_tests(reduce_op, ], "const_axis": [True, False], "keepdims": [True, False], + "fully_quantize": [False], }, { "input_dtype": [tf.float32], @@ -74,6 +78,7 @@ def make_reduce_tests(reduce_op, "axis": [[]], # shape is: [0] "const_axis": [False], "keepdims": [True, False], + "fully_quantize": [False], }, { "input_dtype": [tf.float32], @@ -81,8 +86,39 @@ def make_reduce_tests(reduce_op, "axis": [None], # shape is: [] "const_axis": [True], "keepdims": [True, False], - } + "fully_quantize": [False], + }, + { + "input_dtype": [tf.float32], + "input_shape": [[3, 3, 2, 4]], + "axis": [ + 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], + [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] + ], + "const_axis": [True], + "keepdims": [True, False], + "fully_quantize": [True], + }, + { + "input_dtype": [tf.float32], + "input_shape": [[1, 8, 8, 4], [1, 8, 8, 3]], + "axis": [ + 0, 1, 2, 3, [0], [1], [2], [3], [-1], [-2], [-3], [1, 2], + [0, 3], [1, 2, 3] + ], + "const_axis": [True], + "keepdims": [True, False], + "fully_quantize": [True], + }, ] + # test_parameters include fully_quantize option only when + # allow_fully_quantize is True. + if not allow_fully_quantize: + test_parameters = [ + test_parameter for test_parameter in test_parameters + if True not in test_parameter["fully_quantize"] + ] def build_graph(parameters): """Build the mean op testing graph.""" @@ -139,7 +175,13 @@ def make_mean_tests(options): @register_make_test_function() def make_sum_tests(options): """Make a set of tests to do sum.""" - return make_reduce_tests(tf.reduce_sum)(options) + return make_reduce_tests( + tf.reduce_sum, + min_value=-1, + max_value=1, + boolean_tensor_only=False, + allow_fully_quantize=True)( + options) @register_make_test_function() diff --git a/tensorflow/lite/testing/op_tests/reshape.py b/tensorflow/lite/testing/op_tests/reshape.py index 2486238b2c6..752c1fa53c1 100644 --- a/tensorflow/lite/testing/op_tests/reshape.py +++ b/tensorflow/lite/testing/op_tests/reshape.py @@ -34,11 +34,19 @@ def make_reshape_tests(options): "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], "constant_shape": [True, False], + "fully_quantize": [False], }, { "dtype": [tf.float32], "input_shape": [[1]], "output_shape": [[]], "constant_shape": [True, False], + "fully_quantize": [False], + }, { + "dtype": [tf.float32], + "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], + "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + "constant_shape": [True], + "fully_quantize": [True], }] def build_graph(parameters): @@ -62,8 +70,14 @@ def make_reshape_tests(options): return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): + """Build inputs for reshape op.""" + values = [ - create_tensor_data(parameters["dtype"], parameters["input_shape"]) + create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) ] if not parameters["constant_shape"]: values.append(np.array(parameters["output_shape"])) diff --git a/tensorflow/lite/testing/op_tests/resize_bilinear.py b/tensorflow/lite/testing/op_tests/resize_bilinear.py index 0a87bf58972..06834d00e7f 100644 --- a/tensorflow/lite/testing/op_tests/resize_bilinear.py +++ b/tensorflow/lite/testing/op_tests/resize_bilinear.py @@ -32,6 +32,13 @@ def make_resize_bilinear_tests(options): "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]], "size": [[1, 1], [4, 3], [2, 2], [5, 6]], "align_corners": [None, True, False], + "fully_quantize": [False] + }, { + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]], + "size": [[1, 1], [4, 3], [2, 2], [5, 6]], + "align_corners": [None, True, False], + "fully_quantize": [True] }] def build_graph(parameters): @@ -46,8 +53,11 @@ def make_resize_bilinear_tests(options): return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) diff --git a/tensorflow/lite/testing/op_tests/slice.py b/tensorflow/lite/testing/op_tests/slice.py index e304fd3d026..adfa5781117 100644 --- a/tensorflow/lite/testing/op_tests/slice.py +++ b/tensorflow/lite/testing/op_tests/slice.py @@ -38,6 +38,8 @@ def make_slice_tests(options): "input_shape": [[12, 2, 2, 5]], "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], "size": [[8, 2, 2, 3], [11, 2, 1, 5]], + "constant_indices": [False], + "fully_quantize": [False], }, # 2-D { @@ -46,6 +48,8 @@ def make_slice_tests(options): "input_shape": [[2, 3]], "begin": [[0, 0], [1, 0]], "size": [[2, 3], [2, 2]], + "constant_indices": [False], + "fully_quantize": [False], }, # 4-D with size -1 { @@ -55,6 +59,8 @@ def make_slice_tests(options): "begin": [[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], "size": [[-1, 1, 1, 1], [1, -1, 1, 1], [1, 1, -1, 1], [1, 1, 1, -1]], + "constant_indices": [False, True], + "fully_quantize": [False], }, # last dimension out of index { @@ -63,6 +69,48 @@ def make_slice_tests(options): "input_shape": [[4, 4, 4]], "begin": [[3, 3, 4]], "size": [[-1, -1, -1]], + "constant_indices": [False, True], + "fully_quantize": [False], + }, + # 4-D + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "size": [[8, 2, 2, 3], [11, 2, 1, 5]], + "constant_indices": [True], + "fully_quantize": [True], + }, + # 2-D + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, 0], [1, 0]], + "size": [[2, 3], [2, 2]], + "constant_indices": [True], + "fully_quantize": [True], + }, + # 4-D with size -1 + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[4, 4, 4, 4]], + "begin": [[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], + [0, 0, 0, 1]], + "size": [[-1, 1, 1, 1], [1, -1, 1, 1], [1, 1, -1, 1], [1, 1, 1, -1]], + "constant_indices": [True], + "fully_quantize": [True], + }, + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[1, 4, 4, 4]], + "begin": [[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + "size": [[-1, 1, 1, 1], [1, -1, 1, 1], [1, 1, -1, 1], [1, 1, 1, -1]], + "constant_indices": [True], + "fully_quantize": [True], }, ] @@ -72,33 +120,45 @@ def make_slice_tests(options): dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - begin = tf.compat.v1.placeholder( - dtype=parameters["index_type"], - name="begin", - shape=[len(parameters["input_shape"])]) - size = tf.compat.v1.placeholder( - dtype=parameters["index_type"], - name="size", - shape=[len(parameters["input_shape"])]) - tensors = [input_tensor, begin, size] - out = tf.slice(input_tensor, begin, size) - return tensors, [out] + if parameters["constant_indices"]: + index_type = TF_TYPE_INFO[parameters["index_type"]][0] + begin_values = np.array(parameters["begin"]).astype(index_type) + size_values = np.array(parameters["size"]).astype(index_type) + out = tf.slice(input_tensor, begin_values, size_values) + return [input_tensor], [out] + else: + begin = tf.compat.v1.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + size = tf.compat.v1.placeholder( + dtype=parameters["index_type"], + name="size", + shape=[len(parameters["input_shape"])]) + tensors = [input_tensor, begin, size] + out = tf.slice(input_tensor, begin, size) + return tensors, [out] def build_inputs(parameters, sess, inputs, outputs): """Build inputs for slice test.""" - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - index_type = TF_TYPE_INFO[parameters["index_type"]][0] - - begin_values = np.array(parameters["begin"]).astype(index_type) - size_values = np.array(parameters["size"]).astype(index_type) - values = [input_values, begin_values, size_values] - - return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) + if parameters["constant_indices"]: + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + else: + index_type = TF_TYPE_INFO[parameters["index_type"]][0] + begin_values = np.array(parameters["begin"]).astype(index_type) + size_values = np.array(parameters["size"]).astype(index_type) + values = [input_values, begin_values, size_values] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests( options, test_parameters, build_graph, build_inputs, - expected_tf_failures=24) + expected_tf_failures=27) diff --git a/tensorflow/lite/testing/op_tests/softmax.py b/tensorflow/lite/testing/op_tests/softmax.py index 24d3f673477..c62a8281d80 100644 --- a/tensorflow/lite/testing/op_tests/softmax.py +++ b/tensorflow/lite/testing/op_tests/softmax.py @@ -31,10 +31,12 @@ def make_softmax_tests(options): "dtype": [tf.float32], "input_shape": [[1, 3, 4, 3], [2, 3]], "dim": [-1, 0], + "fully_quantize": [False, True], }, { "dtype": [tf.float32], "input_shape": [[4, 7]], "dim": [-1, 1], + "fully_quantize": [False, True], }] def build_graph(parameters): @@ -46,8 +48,11 @@ def make_softmax_tests(options): return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) diff --git a/tensorflow/lite/testing/op_tests/space_to_depth.py b/tensorflow/lite/testing/op_tests/space_to_depth.py index cca093acbb3..b1a0864a037 100644 --- a/tensorflow/lite/testing/op_tests/space_to_depth.py +++ b/tensorflow/lite/testing/op_tests/space_to_depth.py @@ -31,6 +31,12 @@ def make_space_to_depth_tests(options): "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], + "fully_quantize": [False], + }, { + "dtype": [tf.float32], + "input_shape": [[2, 12, 24, 1], [1, 12, 24, 1]], + "block_size": [2, 3, 4], + "fully_quantize": [True], }] def build_graph(parameters): @@ -42,8 +48,11 @@ def make_space_to_depth_tests(options): return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) diff --git a/tensorflow/lite/testing/op_tests/split.py b/tensorflow/lite/testing/op_tests/split.py index e1e7a6ffd56..00f2b17af54 100644 --- a/tensorflow/lite/testing/op_tests/split.py +++ b/tensorflow/lite/testing/op_tests/split.py @@ -32,6 +32,7 @@ def make_split_tests(options): "input_shape": [[1, 3, 4, 6], [2, 4, 1], [6, 4], [8]], "num_or_size_splits": [1, 2, 3, 4, 5], "axis": [0, 1, 2, 3, -4, -3, -2, -1], + "fully_quantize": [True, False], }] def build_graph(parameters): @@ -42,7 +43,10 @@ def make_split_tests(options): return [input_tensor], [out[0]] def build_inputs(parameters, sess, inputs, outputs): - values = [create_tensor_data(np.float32, parameters["input_shape"])] + values = [ + create_tensor_data( + np.float32, parameters["input_shape"], min_value=-1, max_value=1) + ] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests( @@ -50,4 +54,4 @@ def make_split_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=112) + expected_tf_failures=224) diff --git a/tensorflow/lite/testing/op_tests/squeeze.py b/tensorflow/lite/testing/op_tests/squeeze.py index db86ce1db52..0b6686a4d5d 100644 --- a/tensorflow/lite/testing/op_tests/squeeze.py +++ b/tensorflow/lite/testing/op_tests/squeeze.py @@ -35,14 +35,36 @@ def make_squeeze_tests(options): [-1, -2, -4, -6, -8], [0, 2, 4, 6, 7], [7, 6, 4, 2, 0], [6, 6], [0, 1, 2, 3, 4, 5, 6, 7], [-2, -3, 1, 0, 7, -5] ], + "fully_quantize": [False], }, { "dtype": [tf.int32, tf.float32, tf.int64], "input_shape": [[1]], "axis": [None, [], [0], [-1]], + "fully_quantize": [False], }, { "dtype": [tf.int32, tf.float32, tf.int64], "input_shape": [[1, 1, 1, 1, 1]], "axis": [None, [], [0], [3, 0], [-2, 0, 3, 2]], + "fully_quantize": [False], + }, { + "dtype": [tf.float32], + "input_shape": [[1, 2, 1, 3, 1, 4, 1, 1]], + "axis": [ + None, [], [0, 2], [4, 7], [-1, 0, 2, 0, 7, -6], [1], [2, 3, 2], + [-1, -2, -4, -6, -8], [0, 2, 4, 6, 7], [7, 6, 4, 2, 0], [6, 6], + [0, 1, 2, 3, 4, 5, 6, 7], [-2, -3, 1, 0, 7, -5] + ], + "fully_quantize": [True], + }, { + "dtype": [tf.float32], + "input_shape": [[1, 1, 1, 1, 1]], + "axis": [[0], [3, 0], [-2, 0, 3, 2]], + "fully_quantize": [True], + }, { + "dtype": [tf.float32], + "input_shape": [[1, 1, 5, 10], [1, 5, 1, 10]], + "axis": [[0], [3, 0], [-2, 0, 3, 2]], + "fully_quantize": [True], }] def build_graph(parameters): @@ -54,8 +76,11 @@ def make_squeeze_tests(options): return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) @@ -64,4 +89,4 @@ def make_squeeze_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=12) + expected_tf_failures=20) diff --git a/tensorflow/lite/testing/op_tests/strided_slice.py b/tensorflow/lite/testing/op_tests/strided_slice.py index 86ebfd30a3d..36defb52fdf 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice.py +++ b/tensorflow/lite/testing/op_tests/strided_slice.py @@ -68,8 +68,11 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0): def build_inputs(parameters, sess, inputs, outputs): """Build inputs for stride_slice test.""" - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) + input_values = create_tensor_data( + parameters["dtype"], + parameters["input_shape"], + min_value=-1, + max_value=1) index_type = TF_TYPE_INFO[parameters["index_type"]][0] values = [input_values] if not parameters["constant_indices"]: @@ -111,6 +114,7 @@ def make_strided_slice_tests(options): "end_mask": [None], "shrink_axis_mask": [None], "constant_indices": [False, True], + "fully_quantize": [False], }, # 4-D with non-trivial begin & end. { @@ -124,6 +128,7 @@ def make_strided_slice_tests(options): "end_mask": [None, 3], "shrink_axis_mask": [None, 15, -1], "constant_indices": [True], + "fully_quantize": [False], }, # Begin, end, strides dim are different from input shape { @@ -137,6 +142,7 @@ def make_strided_slice_tests(options): "end_mask": [0], "shrink_axis_mask": [1], "constant_indices": [True], + "fully_quantize": [False], }, # 2-D { @@ -150,6 +156,7 @@ def make_strided_slice_tests(options): "end_mask": [None, 1, 2], "shrink_axis_mask": [None, 1, 2, 3, -1], "constant_indices": [False, True], + "fully_quantize": [False], }, # Negative strides { @@ -163,6 +170,35 @@ def make_strided_slice_tests(options): "end_mask": [None, 1, 2], "shrink_axis_mask": [None, 1, 2, 3, -1], "constant_indices": [False], + "fully_quantize": [False], + }, + # 4-D (cases with const indices and batchsize of 1). + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[1, 2, 2, 5]], + "strides": [None, [1, 1, 1, 1]], + "begin": [[0, 0, 0, 0], [0, 1, 1, 3]], + "end": [[1, 2, 2, 5], [1, 2, 2, 4]], + "begin_mask": [None], + "end_mask": [None], + "shrink_axis_mask": [None], + "constant_indices": [True], + "fully_quantize": [True], + }, + # Begin, end, strides dim are different from input shape + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0]], + "end": [[1]], + "strides": [None, [1]], + "begin_mask": [0], + "end_mask": [0], + "shrink_axis_mask": [1], + "constant_indices": [True], + "fully_quantize": [True], }, ] _make_strided_slice_tests(options, test_parameters, expected_tf_failures=2) diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index 82811f183a6..6ecc659ab6f 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -351,11 +351,15 @@ def make_zip_of_tests(options, "fully_quantize", False): continue - def build_tflite_inputs(tflite_model_binary): - """Build input values and output values of the given tflite model. + def generate_inputs_outputs(tflite_model_binary, + min_value=0, + max_value=255): + """Generate input values and output values of the given tflite model. Args: tflite_model_binary: A serialized flatbuffer as a string. + min_value: min value for the input tensor. + max_value: max value for the input tensor. Returns: (input_values, output_values): input values and output values built. @@ -366,12 +370,11 @@ def make_zip_of_tests(options, input_details = interpreter.get_input_details() input_values = [] for input_detail in input_details: - # TODO(yunluli): Set proper min max value according to dtype. input_value = create_tensor_data( input_detail["dtype"], input_detail["shape"], - min_value=0, - max_value=255) + min_value=min_value, + max_value=max_value) interpreter.set_tensor(input_detail["index"], input_value) input_values.append(input_value) @@ -458,8 +461,9 @@ def make_zip_of_tests(options, if tflite_model_binary: if options.make_edgetpu_tests: - baseline_inputs, baseline_outputs = build_tflite_inputs( - tflite_model_binary) + # Set proper min max values according to input dtype. + baseline_inputs, baseline_outputs = generate_inputs_outputs( + tflite_model_binary, min_value=0, max_value=255) archive.writestr(label + ".bin", tflite_model_binary, zipfile.ZIP_DEFLATED) example = {"inputs": baseline_inputs, "outputs": baseline_outputs}