Rollforward serialization tool with internal test fix

PiperOrigin-RevId: 347894186
Change-Id: I00bd6f41bca9980df6785e067e01c3b94dc81a95
This commit is contained in:
Sachin Joglekar 2020-12-16 14:07:10 -08:00 committed by TensorFlower Gardener
parent 39d6ff7ac5
commit 2028a32c04
9 changed files with 439 additions and 110 deletions

View File

@ -2256,6 +2256,7 @@ cc_library(
":builtin_ops",
":kernel_util",
":variable_op_kernels",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"@com_google_googletest//:gtest",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
@ -113,10 +114,11 @@ void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) {
TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
auto* add_reg = ops::builtin::Register_ADD();
add_reg->builtin_code = kTfLiteBuiltinAdd;
int node_index;
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(), &node_index);
subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
params, add_reg, &node_index);
}
// Build a subgraph with an mul op. Helper function for testing.
@ -143,10 +145,11 @@ void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) {
TfLiteMulParams* params =
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
params->activation = kTfLiteActNone;
auto* mul_reg = ops::builtin::Register_MUL();
mul_reg->builtin_code = kTfLiteBuiltinMul;
int node_index;
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_MUL(), &node_index);
subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
params, mul_reg, &node_index);
}
// Build a subgraph with a pad op. Helper function for testing.
@ -172,10 +175,11 @@ void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) {
TfLitePadParams* params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
auto* pad_reg = ops::builtin::Register_PAD();
pad_reg->builtin_code = kTfLiteBuiltinPad;
int node_index;
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_PAD(), &node_index);
subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
params, pad_reg, &node_index);
}
void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
@ -205,11 +209,12 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams)));
params->then_subgraph_index = 1;
params->else_subgraph_index = 2;
auto* if_reg = ops::builtin::Register_IF();
if_reg->builtin_code = kTfLiteBuiltinIf;
int node_index;
subgraph->AddNodeWithParameters(
{kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_IF(), &node_index);
subgraph->AddNodeWithParameters({kCondInput, kInput1, kInput2}, {kOutput}, {},
nullptr, 0, params, if_reg, &node_index);
}
void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
@ -236,11 +241,13 @@ void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteBool);
auto* le_reg = ops::builtin::Register_LESS_EQUAL();
le_reg->builtin_code = kTfLiteBuiltinLessEqual;
CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs});
int node_index;
subgraph->AddNodeWithParameters(
{kInput1, kConstRhs}, {kOutput}, {}, nullptr, 0, nullptr,
::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index);
subgraph->AddNodeWithParameters({kInput1, kConstRhs}, {kOutput}, {}, nullptr,
0, nullptr, le_reg, &node_index);
}
void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
@ -277,13 +284,13 @@ void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(),
auto* add_reg = ops::builtin::Register_ADD();
add_reg->builtin_code = kTfLiteBuiltinAdd;
subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, add_reg,
&node_index);
params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(),
subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, add_reg,
&node_index);
}
@ -327,14 +334,18 @@ void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph,
TfLiteAddParams* add_params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
add_params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters(
{kInputCounter, kConstStep}, {kOutputCounter}, {}, nullptr, 0, add_params,
::tflite::ops::builtin::Register_ADD(), &node_index);
auto* add_reg = ops::builtin::Register_ADD();
add_reg->builtin_code = kTfLiteBuiltinAdd;
subgraph->AddNodeWithParameters({kInputCounter, kConstStep}, {kOutputCounter},
{}, nullptr, 0, add_params, add_reg,
&node_index);
TfLitePadParams* pad_params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams)));
subgraph->AddNodeWithParameters(
{kInputValue, kConstPadding}, {kOutputValue}, {}, nullptr, 0, pad_params,
::tflite::ops::builtin::Register_PAD(), &node_index);
auto* pad_reg = ops::builtin::Register_PAD();
pad_reg->builtin_code = kTfLiteBuiltinPad;
subgraph->AddNodeWithParameters({kInputValue, kConstPadding}, {kOutputValue},
{}, nullptr, 0, pad_params, pad_reg,
&node_index);
}
void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
@ -364,11 +375,12 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams)));
params->cond_subgraph_index = 1;
params->body_subgraph_index = 2;
auto* while_reg = ops::builtin::Register_WHILE();
while_reg->builtin_code = kTfLiteBuiltinWhile;
int node_index;
subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_WHILE(),
&node_index);
while_reg, &node_index);
}
void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph(

View File

@ -35,6 +35,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs_with_reflection",
"//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/container:flat_hash_map",
],
)
@ -67,6 +68,7 @@ cc_test(
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:subgraph_test_util",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",

View File

@ -0,0 +1,63 @@
# TFLite Serialization Tool
**NOTE:** This tool is intended for advanced users only, and should be used with
care.
The (C++) serialization library generates and writes a TFLite flatbuffer given
an `Interpreter` or `Subgraph`. Example use-cases include authoring models with
the `Interpreter` API, or updating models on-device (by modifying `tensor.data`
for relevant tensors).
## Serialization
### Writing flatbuffer to file
To write a TFLite model from an `Interpreter` (see `lite/interpreter.h`):
`std::unique_ptr<tflite::Interpreter> interpreter; // ...build/modify
interpreter... tflite::ModelWriter writer(interpreter.get()); std::string
filename = "/tmp/model.tflite"; writer.Write(filename);`
Note that the above API does not support custom I/O tensors or custom ops yet.
However, it does support model with Control Flow.
To generate/write a flatbuffer for a particular `Subgraph` (see
`lite/core/subgraph.h`) you can use `SubgraphWriter`.
```
std::unique_ptr<tflite::Interpreter> interpreter;
// ...build/modify interpreter...
// The number of subgraphs can be obtained by:
// const int num_subgraphs = interpreter_->subgraphs_size();
// Note that 0 <= subgraph_index < num_subgraphs
tflite::SubgraphWriter writer(&interpreter->subgraph(subgraph_index));
std::string filename = "/tmp/model.tflite";
writer.Write(filename);
```
`SubgraphWriter` supports custom ops and/or custom I/O tensors.
### Generating flatbuffer in-memory
Both `ModelWriter` and `SubgraphWriter` support a `GetBuffer` method to return
the generated flatbuffer in-memory:
```
std::unique_ptr<uint8_t[]> output_buffer;
size_t output_buffer_size;
tflite::ModelWriter writer(interpreter.get());
writer.GetBuffer(&output_buffer, &output_buffer_size);
```
## De-serialization
The flatbuffers written as above can be de-serialized just like any other TFLite
model, for eg:
```
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile(filename);
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
```

View File

@ -34,7 +34,7 @@ int main(int argc, char* argv[]) {
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
tflite::SubgraphWriter writer(&interpreter->primary_subgraph());
tflite::ModelWriter writer(interpreter.get());
writer.Write(argv[2]);
return 0;

View File

@ -29,6 +29,41 @@ limitations under the License.
#include "tensorflow/lite/version.h"
namespace tflite {
namespace {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
std::vector<OpCode>* opcodes) {
std::vector<flatbuffers::Offset<OperatorCode>> codes;
for (const auto& it : *opcodes) {
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
codes.push_back(CreateOperatorCodeDirect(
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
}
return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
for (auto buffer : *buffers) {
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
}
return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
}
TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return kTfLiteError;
const int result_size = fwrite(data, 1, size, fp);
fclose(fp);
if (result_size != size) return kTfLiteError;
return kTfLiteOk;
}
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
@ -39,6 +74,8 @@ std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
}
} // namespace
template <class T_OUTPUT, class T_INPUT>
flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
@ -159,8 +196,8 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
// Allocate a buffer index
int buffer_index = 0; // This is null
if (tensor->allocation_type == kTfLiteMmapRo) {
buffer_index = buffers_.size();
buffers_.push_back(std::make_pair(
buffer_index = buffers_->size();
buffers_->push_back(std::make_pair(
reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
}
// Primitive type.
@ -214,23 +251,12 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
for (auto buffer : buffers_) {
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
}
return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
return ExportBuffersImpl(fbb, buffers_);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
std::vector<flatbuffers::Offset<OperatorCode>> codes;
for (const auto& it : opcodes_) {
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
codes.push_back(CreateOperatorCodeDirect(
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
}
return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
return CreateOpCodeTableImpl(fbb, opcodes_);
}
template <class T>
@ -254,19 +280,9 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
size_t* size) {
if (!out || !size) return kTfLiteError;
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
{ // subgraph specific stuff
auto tensors = ExportTensors(&builder);
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder));
auto ops = ExportOperators(&builder);
subgraphs_as_vector.push_back(
CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
buffers = ExportBuffers(&builder);
@ -284,21 +300,23 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
return kTfLiteOk;
}
flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
flatbuffers::FlatBufferBuilder* builder) {
auto tensors = ExportTensors(builder);
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
auto inputs = ExportVector<int32_t>(builder, written_inputs);
auto outputs = ExportVector<int32_t>(builder, written_outputs);
auto ops = ExportOperators(builder);
return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0);
}
TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
std::unique_ptr<uint8_t[]> buffer;
size_t size;
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return kTfLiteError;
if (fwrite(buffer.get(), 1, size, fp) != size) {
fclose(fp);
return kTfLiteError;
}
if (fclose(fp)) return kTfLiteError;
return kTfLiteOk;
return WriteImpl(filename, buffer.get(), size);
}
TfLiteStatus SubgraphWriter::RegisterCustomWriter(
@ -377,4 +395,50 @@ TfLiteStatus SubgraphWriter::SetCustomInputOutput(
return kTfLiteOk;
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
return ExportBuffersImpl(fbb, &buffers_);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
return CreateOpCodeTableImpl(fbb, &opcodes_);
}
TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
size_t* size) {
if (!out || !size) return kTfLiteError;
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_,
&builtin_op_to_opcode_);
subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder));
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
buffers = ExportBuffers(&builder);
auto description = builder.CreateString("Exported from Subgraph.");
auto op_codes = CreateOpCodeTable(&builder);
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs_as_vector),
description, buffers);
::tflite::FinishModelBuffer(builder, model);
const uint8_t* buffer = builder.GetBufferPointer();
*size = builder.GetSize();
(*out).reset(new uint8_t[*size]);
memcpy(out->get(), buffer, *size);
return kTfLiteOk;
}
TfLiteStatus ModelWriter::Write(const std::string& filename) {
std::unique_ptr<uint8_t[]> buffer;
size_t size;
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
return WriteImpl(filename, buffer.get(), size);
}
} // namespace tflite

