Distinguish between inference (sub)graphs and validation graphs by using a subgraph name prefix.
PiperOrigin-RevId: 358139741 Change-Id: I74a2964320cf5f962a885e918ef8aab43e548847
This commit is contained in:
parent
129463994d
commit
045b62dc3e
@ -623,6 +623,7 @@ cc_test(
|
||||
data = [
|
||||
"testdata/0_subgraphs.bin",
|
||||
"testdata/2_subgraphs.bin",
|
||||
"testdata/2_subgraphs_dont_delegate_name.bin",
|
||||
"testdata/add_shared_tensors.bin",
|
||||
"testdata/empty_model.bin",
|
||||
"testdata/multi_add_flex.bin",
|
||||
|
@ -1619,4 +1619,14 @@ TfLiteStatus Subgraph::SetCustomAllocationForTensor(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void Subgraph::SetName(const char* name) {
|
||||
if (name) {
|
||||
name_ = name;
|
||||
} else {
|
||||
name_ = "";
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& Subgraph::GetName() const { return name_; }
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -35,6 +35,11 @@ limitations under the License.
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace test_utils {
|
||||
class TestDelegate; // Class for friend declarations.
|
||||
} // namespace test_utils
|
||||
} // namespace delegates
|
||||
|
||||
class Subgraph {
|
||||
public:
|
||||
@ -342,7 +347,11 @@ class Subgraph {
|
||||
TfLiteStatus SetCustomAllocationForTensor(
|
||||
int tensor_index, const TfLiteCustomAllocation& allocation);
|
||||
|
||||
void SetName(const char* name);
|
||||
const std::string& GetName() const;
|
||||
|
||||
private:
|
||||
friend class TestDelegate;
|
||||
// SubgraphAwareProfiler wraps an actual TFLite profiler, such as a
|
||||
// BufferedProfiler instance, and takes care of event profiling/tracing in a
|
||||
// certain subgraph.
|
||||
@ -731,6 +740,8 @@ class Subgraph {
|
||||
|
||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||
resource::ResourceMap* resources_ = nullptr;
|
||||
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/delegate_test_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
@ -769,6 +770,40 @@ TEST_F(TestDelegate, DelegateCustomOpResolution) {
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, AllSubgraphsAreDelegatedByDefault) {
|
||||
interpreter_->AddSubgraphs(1);
|
||||
SetUpSubgraph(interpreter_->subgraph(1));
|
||||
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
|
||||
ASSERT_EQ(
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||
kTfLiteOk);
|
||||
for (int subgraph_index = 0; subgraph_index < 2; subgraph_index++) {
|
||||
ASSERT_EQ(interpreter_->subgraph(subgraph_index)->execution_plan().size(),
|
||||
1);
|
||||
int node = interpreter_->subgraph(subgraph_index)->execution_plan()[0];
|
||||
const auto* node_and_reg =
|
||||
interpreter_->subgraph(subgraph_index)->node_and_registration(node);
|
||||
EXPECT_EQ(node_and_reg->second.custom_name,
|
||||
delegate_->FakeFusedRegistration().custom_name);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestDelegate, ValidationSubgraphsAreNotDelegated) {
|
||||
interpreter_->AddSubgraphs(1);
|
||||
SetUpSubgraph(interpreter_->subgraph(1));
|
||||
interpreter_->subgraph(1)->SetName("VALIDATION:foo");
|
||||
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
|
||||
ASSERT_EQ(
|
||||
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_->subgraph(1)->execution_plan().size(), 3);
|
||||
int node = interpreter_->subgraph(1)->execution_plan()[0];
|
||||
const auto* node_and_reg =
|
||||
interpreter_->subgraph(1)->node_and_registration(node);
|
||||
EXPECT_NE(node_and_reg->second.custom_name,
|
||||
delegate_->FakeFusedRegistration().custom_name);
|
||||
}
|
||||
|
||||
class TestDelegateWithDynamicTensors : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
|
@ -91,19 +91,33 @@ TfLiteRegistration AddOpRegistration() {
|
||||
|
||||
void TestDelegate::SetUp() {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddTensors(5);
|
||||
interpreter_->SetInputs({0, 1});
|
||||
interpreter_->SetOutputs({3, 4});
|
||||
TfLiteQuantizationParams quant;
|
||||
interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, quant);
|
||||
interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, quant);
|
||||
interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, quant);
|
||||
interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, quant);
|
||||
interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, quant);
|
||||
SetUpSubgraph(&interpreter_->primary_subgraph());
|
||||
}
|
||||
|
||||
void TestDelegate::SetUpSubgraph(Subgraph* subgraph) {
|
||||
subgraph->AddTensors(5);
|
||||
subgraph->SetInputs({0, 1});
|
||||
subgraph->SetOutputs({3, 4});
|
||||
std::vector<int> dims({3});
|
||||
TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
|
||||
subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
|
||||
dims.data(), quant, false);
|
||||
subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
|
||||
dims.data(), quant, false);
|
||||
subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
|
||||
dims.data(), quant, false);
|
||||
subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
|
||||
dims.data(), quant, false);
|
||||
subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
|
||||
dims.data(), quant, false);
|
||||
TfLiteRegistration reg = AddOpRegistration();
|
||||
interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®);
|
||||
interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®);
|
||||
interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®);
|
||||
int node_index_ignored;
|
||||
subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, ®,
|
||||
&node_index_ignored);
|
||||
subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, ®,
|
||||
&node_index_ignored);
|
||||
subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, ®,
|
||||
&node_index_ignored);
|
||||
}
|
||||
|
||||
void TestDelegate::TearDown() {
|
||||
|
@ -48,6 +48,8 @@ class TestDelegate : public ::testing::Test {
|
||||
return interpreter_->RemoveAllDelegates();
|
||||
}
|
||||
|
||||
void SetUpSubgraph(Subgraph* subgraph);
|
||||
|
||||
protected:
|
||||
class SimpleDelegate {
|
||||
public:
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/stderr_reporter.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
// TODO(b/139446230): Move to portable platform header.
|
||||
#if defined(__ANDROID__)
|
||||
@ -386,6 +387,9 @@ bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
|
||||
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
for (auto& subgraph : subgraphs_) {
|
||||
if (IsValidationSubgraph(subgraph->GetName().c_str())) {
|
||||
continue;
|
||||
}
|
||||
status = subgraph->ModifyGraphWithDelegate(delegate);
|
||||
if (status != kTfLiteOk) {
|
||||
break;
|
||||
|
@ -765,6 +765,9 @@ TfLiteStatus InterpreterBuilder::operator()(
|
||||
}
|
||||
}
|
||||
modified_subgraph->SetVariables(std::move(variables));
|
||||
if (subgraph->name()) {
|
||||
modified_subgraph->SetName(subgraph->name()->c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
|
||||
|
@ -164,6 +164,18 @@ TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
|
||||
EXPECT_EQ(interpreter->subgraphs_size(), 2);
|
||||
}
|
||||
|
||||
TEST(BasicFlatBufferModel, TestSubgraphName) {
|
||||
auto m = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/"
|
||||
"2_subgraphs_dont_delegate_name.bin");
|
||||
ASSERT_TRUE(m);
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
|
||||
EXPECT_EQ(interpreter->subgraphs_size(), 2);
|
||||
EXPECT_EQ(interpreter->subgraph(0)->GetName(), "");
|
||||
EXPECT_EQ(interpreter->subgraph(1)->GetName(), "VALIDATION:main");
|
||||
}
|
||||
|
||||
// Test what happens if we cannot bind any of the ops.
|
||||
TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
|
BIN
tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/2_subgraphs_dont_delegate_name.bin
vendored
Normal file
Binary file not shown.
@ -172,4 +172,8 @@ std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsValidationSubgraph(const char* name) {
|
||||
// NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
|
||||
return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
@ -91,6 +91,14 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration);
|
||||
|
||||
// Returns a descriptive name with the given op TfLiteRegistration.
|
||||
std::string GetOpNameByRegistration(const TfLiteRegistration& registration);
|
||||
|
||||
// The prefix of a validation subgraph name.
|
||||
// WARNING: This is an experimental API and subject to change.
|
||||
constexpr char kValidationSubgraphNamePrefix[] = "VALIDATION:";
|
||||
|
||||
// Checks whether the prefix of the subgraph name indicates the subgraph is a
|
||||
// validation subgraph.
|
||||
bool IsValidationSubgraph(const char* name);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_UTIL_H_
|
||||
|
@ -120,6 +120,16 @@ TEST(GetOpNameByRegistration, CustomName) {
|
||||
op_name = GetOpNameByRegistration(registration);
|
||||
EXPECT_EQ("DELEGATE TestDelegate", op_name);
|
||||
}
|
||||
|
||||
TEST(ValidationSubgraph, NameIsDetected) {
|
||||
EXPECT_FALSE(IsValidationSubgraph(nullptr));
|
||||
EXPECT_FALSE(IsValidationSubgraph(""));
|
||||
EXPECT_FALSE(IsValidationSubgraph("a name"));
|
||||
EXPECT_FALSE(IsValidationSubgraph("VALIDATIONfoo"));
|
||||
EXPECT_TRUE(IsValidationSubgraph("VALIDATION:"));
|
||||
EXPECT_TRUE(IsValidationSubgraph("VALIDATION:main"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user