From 045b62dc3ee2ce23ace71a39b5e433abbbbe3900 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 18 Feb 2021 03:14:41 -0800
Subject: [PATCH] Distinguish between inference (sub)graphs and validation
 graphs by using a subgraph name prefix.

PiperOrigin-RevId: 358139741
Change-Id: I74a2964320cf5f962a885e918ef8aab43e548847
---
 tensorflow/lite/BUILD                         |   1 +
 tensorflow/lite/core/subgraph.cc              |  10 +++++
 tensorflow/lite/core/subgraph.h               |  11 +++++
 tensorflow/lite/delegates/delegate_test.cc    |  35 ++++++++++++++++
 .../lite/delegates/delegate_test_util.cc      |  38 ++++++++++++------
 .../lite/delegates/delegate_test_util.h       |   2 +
 tensorflow/lite/interpreter.cc                |   4 ++
 tensorflow/lite/interpreter_builder.cc        |   3 ++
 tensorflow/lite/model_test.cc                 |  12 ++++++
 .../2_subgraphs_dont_delegate_name.bin        | Bin 0 -> 188 bytes
 tensorflow/lite/util.cc                       |   4 ++
 tensorflow/lite/util.h                        |   8 ++++
 tensorflow/lite/util_test.cc                  |  10 +++++
 13 files changed, 126 insertions(+), 12 deletions(-)
 create mode 100644 tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin

diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index f41e5bd8d1f..1c05ff86d19 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -623,6 +623,7 @@ cc_test(
     data = [
         "testdata/0_subgraphs.bin",
         "testdata/2_subgraphs.bin",
+        "testdata/2_subgraphs_dont_delegate_name.bin",
         "testdata/add_shared_tensors.bin",
         "testdata/empty_model.bin",
         "testdata/multi_add_flex.bin",
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 7f604823324..45c6d884234 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1619,4 +1619,14 @@ TfLiteStatus Subgraph::SetCustomAllocationForTensor(
   return kTfLiteOk;
 }
 
+void Subgraph::SetName(const char* name) {
+  if (name) {
+    name_ = name;
+  } else {
+    name_ = "";
+  }
+}
+
+const std::string& Subgraph::GetName() const { return name_; }
+
 }  // namespace tflite
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 413e73a3729..d3e82a43b3e 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -35,6 +35,11 @@ limitations under the License.
 #include "tensorflow/lite/util.h"
 
 namespace tflite {
+namespace delegates {
+namespace test_utils {
+class TestDelegate;  // Class for friend declarations.
+}  // namespace test_utils
+}  // namespace delegates
 
 class Subgraph {
  public:
@@ -342,7 +347,11 @@ class Subgraph {
   TfLiteStatus SetCustomAllocationForTensor(
       int tensor_index, const TfLiteCustomAllocation& allocation);
 
+  void SetName(const char* name);
+  const std::string& GetName() const;
+
  private:
+  friend class TestDelegate;
   // SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
   // BufferedProfiler instance, and takes care of event profiling/tracing in a
   // certain subgraph.
@@ -731,6 +740,8 @@ class Subgraph {
 
   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
   resource::ResourceMap* resources_ = nullptr;
+
+  std::string name_;
 };
 
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc
index 7d579c579c7..d8d9067072b 100644
--- a/tensorflow/lite/delegates/delegate_test.cc
+++ b/tensorflow/lite/delegates/delegate_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
 
 #include <gtest/gtest.h>
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/delegates/delegate_test_util.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/interpreter_builder.h"
@@ -769,6 +770,40 @@ TEST_F(TestDelegate, DelegateCustomOpResolution) {
   ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
 }
 
+TEST_F(TestDelegate, AllSubgraphsAreDelegatedByDefault) {
+  interpreter_->AddSubgraphs(1);
+  SetUpSubgraph(interpreter_->subgraph(1));
+  delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+      kTfLiteOk);
+  for (int subgraph_index = 0; subgraph_index < 2; subgraph_index++) {
+    ASSERT_EQ(interpreter_->subgraph(subgraph_index)->execution_plan().size(),
+              1);
+    int node = interpreter_->subgraph(subgraph_index)->execution_plan()[0];
+    const auto* node_and_reg =
+        interpreter_->subgraph(subgraph_index)->node_and_registration(node);
+    EXPECT_EQ(node_and_reg->second.custom_name,
+              delegate_->FakeFusedRegistration().custom_name);
+  }
+}
+
+TEST_F(TestDelegate, ValidationSubgraphsAreNotDelegated) {
+  interpreter_->AddSubgraphs(1);
+  SetUpSubgraph(interpreter_->subgraph(1));
+  interpreter_->subgraph(1)->SetName("VALIDATION:foo");
+  delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+      kTfLiteOk);
+  ASSERT_EQ(interpreter_->subgraph(1)->execution_plan().size(), 3);
+  int node = interpreter_->subgraph(1)->execution_plan()[0];
+  const auto* node_and_reg =
+      interpreter_->subgraph(1)->node_and_registration(node);
+  EXPECT_NE(node_and_reg->second.custom_name,
+            delegate_->FakeFusedRegistration().custom_name);
+}
+
 class TestDelegateWithDynamicTensors : public ::testing::Test {
  protected:
   void SetUp() override {
diff --git a/tensorflow/lite/delegates/delegate_test_util.cc b/tensorflow/lite/delegates/delegate_test_util.cc
index 137bf62b7de..878b1062a7e 100644
--- a/tensorflow/lite/delegates/delegate_test_util.cc
+++ b/tensorflow/lite/delegates/delegate_test_util.cc
@@ -91,19 +91,33 @@ TfLiteRegistration AddOpRegistration() {
 
 void TestDelegate::SetUp() {
   interpreter_.reset(new Interpreter);
-  interpreter_->AddTensors(5);
-  interpreter_->SetInputs({0, 1});
-  interpreter_->SetOutputs({3, 4});
-  TfLiteQuantizationParams quant;
-  interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, quant);
-  interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, quant);
-  interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, quant);
-  interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, quant);
-  interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, quant);
+  SetUpSubgraph(&interpreter_->primary_subgraph());
+}
+
+void TestDelegate::SetUpSubgraph(Subgraph* subgraph) {
+  subgraph->AddTensors(5);
+  subgraph->SetInputs({0, 1});
+  subgraph->SetOutputs({3, 4});
+  std::vector<int> dims({3});
+  TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
+  subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
+                                         dims.data(), quant, false);
+  subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
+                                         dims.data(), quant, false);
+  subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
+                                         dims.data(), quant, false);
+  subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
+                                         dims.data(), quant, false);
+  subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
+                                         dims.data(), quant, false);
   TfLiteRegistration reg = AddOpRegistration();
