Add custom input/output/execution_plan to SubgraphWriter.
PiperOrigin-RevId: 303785677 Change-Id: Iad91f6332510604cdb1de3423a4e1056d07d4f28
This commit is contained in:
parent
be8d751324
commit
2a6aa033c2
@ -32,6 +32,7 @@ cc_library(
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs_with_reflection",
|
||||
],
|
||||
)
|
||||
@ -63,7 +64,9 @@ cc_test(
|
||||
deps = [
|
||||
":writer_lib",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/writer/enum_mapping.h"
|
||||
@ -50,7 +52,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
|
||||
std::vector<int> operator_to_opcode;
|
||||
// TODO(aselle): Augment this once we put execution plan in schema.
|
||||
operator_to_opcode.resize(subgraph_->nodes_size(), -1);
|
||||
for (int op_index : subgraph_->execution_plan()) {
|
||||
for (int op_index : execution_plan_) {
|
||||
const auto* node_and_registration =
|
||||
subgraph_->node_and_registration(op_index);
|
||||
const TfLiteRegistration* registration = &node_and_registration->second;
|
||||
@ -63,7 +65,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
|
||||
}
|
||||
}
|
||||
// second pass serialize operators
|
||||
for (int op_index : subgraph_->execution_plan()) {
|
||||
for (int op_index : execution_plan_) {
|
||||
const auto* node_and_registration =
|
||||
subgraph_->node_and_registration(op_index);
|
||||
const TfLiteNode& node = node_and_registration->first;
|
||||
@ -255,10 +257,8 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
|
||||
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
|
||||
{ // subgraph specific stuff
|
||||
auto tensors = ExportTensors(&builder);
|
||||
std::vector<int> written_inputs =
|
||||
RemapTensorIndicesToWritten(subgraph_->inputs());
|
||||
std::vector<int> written_outputs =
|
||||
RemapTensorIndicesToWritten(subgraph_->outputs());
|
||||
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
|
||||
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
|
||||
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
|
||||
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
|
||||
|
||||
@ -309,4 +309,63 @@ TfLiteStatus SubgraphWriter::RegisterCustomWriter(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphWriter::CheckInputOutput(
|
||||
const std::vector<int>& inputs, const std::vector<int>& outputs,
|
||||
const std::vector<int>& execution_plan) {
|
||||
std::unordered_set<int> known_tensors(inputs.begin(), inputs.end());
|
||||
// Scan execution plan and confirm input tensors are known before each node
|
||||
// executes. Then append output tensors to known tensors.
|
||||
for (int op_index : execution_plan) {
|
||||
const auto* node_and_registration =
|
||||
subgraph_->node_and_registration(op_index);
|
||||
const TfLiteNode& node = node_and_registration->first;
|
||||
for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
|
||||
if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
|
||||
// Skip constant tensors.
|
||||
if (tensor->allocation_type == kTfLiteMmapRo) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (known_tensors.find(tensor_index) == known_tensors.end()) {
|
||||
subgraph_->context()->ReportError(
|
||||
subgraph_->context(),
|
||||
"Node (%d) uses an input (%d) that is not provided.", op_index,
|
||||
tensor_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
TfLiteIntArrayView outputs(node.outputs);
|
||||
known_tensors.insert(outputs.begin(), outputs.end());
|
||||
}
|
||||
|
||||
// Check if outputs are known tensors or constants.
|
||||
for (int tensor_index : outputs) {
|
||||
if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
|
||||
// Skip constant tensors.
|
||||
if (tensor->allocation_type == kTfLiteMmapRo) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (known_tensors.find(tensor_index) == known_tensors.end()) {
|
||||
subgraph_->context()->ReportError(
|
||||
subgraph_->context(),
|
||||
"Output (%d) is not produced by the execution plan.", tensor_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphWriter::SetCustomInputOutput(
|
||||
const std::vector<int>& inputs, const std::vector<int>& outputs,
|
||||
const std::vector<int>& execution_plan) {
|
||||
TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
|
||||
inputs_ = inputs;
|
||||
outputs_ = outputs;
|
||||
execution_plan_ = execution_plan;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/writer/enum_mapping.h"
|
||||
@ -47,9 +48,13 @@ class SubgraphWriter {
|
||||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
|
||||
CustomOptionsFormat* custom_options_format);
|
||||
|
||||
// Construct an subgraph writer for the specified `subgraph`. Then,
|
||||
// a uses .Write() or .GetBuffer(...) to extract the data.
|
||||
explicit SubgraphWriter(Subgraph* subgraph) : subgraph_(subgraph) {
|
||||
// Construct a subgraph writer for the specified `subgraph`. Then, use
|
||||
// .Write() or .GetBuffer(...) to extract the data.
|
||||
explicit SubgraphWriter(Subgraph* subgraph)
|
||||
: subgraph_(subgraph),
|
||||
inputs_(subgraph->inputs()),
|
||||
outputs_(subgraph->outputs()),
|
||||
execution_plan_(subgraph->execution_plan()) {
|
||||
buffers_.push_back(std::make_pair(nullptr, 0));
|
||||
}
|
||||
|
||||
@ -65,6 +70,11 @@ class SubgraphWriter {
|
||||
void SetUnusedTensors(const std::set<int>& unused_tensors) {
|
||||
unused_tensors_ = unused_tensors;
|
||||
}
|
||||
// Sets custom inputs, outputs, and execution_plan so that a portion of the
|
||||
// subgraph is written to the buffer instead of the whole subgraph.
|
||||
TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs,
|
||||
const std::vector<int>& outputs,
|
||||
const std::vector<int>& execution_plan);
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
@ -84,6 +94,12 @@ class SubgraphWriter {
|
||||
template <class T>
|
||||
std::vector<int> RemapTensorIndicesToWritten(const T& input);
|
||||
|
||||
// Checks if given `input`, `output`, and `execution_plan` represents a valid
|
||||
// model within the Subgraph.
|
||||
TfLiteStatus CheckInputOutput(const std::vector<int>& inputs,
|
||||
const std::vector<int>& outputs,
|
||||
const std::vector<int>& execution_plan);
|
||||
|
||||
int GetOpCodeForBuiltin(int builtin_op_index) {
|
||||
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
|
||||
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
|
||||
@ -107,6 +123,12 @@ class SubgraphWriter {
|
||||
|
||||
// The subgraph we are writing
|
||||
Subgraph* subgraph_;
|
||||
// Input tensor indices to be written.
|
||||
std::vector<int> inputs_;
|
||||
// Output tensor indices to be written.
|
||||
std::vector<int> outputs_;
|
||||
// Order of nodes to be written.
|
||||
std::vector<int> execution_plan_;
|
||||
// Keep track of byte buffers
|
||||
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
|
||||
// List of op codes and mappings from builtin or custom op to opcode
|
||||
|
@ -14,10 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/writer/writer_lib.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -55,6 +58,102 @@ TEST(Writer, FloatModelTest) {
|
||||
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
// Tests writing only a portion of the subgraph.
|
||||
TEST(Writer, CustomInputOutputTest) {
|
||||
Interpreter interpreter;
|
||||
interpreter.AddTensors(4);
|
||||
constexpr float kFoo[] = {1, 2, 3};
|
||||
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
|
||||
reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
|
||||
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetInputs({0, 1});
|
||||
interpreter.SetOutputs({3});
|
||||
|
||||
// Add two ops: Add and Relu
|
||||
const char* initial_data = "";
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
TfLiteAddParams* builtin_data =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
builtin_data->activation = kTfLiteActNone;
|
||||
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
|
||||
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
|
||||
reinterpret_cast<void*>(builtin_data), reg);
|
||||
|
||||
const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
|
||||
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
|
||||
|
||||
// Only write the second op.
|
||||
SubgraphWriter writer(&interpreter.primary_subgraph());
|
||||
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
|
||||
/*execution_plan=*/{1}),
|
||||
kTfLiteOk);
|
||||
writer.SetUnusedTensors({0, 1});
|
||||
writer.Write("/tmp/test_custom.tflite");
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model =
|
||||
FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite");
|
||||
InterpreterBuilder builder(*model, resolver);
|
||||
std::unique_ptr<Interpreter> new_interpreter;
|
||||
builder(&new_interpreter);
|
||||
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
TEST(Writer, CustomInputOutputErrorCasesTest) {
|
||||
Interpreter interpreter;
|
||||
interpreter.AddTensors(5);
|
||||
constexpr float kFoo[] = {1, 2, 3};
|
||||
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
|
||||
reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
|
||||
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetTensorParametersReadWrite(4, kTfLiteFloat32, "e", {3},
|
||||
TfLiteQuantization());
|
||||
interpreter.SetInputs({0, 1});
|
||||
interpreter.SetOutputs({4});
|
||||
|
||||
// Add three ops.
|
||||
const char* initial_data = "";
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
TfLiteAddParams* builtin_data =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
builtin_data->activation = kTfLiteActNone;
|
||||
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
|
||||
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
|
||||
reinterpret_cast<void*>(builtin_data), reg);
|
||||
|
||||
const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
|
||||
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
|
||||
|
||||
const TfLiteRegistration* reg3 = resolver.FindOp(BuiltinOperator_RELU6, 1);
|
||||
interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, reg3);
|
||||
|
||||
SubgraphWriter writer(&interpreter.primary_subgraph());
|
||||
|
||||
// Test wrong input.
|
||||
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
|
||||
/*execution_plan=*/{0, 1}),
|
||||
kTfLiteError);
|
||||
// Test wrong output.
|
||||
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{4},
|
||||
/*execution_plan=*/{0, 1}),
|
||||
kTfLiteError);
|
||||
// Test a valid case.
|
||||
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{3},
|
||||
/*execution_plan=*/{0, 1}),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
TEST(Writer, PerTensorQuantizedModelTest) {
|
||||
Interpreter interpreter;
|
||||
interpreter.AddTensors(3);
|
||||
|
Loading…
Reference in New Issue
Block a user