Add custom input/output/execution_plan to SubgraphWriter.

PiperOrigin-RevId: 303785677
Change-Id: Iad91f6332510604cdb1de3423a4e1056d07d4f28
This commit is contained in:
Abdurrahman Akkas 2020-03-30 11:35:27 -07:00 committed by TensorFlower Gardener
parent be8d751324
commit 2a6aa033c2
4 changed files with 192 additions and 9 deletions

View File

@ -32,6 +32,7 @@ cc_library(
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs_with_reflection",
],
)
@ -63,7 +64,9 @@ cc_test(
deps = [
":writer_lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],

View File

@ -17,8 +17,10 @@ limitations under the License.
#include <cstdlib>
#include <cstring>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/writer/enum_mapping.h"
@ -50,7 +52,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
std::vector<int> operator_to_opcode;
// TODO(aselle): Augment this once we put execution plan in schema.
operator_to_opcode.resize(subgraph_->nodes_size(), -1);
for (int op_index : subgraph_->execution_plan()) {
for (int op_index : execution_plan_) {
const auto* node_and_registration =
subgraph_->node_and_registration(op_index);
const TfLiteRegistration* registration = &node_and_registration->second;
@ -63,7 +65,7 @@ SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
}
}
// second pass serialize operators
for (int op_index : subgraph_->execution_plan()) {
for (int op_index : execution_plan_) {
const auto* node_and_registration =
subgraph_->node_and_registration(op_index);
const TfLiteNode& node = node_and_registration->first;
@ -255,10 +257,8 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
{ // subgraph specific stuff
auto tensors = ExportTensors(&builder);
std::vector<int> written_inputs =
RemapTensorIndicesToWritten(subgraph_->inputs());
std::vector<int> written_outputs =
RemapTensorIndicesToWritten(subgraph_->outputs());
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
@ -309,4 +309,63 @@ TfLiteStatus SubgraphWriter::RegisterCustomWriter(
return kTfLiteOk;
}
TfLiteStatus SubgraphWriter::CheckInputOutput(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const std::vector<int>& execution_plan) {
std::unordered_set<int> known_tensors(inputs.begin(), inputs.end());
// Scan execution plan and confirm input tensors are known before each node
// executes. Then append output tensors to known tensors.
for (int op_index : execution_plan) {
const auto* node_and_registration =
subgraph_->node_and_registration(op_index);
const TfLiteNode& node = node_and_registration->first;
for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
// Skip constant tensors.
if (tensor->allocation_type == kTfLiteMmapRo) {
continue;
}
}
if (known_tensors.find(tensor_index) == known_tensors.end()) {
subgraph_->context()->ReportError(
subgraph_->context(),
"Node (%d) uses an input (%d) that is not provided.", op_index,
tensor_index);
return kTfLiteError;
}
}
TfLiteIntArrayView outputs(node.outputs);
known_tensors.insert(outputs.begin(), outputs.end());
}
// Check if outputs are known tensors or constants.
for (int tensor_index : outputs) {
if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
// Skip constant tensors.
if (tensor->allocation_type == kTfLiteMmapRo) {
continue;
}
}
if (known_tensors.find(tensor_index) == known_tensors.end()) {
subgraph_->context()->ReportError(
subgraph_->context(),
"Output (%d) is not produced by the execution plan.", tensor_index);
return kTfLiteError;
}
}
return kTfLiteOk;
}
TfLiteStatus SubgraphWriter::SetCustomInputOutput(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const std::vector<int>& execution_plan) {
TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
inputs_ = inputs;
outputs_ = outputs;
execution_plan_ = execution_plan;
return kTfLiteOk;
}
} // namespace tflite

View File

@ -30,6 +30,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/writer/enum_mapping.h"
@ -47,9 +48,13 @@ class SubgraphWriter {
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
CustomOptionsFormat* custom_options_format);
// Construct an subgraph writer for the specified `subgraph`. Then,
// a uses .Write() or .GetBuffer(...) to extract the data.
explicit SubgraphWriter(Subgraph* subgraph) : subgraph_(subgraph) {
// Construct a subgraph writer for the specified `subgraph`. Then, use
// .Write() or .GetBuffer(...) to extract the data.
explicit SubgraphWriter(Subgraph* subgraph)
: subgraph_(subgraph),
inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) {
buffers_.push_back(std::make_pair(nullptr, 0));
}
@ -65,6 +70,11 @@ class SubgraphWriter {
void SetUnusedTensors(const std::set<int>& unused_tensors) {
unused_tensors_ = unused_tensors;
}
// Sets custom inputs, outputs, and execution_plan so that a portion of the
// subgraph is written to the buffer instead of the whole subgraph.
TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs,
const std::vector<int>& outputs,
const std::vector<int>& execution_plan);
private:
template <class T>
@ -84,6 +94,12 @@ class SubgraphWriter {
template <class T>
std::vector<int> RemapTensorIndicesToWritten(const T& input);
// Checks if given `input`, `output`, and `execution_plan` represents a valid
// model within the Subgraph.
TfLiteStatus CheckInputOutput(const std::vector<int>& inputs,
const std::vector<int>& outputs,
const std::vector<int>& execution_plan);
int GetOpCodeForBuiltin(int builtin_op_index) {
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
@ -107,6 +123,12 @@ class SubgraphWriter {
// The subgraph we are writing
Subgraph* subgraph_;
// Input tensor indices to be written.
std::vector<int> inputs_;
// Output tensor indices to be written.
std::vector<int> outputs_;
// Order of nodes to be written.
std::vector<int> execution_plan_;
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of op codes and mappings from builtin or custom op to opcode

View File

@ -14,10 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/writer/writer_lib.h"
#include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
@ -55,6 +58,102 @@ TEST(Writer, FloatModelTest) {
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
// Tests writing only a portion of the subgraph.
TEST(Writer, CustomInputOutputTest) {
Interpreter interpreter;
interpreter.AddTensors(4);
constexpr float kFoo[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadOnly(
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
TfLiteQuantization());
interpreter.SetInputs({0, 1});
interpreter.SetOutputs({3});
// Add two ops: Add and Relu
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteAddParams* builtin_data =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
builtin_data->activation = kTfLiteActNone;
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
// Only write the second op.
SubgraphWriter writer(&interpreter.primary_subgraph());
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
/*execution_plan=*/{1}),
kTfLiteOk);
writer.SetUnusedTensors({0, 1});
writer.Write("/tmp/test_custom.tflite");
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite");
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
TEST(Writer, CustomInputOutputErrorCasesTest) {
Interpreter interpreter;
interpreter.AddTensors(5);
constexpr float kFoo[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadOnly(
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadWrite(4, kTfLiteFloat32, "e", {3},
TfLiteQuantization());
interpreter.SetInputs({0, 1});
interpreter.SetOutputs({4});
// Add three ops.
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteAddParams* builtin_data =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
builtin_data->activation = kTfLiteActNone;
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
const TfLiteRegistration* reg3 = resolver.FindOp(BuiltinOperator_RELU6, 1);
interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, reg3);
SubgraphWriter writer(&interpreter.primary_subgraph());
// Test wrong input.
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
/*execution_plan=*/{0, 1}),
kTfLiteError);
// Test wrong output.
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{4},
/*execution_plan=*/{0, 1}),
kTfLiteError);
// Test a valid case.
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{3},
/*execution_plan=*/{0, 1}),
kTfLiteOk);
}
TEST(Writer, PerTensorQuantizedModelTest) {
Interpreter interpreter;
interpreter.AddTensors(3);