-  interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
-  interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
-  interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
+  int node_index_ignored;
+  subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, &reg,
+                                  &node_index_ignored);
+  subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, &reg,
+                                  &node_index_ignored);
+  subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, &reg,
+                                  &node_index_ignored);
 }
 
 void TestDelegate::TearDown() {
diff --git a/tensorflow/lite/delegates/delegate_test_util.h b/tensorflow/lite/delegates/delegate_test_util.h
index 72578e63287..914fd85f669 100644
--- a/tensorflow/lite/delegates/delegate_test_util.h
+++ b/tensorflow/lite/delegates/delegate_test_util.h
@@ -48,6 +48,8 @@ class TestDelegate : public ::testing::Test {
     return interpreter_->RemoveAllDelegates();
   }
 
+  void SetUpSubgraph(Subgraph* subgraph);
+
  protected:
   class SimpleDelegate {
    public:
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index 3924c3e0f54..dd99c5f1d2d 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/lite/external_cpu_backend_context.h"
 #include "tensorflow/lite/minimal_logging.h"
 #include "tensorflow/lite/stderr_reporter.h"
+#include "tensorflow/lite/util.h"
 
 // TODO(b/139446230): Move to portable platform header.
 #if defined(__ANDROID__)
@@ -386,6 +387,9 @@ bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
 TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
   TfLiteStatus status = kTfLiteOk;
   for (auto& subgraph : subgraphs_) {
+    if (IsValidationSubgraph(subgraph->GetName().c_str())) {
+      continue;
+    }
     status = subgraph->ModifyGraphWithDelegate(delegate);
     if (status != kTfLiteOk) {
       break;
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 50d96bf5aba..ce48876815c 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -765,6 +765,9 @@ TfLiteStatus InterpreterBuilder::operator()(
       }
     }
     modified_subgraph->SetVariables(std::move(variables));
+    if (subgraph->name()) {
+      modified_subgraph->SetName(subgraph->name()->c_str());
+    }
   }
 
   if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc
index 311cf7a9f6f..a9a2150370e 100644
--- a/tensorflow/lite/model_test.cc
+++ b/tensorflow/lite/model_test.cc
@@ -164,6 +164,18 @@ TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
   EXPECT_EQ(interpreter->subgraphs_size(), 2);
 }
 
+TEST(BasicFlatBufferModel, TestSubgraphName) {
+  auto m = FlatBufferModel::BuildFromFile(
+      "tensorflow/lite/testdata/"
+      "2_subgraphs_dont_delegate_name.bin");
+  ASSERT_TRUE(m);
+  std::unique_ptr<Interpreter> interpreter;
+  ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
+  EXPECT_EQ(interpreter->subgraphs_size(), 2);
+  EXPECT_EQ(interpreter->subgraph(0)->GetName(), "");
+  EXPECT_EQ(interpreter->subgraph(1)->GetName(), "VALIDATION:main");
+}
+
 // Test what happens if we cannot bind any of the ops.
 TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
   auto model = FlatBufferModel::BuildFromFile(
diff --git a/tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin b/tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin
new file mode 100644
index 0000000000000000000000000000000000000000..4b03f19f964900f8c433f2fa8b5dd026d92e1f22
GIT binary patch
literal 188
zcmb1OU|<Mw^D$;%;A0SBU}4~3-~oyV0C@}y%wQIX#{iT;045+a0f<GwLVOGoP<;YG
z5s*F&AU1%~3P2j9mmi439DO`p978<){j72mGxLBVsDKCV0s*Ki7=UI&?B_sL3g!U-
DPml<x

literal 0
HcmV?d00001

diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc
index 995d52bee9b..84dbc16b607 100644
--- a/tensorflow/lite/util.cc
+++ b/tensorflow/lite/util.cc
@@ -172,4 +172,8 @@ std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
   return result;
 }
 
+bool IsValidationSubgraph(const char* name) {
+  // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
+  return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
+}
 }  // namespace tflite
diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h
index 9aaab40bf49..d9d7f7a0a8e 100644
--- a/tensorflow/lite/util.h
+++ b/tensorflow/lite/util.h
@@ -91,6 +91,14 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration);
 
 // Returns a descriptive name with the given op TfLiteRegistration.
 std::string GetOpNameByRegistration(const TfLiteRegistration& registration);
+
+// The prefix of a validation subgraph name.
+// WARNING: This is an experimental API and subject to change.
+constexpr char kValidationSubgraphNamePrefix[] = "VALIDATION:";
+
+// Checks whether the prefix of the subgraph name indicates the subgraph is a
+// validation subgraph.
+bool IsValidationSubgraph(const char* name);
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_UTIL_H_
diff --git a/tensorflow/lite/util_test.cc b/tensorflow/lite/util_test.cc
index 47726bcdb17..46601b908dc 100644
--- a/tensorflow/lite/util_test.cc
+++ b/tensorflow/lite/util_test.cc
@@ -120,6 +120,16 @@ TEST(GetOpNameByRegistration, CustomName) {
   op_name = GetOpNameByRegistration(registration);
   EXPECT_EQ("DELEGATE TestDelegate", op_name);
 }
+
+TEST(ValidationSubgraph, NameIsDetected) {
+  EXPECT_FALSE(IsValidationSubgraph(nullptr));
+  EXPECT_FALSE(IsValidationSubgraph(""));
+  EXPECT_FALSE(IsValidationSubgraph("a name"));
+  EXPECT_FALSE(IsValidationSubgraph("VALIDATIONfoo"));
+  EXPECT_TRUE(IsValidationSubgraph("VALIDATION:"));
+  EXPECT_TRUE(IsValidationSubgraph("VALIDATION:main"));
+}
+
 }  // namespace
 }  // namespace tflite