From 6e45e6ae21529ffebd4c0856cb1c2c8a278bf780 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 18 Oct 2018 17:36:26 -0700
Subject: [PATCH] Properly unittest all supported conversion modes.

PiperOrigin-RevId: 217799124
---
 .../contrib/lite/toco/tflite/export_test.cc   | 135 +++++++++++++-----
 1 file changed, 101 insertions(+), 34 deletions(-)

diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 9ee45ee64eb..6a293901485 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -30,6 +30,7 @@ using ::testing::ElementsAre;
 
 class ExportTest : public ::testing::Test {
  protected:
+  void ResetOperators() { input_model_.operators.clear(); }
   void AddTensorsByName(std::initializer_list<string> names) {
     for (const string& name : names) {
       input_model_.GetOrCreateArray(name);
@@ -93,7 +94,11 @@ class ExportTest : public ::testing::Test {
     std::vector<string> names;
 
     string result;
-    if (!Export(input_model_, &result, params).ok()) return names;
+    auto status = Export(input_model_, &result, params);
+    if (!status.ok()) {
+      LOG(INFO) << status.error_message();
+      return names;
+    }
 
     auto* model = ::tflite::GetModel(result.data());
 
@@ -166,39 +171,6 @@ TEST_F(ExportTest, Export) {
   EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
 }
 
-TEST_F(ExportTest, ExportWithOptions) {
-  // We have three main types of operators:
-  //  1) first-class TOCO/TF Lite operators, which are exported as TF Lite
-  //     builtins
-  //  2) operators that are not recognized by TOCO but are known to the
-  //     standard TensorFlow runtime. Those should be exported as FLEX
-  //     operators, when requested. Otherwise they are custom TF Lite ops.
-  //  3) operators that are not recognized and should be exported as custom
-  //     TF Lite ops. There are two subtypes: ops that would be recognized
-  //     by a custom TensorFlow runtime, and "fake" ops that have only a
-  //     TF Lite impl.
-  AddOperatorsByName({"Add", "ResizeNearestNeighbor"});
-
-  ExportParams params;
-  params.allow_custom_ops = false;
-  params.allow_flex_ops = false;
-  params.quantize_weights = false;
-
-  // Conversion fails because ResizeNearestNeighbor is unknown.
-  EXPECT_THAT(ExportAndSummarizeOperators(params), ElementsAre());
-
-  // ResizeNearestNeighbor is treated as a simple custom op (#3 above).
-  params.allow_custom_ops = true;
-  EXPECT_THAT(ExportAndSummarizeOperators(params),
-              ElementsAre("builtin:ADD", "custom:ResizeNearestNeighbor"));
-
-  // ResizeNearestNeighbor is recognized as a TensorFlow (Flex) op.
-  params.allow_custom_ops = true;
-  params.allow_flex_ops = true;
-  EXPECT_THAT(ExportAndSummarizeOperators(params),
-              ElementsAre("builtin:ADD", "custom:FlexResizeNearestNeighbor"));
-}
-
 TEST_F(ExportTest, QuantizeWeights) {
   // Sanity check for quantize_weights parameter.
   BuildQuantizableTestModel();
@@ -213,6 +185,101 @@ TEST_F(ExportTest, QuantizeWeights) {
   EXPECT_LT(quantized_result.size(), unquantized_result.size());
 }
 
+class OpSetsTest : public ExportTest {
+ public:
+  enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
+
+  void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
+    import_all_ops_as_unsupported_ = true;
+    params_.allow_custom_ops = false;
+    params_.allow_flex_ops = false;
+    params_.quantize_weights = false;
+
+    for (OpSet i : sets) {
+      switch (i) {
+        case kTfLiteBuiltins:
+          import_all_ops_as_unsupported_ = false;
+          break;
+        case kSelectTfOps:
+          params_.allow_flex_ops = true;
+          break;
+        case kCustomOps:
+          params_.allow_custom_ops = true;
+          break;
+      }
+    }
+  }
+
+  std::vector<string> ImportExport(std::initializer_list<string> op_names) {
+    ResetOperators();
+    if (!import_all_ops_as_unsupported_) {
+      AddOperatorsByName(op_names);
+    } else {
+      for (const string& name : op_names) {
+        auto* op = new TensorFlowUnsupportedOperator;
+        op->tensorflow_op = name;
+        input_model_.operators.emplace_back(op);
+      }
+    }
+    return ExportAndSummarizeOperators(params_);
+  }
+
+ private:
+  bool import_all_ops_as_unsupported_;
+  ExportParams params_;
+};
+
+TEST_F(OpSetsTest, BuiltinsOnly) {
+  // --target_op_set=TFLITE_BUILTINS
+  SetAllowedOpSets({kTfLiteBuiltins});
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+              ElementsAre());
+  EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
+
+  // --target_op_set=TFLITE_BUILTINS --allow_custom_ops
+  SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
+  EXPECT_THAT(
+      ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+      ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:UnrollAndFold"));
+}
+
+TEST_F(OpSetsTest, TfSelectOnly) {
+  // --target_op_set=SELECT_TF_OPS
+  SetAllowedOpSets({kSelectTfOps});
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+              ElementsAre());
+
+  // TODO(b/117845348): Add should be recognized as a flex op, which should be
+  // OK even if we don't specify --allow_custom_ops
+  // EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
+  EXPECT_THAT(ImportExport({"Add"}), ElementsAre());
+
+  // --target_op_set=SELECT_TF_OPS --allow_custom_ops
+  SetAllowedOpSets({kSelectTfOps, kCustomOps});
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+              ElementsAre("custom:FlexAdd", "custom:FlexAdjustHue",
+                          "custom:UnrollAndFold"));
+}
+
+TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
+  // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
+  SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+              ElementsAre());
+
+  // TODO(b/117845348): AdjustHue should be recognized as a flex op,
+  // which should be OK even if we don't specify --allow_custom_ops
+  // EXPECT_THAT(ImportExport({"Add", "AdjustHue"}),
+  //            ElementsAre("builtin:ADD", "custom:FlexAdjustHue"));
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue"}), ElementsAre());
+
+  // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
+  SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
+  EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
+              ElementsAre("builtin:ADD", "custom:FlexAdjustHue",
+                          "custom:UnrollAndFold"));
+}
+
 // This test is based on a hypothetical scenario that dilation is supported
 // only in Conv version 2. So Toco populates version=1 when dialation
 // parameters are all 1, and version=2 otehrwise.