Distinguish between inference (sub)graphs and validation graphs by using a subgraph name prefix.

PiperOrigin-RevId: 358139741
Change-Id: I74a2964320cf5f962a885e918ef8aab43e548847
This commit is contained in:
A. Unique TensorFlower 2021-02-18 03:14:41 -08:00 committed by TensorFlower Gardener
parent 129463994d
commit 045b62dc3e
13 changed files with 126 additions and 12 deletions

View File

@ -623,6 +623,7 @@ cc_test(
data = [
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/2_subgraphs_dont_delegate_name.bin",
"testdata/add_shared_tensors.bin",
"testdata/empty_model.bin",
"testdata/multi_add_flex.bin",

View File

@ -1619,4 +1619,14 @@ TfLiteStatus Subgraph::SetCustomAllocationForTensor(
return kTfLiteOk;
}
void Subgraph::SetName(const char* name) {
if (name) {
name_ = name;
} else {
name_ = "";
}
}
const std::string& Subgraph::GetName() const { return name_; }
} // namespace tflite

View File

@ -35,6 +35,11 @@ limitations under the License.
#include "tensorflow/lite/util.h"
namespace tflite {
namespace delegates {
namespace test_utils {
class TestDelegate; // Class for friend declarations.
} // namespace test_utils
} // namespace delegates
class Subgraph {
public:
@ -342,7 +347,11 @@ class Subgraph {
TfLiteStatus SetCustomAllocationForTensor(
int tensor_index, const TfLiteCustomAllocation& allocation);
void SetName(const char* name);
const std::string& GetName() const;
private:
friend class TestDelegate;
// SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
// BufferedProfiler instance, and takes care of event profiling/tracing in a
// certain subgraph.
@ -731,6 +740,8 @@ class Subgraph {
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap* resources_ = nullptr;
std::string name_;
};
} // namespace tflite

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/delegate_test_util.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
@ -769,6 +770,40 @@ TEST_F(TestDelegate, DelegateCustomOpResolution) {
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
}
TEST_F(TestDelegate, AllSubgraphsAreDelegatedByDefault) {
interpreter_->AddSubgraphs(1);
SetUpSubgraph(interpreter_->subgraph(1));
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
ASSERT_EQ(
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
kTfLiteOk);
for (int subgraph_index = 0; subgraph_index < 2; subgraph_index++) {
ASSERT_EQ(interpreter_->subgraph(subgraph_index)->execution_plan().size(),
1);
int node = interpreter_->subgraph(subgraph_index)->execution_plan()[0];
const auto* node_and_reg =
interpreter_->subgraph(subgraph_index)->node_and_registration(node);
EXPECT_EQ(node_and_reg->second.custom_name,
delegate_->FakeFusedRegistration().custom_name);
}
}
TEST_F(TestDelegate, ValidationSubgraphsAreNotDelegated) {
interpreter_->AddSubgraphs(1);
SetUpSubgraph(interpreter_->subgraph(1));
interpreter_->subgraph(1)->SetName("VALIDATION:foo");
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
ASSERT_EQ(
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
kTfLiteOk);
ASSERT_EQ(interpreter_->subgraph(1)->execution_plan().size(), 3);
int node = interpreter_->subgraph(1)->execution_plan()[0];
const auto* node_and_reg =
interpreter_->subgraph(1)->node_and_registration(node);
EXPECT_NE(node_and_reg->second.custom_name,
delegate_->FakeFusedRegistration().custom_name);
}
class TestDelegateWithDynamicTensors : public ::testing::Test {
protected:
void SetUp() override {

View File

@ -91,19 +91,33 @@ TfLiteRegistration AddOpRegistration() {
void TestDelegate::SetUp() {
interpreter_.reset(new Interpreter);
interpreter_->AddTensors(5);
interpreter_->SetInputs({0, 1});
interpreter_->SetOutputs({3, 4});
TfLiteQuantizationParams quant;
interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, quant);
interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, quant);
interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, quant);
interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, quant);
interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, quant);
SetUpSubgraph(&interpreter_->primary_subgraph());
}
void TestDelegate::SetUpSubgraph(Subgraph* subgraph) {
subgraph->AddTensors(5);
subgraph->SetInputs({0, 1});
subgraph->SetOutputs({3, 4});
std::vector<int> dims({3});
TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
TfLiteRegistration reg = AddOpRegistration();
interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
int node_index_ignored;
subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
}
void TestDelegate::TearDown() {

View File

@ -48,6 +48,8 @@ class TestDelegate : public ::testing::Test {
return interpreter_->RemoveAllDelegates();
}
void SetUpSubgraph(Subgraph* subgraph);
protected:
class SimpleDelegate {
public:

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/lite/external_cpu_backend_context.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/stderr_reporter.h"
#include "tensorflow/lite/util.h"
// TODO(b/139446230): Move to portable platform header.
#if defined(__ANDROID__)
@ -386,6 +387,9 @@ bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
TfLiteStatus status = kTfLiteOk;
for (auto& subgraph : subgraphs_) {
if (IsValidationSubgraph(subgraph->GetName().c_str())) {
continue;
}
status = subgraph->ModifyGraphWithDelegate(delegate);
if (status != kTfLiteOk) {
break;

View File

@ -765,6 +765,9 @@ TfLiteStatus InterpreterBuilder::operator()(
}
}
modified_subgraph->SetVariables(std::move(variables));
if (subgraph->name()) {
modified_subgraph->SetName(subgraph->name()->c_str());
}
}
if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=

View File

@ -164,6 +164,18 @@ TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
EXPECT_EQ(interpreter->subgraphs_size(), 2);
}
TEST(BasicFlatBufferModel, TestSubgraphName) {
auto m = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/"
"2_subgraphs_dont_delegate_name.bin");
ASSERT_TRUE(m);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
EXPECT_EQ(interpreter->subgraphs_size(), 2);
EXPECT_EQ(interpreter->subgraph(0)->GetName(), "");
EXPECT_EQ(interpreter->subgraph(1)->GetName(), "VALIDATION:main");
}
// Test what happens if we cannot bind any of the ops.
TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
auto model = FlatBufferModel::BuildFromFile(

Binary file not shown.

View File

@ -172,4 +172,8 @@ std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
return result;
}
bool IsValidationSubgraph(const char* name) {
// NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
}
} // namespace tflite