View File

@ -12,37 +12,72 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Writes a flatbuffer of a currently loaded TensorFlow Lite subgraph.
//
// Usage:
// From command line:
// bazel run third_party/tensorflow/lite/experimental/writer:writer
// -- foo.tflite foo.out.tflite
//
// From C++
// std::unique_ptr<Interpreter> interpreter;
// // Build Interpreter however
// // ... <omitted>
// SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite");
// Library to write a flatbuffer of a currently loaded TFLite model/subgraph.
#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
#include <iostream>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/schema/reflection/schema_generated.h"
#include "tensorflow/lite/tools/serialization/enum_mapping.h"
#include "tensorflow/lite/version.h"
namespace tflite {
struct OpCode {
int builtin;
std::string custom;
};
// Handles writing a full TFLite model (with 1 or more subgraphs) to a
// serialized TF lite file format.
// TODO(b/174708523): Support custom I/O or unused tensors later.
class ModelWriter {
public:
// Construct a writer for the specified `interpreter`. Then, use
// .Write() or .GetBuffer(...) to extract the data.
explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) {
buffers_.push_back(std::make_pair(nullptr, 0));
}
// Get a buffer and size of a serialized flatbuffer.
TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
// Write the serialized flatbuffer to the prescribed `filename`.
TfLiteStatus Write(const std::string& filename);
private:
template <class T>
using Offset = flatbuffers::Offset<T>;
Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
flatbuffers::FlatBufferBuilder* fbb);
// ModelWriter does not take ownership of this object.
Interpreter* const interpreter_;
// This data corresponds to the overall model (rather than individual
// subgraphs), so we define common fields. Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of used opcodes
std::vector<OpCode> opcodes_;
absl::flat_hash_map<int, int> builtin_op_to_opcode_;
};
// Handles writing TensorFlow Lite running subgraph to a serialized TF lite
// file format.
// TODO(b/174708523): Reconcile into ModelWriter?
class SubgraphWriter {
public:
friend class ModelWriter;
typedef flatbuffers::Offset<Operator> (*CustomWriter)(
flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
@ -55,7 +90,10 @@ class SubgraphWriter {
inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) {
buffers_.push_back(std::make_pair(nullptr, 0));
buffers_ = &buffers_data_;
opcodes_ = &opcodes_data_;
builtin_op_to_opcode_ = &builtin_op_to_opcode_data_;
buffers_->push_back(std::make_pair(nullptr, 0));
}
// Get a buffer and size of a serialized flatbuffer.
@ -77,6 +115,28 @@ class SubgraphWriter {
const std::vector<int>& execution_plan);
private:
// Used by ModelWriter.
explicit SubgraphWriter(
Subgraph* subgraph,
std::vector<std::pair<const uint8_t*, size_t>>* external_buffers,
std::vector<OpCode>* external_opcodes,
absl::flat_hash_map<int, int>* external_builtin_op_to_opcode)
: subgraph_(subgraph),
inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) {
buffers_ = external_buffers;
opcodes_ = external_opcodes;
builtin_op_to_opcode_ = external_builtin_op_to_opcode;
buffers_->push_back(std::make_pair(nullptr, 0));
}
// Used by ModelWriter to populate data specific to this subgraph.
// Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_,
// etc. & populated in the Flatbuffer by ModelWriter.
flatbuffers::Offset<SubGraph> PopulateAndGetOffset(
flatbuffers::FlatBufferBuilder* builder);
template <class T>
using Offset = flatbuffers::Offset<T>;
template <class T_OUTPUT, class T_INPUT>
@ -102,11 +162,11 @@ class SubgraphWriter {
int GetOpCodeForBuiltin(int builtin_op_index) {
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
builtin_op_to_opcode_.insert(
std::make_pair(builtin_op_index, opcodes_.size()));
std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result =
builtin_op_to_opcode_->insert(
std::make_pair(builtin_op_index, opcodes_->size()));
if (result.second) {
opcodes_.push_back({builtin_op_index, ""});
opcodes_->push_back({builtin_op_index, ""});
}
return result.first->second;
}
@ -114,9 +174,9 @@ class SubgraphWriter {
int GetOpCodeForCustom(const std::string& custom_name) {
std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
custom_op_to_opcode_.insert(
std::make_pair(custom_name, opcodes_.size()));
std::make_pair(custom_name, opcodes_->size()));
if (result.second) {
opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name});
}
return result.first->second;
}
@ -129,22 +189,26 @@ class SubgraphWriter {
std::vector<int> outputs_;
// Order of nodes to be written.
std::vector<int> execution_plan_;
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of op codes and mappings from builtin or custom op to opcode
struct OpCode {
int builtin;
std::string custom;
};
std::set<int> unused_tensors_;
// For every tensor index in the subgraph, the index in the written.
// This is different due to temporary and unused tensors not being written.
std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
std::unordered_map<int, int> builtin_op_to_opcode_;
std::unordered_map<std::string, int> custom_op_to_opcode_;
std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
// We use pointers for these, since they may be provided by ModelWriter.
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>>* buffers_;
// List of used opcodes
std::vector<OpCode>* opcodes_;
absl::flat_hash_map<int, int>* builtin_op_to_opcode_;
// These are used if SubgraphWriter is being used directly.
std::vector<std::pair<const uint8_t*, size_t>> buffers_data_;
// List of used opcodes
std::vector<OpCode> opcodes_data_;
absl::flat_hash_map<int, int> builtin_op_to_opcode_data_;
};
} // namespace tflite

