From 045b62dc3ee2ce23ace71a39b5e433abbbbe3900 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Thu, 18 Feb 2021 03:14:41 -0800 Subject: [PATCH] Distinguish between inference (sub)graphs and validation graphs by using a subgraph name prefix. PiperOrigin-RevId: 358139741 Change-Id: I74a2964320cf5f962a885e918ef8aab43e548847 --- tensorflow/lite/BUILD | 1 + tensorflow/lite/core/subgraph.cc | 10 +++++ tensorflow/lite/core/subgraph.h | 11 +++++ tensorflow/lite/delegates/delegate_test.cc | 35 ++++++++++++++++ .../lite/delegates/delegate_test_util.cc | 38 ++++++++++++------ .../lite/delegates/delegate_test_util.h | 2 + tensorflow/lite/interpreter.cc | 4 ++ tensorflow/lite/interpreter_builder.cc | 3 ++ tensorflow/lite/model_test.cc | 12 ++++++ .../2_subgraphs_dont_delegate_name.bin | Bin 0 -> 188 bytes tensorflow/lite/util.cc | 4 ++ tensorflow/lite/util.h | 8 ++++ tensorflow/lite/util_test.cc | 10 +++++ 13 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index f41e5bd8d1f..1c05ff86d19 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -623,6 +623,7 @@ cc_test( data = [ "testdata/0_subgraphs.bin", "testdata/2_subgraphs.bin", + "testdata/2_subgraphs_dont_delegate_name.bin", "testdata/add_shared_tensors.bin", "testdata/empty_model.bin", "testdata/multi_add_flex.bin", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 7f604823324..45c6d884234 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1619,4 +1619,14 @@ TfLiteStatus Subgraph::SetCustomAllocationForTensor( return kTfLiteOk; } +void Subgraph::SetName(const char* name) { + if (name) { + name_ = name; + } else { + name_ = ""; + } +} + +const std::string& Subgraph::GetName() const { return name_; } + } // namespace tflite diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 413e73a3729..d3e82a43b3e 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -35,6 +35,11 @@ limitations under the License. #include "tensorflow/lite/util.h" namespace tflite { +namespace delegates { +namespace test_utils { +class TestDelegate; // Class for friend declarations. +} // namespace test_utils +} // namespace delegates class Subgraph { public: @@ -342,7 +347,11 @@ class Subgraph { TfLiteStatus SetCustomAllocationForTensor( int tensor_index, const TfLiteCustomAllocation& allocation); + void SetName(const char* name); + const std::string& GetName() const; + private: + friend class TestDelegate; // SubgraphAwareProfiler wraps an actual TFLite profiler, such as a // BufferedProfiler instance, and takes care of event profiling/tracing in a // certain subgraph. @@ -731,6 +740,8 @@ class Subgraph { // A map of resources. Owned by interpreter and shared by multiple subgraphs. resource::ResourceMap* resources_ = nullptr; + + std::string name_; }; } // namespace tflite diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index 7d579c579c7..d8d9067072b 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include <gtest/gtest.h> #include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" @@ -769,6 +770,40 @@ TEST_F(TestDelegate, DelegateCustomOpResolution) { ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); } +TEST_F(TestDelegate, AllSubgraphsAreDelegatedByDefault) { + interpreter_->AddSubgraphs(1); + SetUpSubgraph(interpreter_->subgraph(1)); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + for (int subgraph_index = 0; subgraph_index < 2; subgraph_index++) { + ASSERT_EQ(interpreter_->subgraph(subgraph_index)->execution_plan().size(), + 1); + int node = interpreter_->subgraph(subgraph_index)->execution_plan()[0]; + const auto* node_and_reg = + interpreter_->subgraph(subgraph_index)->node_and_registration(node); + EXPECT_EQ(node_and_reg->second.custom_name, + delegate_->FakeFusedRegistration().custom_name); + } +} + +TEST_F(TestDelegate, ValidationSubgraphsAreNotDelegated) { + interpreter_->AddSubgraphs(1); + SetUpSubgraph(interpreter_->subgraph(1)); + interpreter_->subgraph(1)->SetName("VALIDATION:foo"); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->subgraph(1)->execution_plan().size(), 3); + int node = interpreter_->subgraph(1)->execution_plan()[0]; + const auto* node_and_reg = + interpreter_->subgraph(1)->node_and_registration(node); + EXPECT_NE(node_and_reg->second.custom_name, + delegate_->FakeFusedRegistration().custom_name); +} + class TestDelegateWithDynamicTensors : public ::testing::Test { protected: void SetUp() override { diff --git a/tensorflow/lite/delegates/delegate_test_util.cc b/tensorflow/lite/delegates/delegate_test_util.cc index 137bf62b7de..878b1062a7e 100644 --- a/tensorflow/lite/delegates/delegate_test_util.cc +++ b/tensorflow/lite/delegates/delegate_test_util.cc @@ -91,19 +91,33 @@ TfLiteRegistration AddOpRegistration() { void TestDelegate::SetUp() { interpreter_.reset(new Interpreter); - interpreter_->AddTensors(5); - interpreter_->SetInputs({0, 1}); - interpreter_->SetOutputs({3, 4}); - TfLiteQuantizationParams quant; - interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, quant); - interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, quant); - interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, quant); - interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, quant); - interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, quant); + SetUpSubgraph(&interpreter_->primary_subgraph()); +} + +void TestDelegate::SetUpSubgraph(Subgraph* subgraph) { + subgraph->AddTensors(5); + subgraph->SetInputs({0, 1}); + subgraph->SetOutputs({3, 4}); + std::vector<int> dims({3}); + TfLiteQuantization quant{kTfLiteNoQuantization, nullptr}; + subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(), + dims.data(), quant, false); + subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(), + dims.data(), quant, false); + subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(), + dims.data(), quant, false); + subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(), + dims.data(), quant, false); + subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(), + dims.data(), quant, false); TfLiteRegistration reg = AddOpRegistration(); - interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); - interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); - interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + int node_index_ignored; + subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, ®, + &node_index_ignored); + subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, ®, + &node_index_ignored); + subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, ®, + &node_index_ignored); } void TestDelegate::TearDown() { diff --git a/tensorflow/lite/delegates/delegate_test_util.h b/tensorflow/lite/delegates/delegate_test_util.h index 72578e63287..914fd85f669 100644 --- a/tensorflow/lite/delegates/delegate_test_util.h +++ b/tensorflow/lite/delegates/delegate_test_util.h @@ -48,6 +48,8 @@ class TestDelegate : public ::testing::Test { return interpreter_->RemoveAllDelegates(); } + void SetUpSubgraph(Subgraph* subgraph); + protected: class SimpleDelegate { public: diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index 3924c3e0f54..dd99c5f1d2d 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/lite/external_cpu_backend_context.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/stderr_reporter.h" +#include "tensorflow/lite/util.h" // TODO(b/139446230): Move to portable platform header. #if defined(__ANDROID__) @@ -386,6 +387,9 @@ bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { TfLiteStatus status = kTfLiteOk; for (auto& subgraph : subgraphs_) { + if (IsValidationSubgraph(subgraph->GetName().c_str())) { + continue; + } status = subgraph->ModifyGraphWithDelegate(delegate); if (status != kTfLiteOk) { break; diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 50d96bf5aba..ce48876815c 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -765,6 +765,9 @@ TfLiteStatus InterpreterBuilder::operator()( } } modified_subgraph->SetVariables(std::move(variables)); + if (subgraph->name()) { + modified_subgraph->SetName(subgraph->name()->c_str()); + } } if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) != diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index 311cf7a9f6f..a9a2150370e 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -164,6 +164,18 @@ TEST(BasicFlatBufferModel, TestMultipleSubgraphs) { EXPECT_EQ(interpreter->subgraphs_size(), 2); } +TEST(BasicFlatBufferModel, TestSubgraphName) { + auto m = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/" + "2_subgraphs_dont_delegate_name.bin"); + ASSERT_TRUE(m); + std::unique_ptr<Interpreter> interpreter; + ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk); + EXPECT_EQ(interpreter->subgraphs_size(), 2); + EXPECT_EQ(interpreter->subgraph(0)->GetName(), ""); + EXPECT_EQ(interpreter->subgraph(1)->GetName(), "VALIDATION:main"); +} + // Test what happens if we cannot bind any of the ops. TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { auto model = FlatBufferModel::BuildFromFile( diff --git a/tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin b/tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin new file mode 100644 index 0000000000000000000000000000000000000000..4b03f19f964900f8c433f2fa8b5dd026d92e1f22 GIT binary patch literal 188 zcmb1OU|<Mw^D$;%;A0SBU}4~3-~oyV0C@}y%wQIX#{iT;045+a0f<GwLVOGoP<;YG z5s*F&AU1%~3P2j9mmi439DO`p978<){j72mGxLBVsDKCV0s*Ki7=UI&?B_sL3g!U- DPml<x literal 0 HcmV?d00001 diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index 995d52bee9b..84dbc16b607 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -172,4 +172,8 @@ std::string GetOpNameByRegistration(const TfLiteRegistration& registration) { return result; } +bool IsValidationSubgraph(const char* name) { + // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed. + return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0; +} } // namespace tflite diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h index 9aaab40bf49..d9d7f7a0a8e 100644 --- a/tensorflow/lite/util.h +++ b/tensorflow/lite/util.h @@ -91,6 +91,14 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration); // Returns a descriptive name with the given op TfLiteRegistration. std::string GetOpNameByRegistration(const TfLiteRegistration& registration); + +// The prefix of a validation subgraph name. +// WARNING: This is an experimental API and subject to change. +constexpr char kValidationSubgraphNamePrefix[] = "VALIDATION:"; + +// Checks whether the prefix of the subgraph name indicates the subgraph is a +// validation subgraph. +bool IsValidationSubgraph(const char* name); } // namespace tflite #endif // TENSORFLOW_LITE_UTIL_H_ diff --git a/tensorflow/lite/util_test.cc b/tensorflow/lite/util_test.cc index 47726bcdb17..46601b908dc 100644 --- a/tensorflow/lite/util_test.cc +++ b/tensorflow/lite/util_test.cc @@ -120,6 +120,16 @@ TEST(GetOpNameByRegistration, CustomName) { op_name = GetOpNameByRegistration(registration); EXPECT_EQ("DELEGATE TestDelegate", op_name); } + +TEST(ValidationSubgraph, NameIsDetected) { + EXPECT_FALSE(IsValidationSubgraph(nullptr)); + EXPECT_FALSE(IsValidationSubgraph("")); + EXPECT_FALSE(IsValidationSubgraph("a name")); + EXPECT_FALSE(IsValidationSubgraph("VALIDATIONfoo")); + EXPECT_TRUE(IsValidationSubgraph("VALIDATION:")); + EXPECT_TRUE(IsValidationSubgraph("VALIDATION:main")); +} + } // namespace } // namespace tflite