View File

@ -91,6 +91,14 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration);
// Returns a descriptive name with the given op TfLiteRegistration.
std::string GetOpNameByRegistration(const TfLiteRegistration& registration);
// The prefix of a validation subgraph name.
// WARNING: This is an experimental API and subject to change.
constexpr char kValidationSubgraphNamePrefix[] = "VALIDATION:";
// Checks whether the prefix of the subgraph name indicates the subgraph is a
// validation subgraph.
bool IsValidationSubgraph(const char* name);
} // namespace tflite
#endif // TENSORFLOW_LITE_UTIL_H_

View File

@ -120,6 +120,16 @@ TEST(GetOpNameByRegistration, CustomName) {
op_name = GetOpNameByRegistration(registration);
EXPECT_EQ("DELEGATE TestDelegate", op_name);
}
TEST(ValidationSubgraph, NameIsDetected) {
EXPECT_FALSE(IsValidationSubgraph(nullptr));
EXPECT_FALSE(IsValidationSubgraph(""));
EXPECT_FALSE(IsValidationSubgraph("a name"));
EXPECT_FALSE(IsValidationSubgraph("VALIDATIONfoo"));
EXPECT_TRUE(IsValidationSubgraph("VALIDATION:"));
EXPECT_TRUE(IsValidationSubgraph("VALIDATION:main"));
}
} // namespace
} // namespace tflite