View File

@ -15,21 +15,47 @@ limitations under the License.
#include "tensorflow/lite/tools/serialization/writer_lib.h"
#include <cstdlib>
#include <numeric>
#include <sstream>
#include <string>
#include <tuple>
#include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/subgraph_test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
// Make an interpreter that has no tensors and no nodes
// TODO(b/113731921): add more tests.
TEST(Writer, FloatModelTest) {
using subgraph_test_util::CheckIntTensor;
using subgraph_test_util::FillIntTensor;
std::string CreateFilePath(const std::string& file_name) {
return std::string(getenv("TEST_TMPDIR")) + file_name;
}
// The bool param indicates whether we use SubgraphWriter(true) or
// ModelWriter(false) for the test
class SingleSubgraphTest : public ::testing::TestWithParam<bool> {
protected:
void WriteToFile(Interpreter* interpreter, const std::string& filename,
bool use_subgraph_writer) {
if (use_subgraph_writer) {
SubgraphWriter writer(&interpreter->primary_subgraph());
CHECK_EQ(writer.Write(filename), kTfLiteOk);
} else {
ModelWriter writer(interpreter);
CHECK_EQ(writer.Write(filename), kTfLiteOk);
}
}
};
TEST_P(SingleSubgraphTest, InvalidDestinations) {
Interpreter interpreter;
interpreter.AddTensors(3);
float foo[] = {1, 2, 3};
@ -52,10 +78,53 @@ TEST(Writer, FloatModelTest) {
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
SubgraphWriter writer(&interpreter.primary_subgraph());
writer.Write("/tmp/test_float.tflite");
// Check if invalid filename is handled gracefully.
if (GetParam()) {
SubgraphWriter writer(&interpreter.primary_subgraph());
CHECK_EQ(writer.Write(""), kTfLiteError);
} else {
ModelWriter writer(&interpreter);
CHECK_EQ(writer.Write(""), kTfLiteError);
}
// Check if invalid buffer is handled gracefully.
size_t size;
if (GetParam()) {
SubgraphWriter writer(&interpreter.primary_subgraph());
CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
} else {
ModelWriter writer(&interpreter);
CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
}
}
TEST_P(SingleSubgraphTest, FloatModelTest) {
Interpreter interpreter;
interpreter.AddTensors(3);
float foo[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadOnly(
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
reinterpret_cast<char*>(foo), sizeof(foo));
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
TfLiteQuantization());
interpreter.SetInputs({0, 1});
interpreter.SetOutputs({2});
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteAddParams* builtin_data =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
builtin_data->activation = kTfLiteActNone;
builtin_data->pot_scale_int16 = false;
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
const std::string test_file = CreateFilePath("test_float.tflite");
WriteToFile(&interpreter, test_file, GetParam());
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_float.tflite");
FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
@ -63,7 +132,7 @@ TEST(Writer, FloatModelTest) {
}
// Tests writing only a portion of the subgraph.
TEST(Writer, CustomInputOutputTest) {
TEST_P(SingleSubgraphTest, CustomInputOutputTest) {
Interpreter interpreter;
interpreter.AddTensors(4);
constexpr float kFoo[] = {1, 2, 3};
@ -94,22 +163,23 @@ TEST(Writer, CustomInputOutputTest) {
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
// Only write the second op.
const std::string test_file = CreateFilePath("test_custom.tflite");
SubgraphWriter writer(&interpreter.primary_subgraph());
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
/*execution_plan=*/{1}),
kTfLiteOk);
writer.SetUnusedTensors({0, 1});
writer.Write("/tmp/test_custom.tflite");
writer.Write(test_file);
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite");
FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
TEST(Writer, CustomInputOutputErrorCasesTest) {
TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
Interpreter interpreter;
interpreter.AddTensors(5);
constexpr float kFoo[] = {1, 2, 3};
@ -160,7 +230,7 @@ TEST(Writer, CustomInputOutputErrorCasesTest) {
kTfLiteOk);
}
TEST(Writer, PerTensorQuantizedModelTest) {
TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
Interpreter interpreter;
interpreter.AddTensors(3);
interpreter.SetTensorParametersReadWrite(
@ -181,16 +251,18 @@ TEST(Writer, PerTensorQuantizedModelTest) {
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
SubgraphWriter writer(&interpreter.primary_subgraph());
writer.Write("/tmp/test_uint8.tflite");
const std::string test_file = CreateFilePath("test_uint8.tflite");
WriteToFile(&interpreter, test_file, GetParam());
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_uint8.tflite");
FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
struct ReshapeTestPattern {
int num_inputs;
bool is_param_valid;
@ -241,8 +313,8 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
SubgraphWriter writer(&interpreter.primary_subgraph());
std::stringstream ss;
ss << "/tmp/test_reshape_" << param.num_inputs << param.is_param_valid
<< ".tflite";
ss << CreateFilePath("test_reshape_") << param.num_inputs
<< param.is_param_valid << ".tflite";
std::string filename = ss.str();
writer.Write(filename);
std::unique_ptr<FlatBufferModel> model =
@ -268,6 +340,57 @@ INSTANTIATE_TEST_SUITE_P(
std::string name = ss.str();
return name;
});
class WhileTest : public subgraph_test_util::ControlFlowOpTest {};
// The test builds a model that produces the i-th number of
// triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
TEST_F(WhileTest, TestTriangularNumberSequence) {
const int kSeqNumber = 4;
const int kExpectedValue = 15;
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {kSeqNumber + 1});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {1}, {kExpectedValue});
// Now serialize & deserialize model into a new Interpreter.
ModelWriter writer(interpreter_.get());
const std::string test_file = CreateFilePath("test_while.tflite");
writer.Write(test_file);
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile(test_file.c_str());
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
// Check deserialized model.
new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1});
new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1});
ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk);
output1 = new_interpreter->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {kSeqNumber + 1});
output2 = new_interpreter->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {1}, {kExpectedValue});
}
} // namespace tflite
int main(int argc, char** argv) {

View File

@ -35,7 +35,7 @@ int main(int argc, char* argv[]) {
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
tflite::SubgraphWriter writer(&interpreter->primary_subgraph());
tflite::ModelWriter writer(interpreter.get());
std::unique_ptr<uint8_t[]> output_buffer;
size_t output_buffer_size;
writer.GetBuffer(&output_buffer, &output_buffer_size);