Rollforward serialization tool with internal test fix

PiperOrigin-RevId: 347894186
Change-Id: I00bd6f41bca9980df6785e067e01c3b94dc81a95
This commit is contained in:
Sachin Joglekar 2020-12-16 14:07:10 -08:00 committed by TensorFlower Gardener
parent 39d6ff7ac5
commit 2028a32c04
9 changed files with 439 additions and 110 deletions

View File

@ -2256,6 +2256,7 @@ cc_library(
":builtin_ops", ":builtin_ops",
":kernel_util", ":kernel_util",
":variable_op_kernels", ":variable_op_kernels",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector> #include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/core/subgraph.h"
@ -113,10 +114,11 @@ void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) {
TfLiteAddParams* params = TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams))); reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone; params->activation = kTfLiteActNone;
auto* add_reg = ops::builtin::Register_ADD();
add_reg->builtin_code = kTfLiteBuiltinAdd;
int node_index; int node_index;
subgraph->AddNodeWithParameters( subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, params, add_reg, &node_index);
::tflite::ops::builtin::Register_ADD(), &node_index);
} }
// Build a subgraph with an mul op. Helper function for testing. // Build a subgraph with an mul op. Helper function for testing.
@ -143,10 +145,11 @@ void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) {
TfLiteMulParams* params = TfLiteMulParams* params =
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams))); reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
params->activation = kTfLiteActNone; params->activation = kTfLiteActNone;
auto* mul_reg = ops::builtin::Register_MUL();
mul_reg->builtin_code = kTfLiteBuiltinMul;
int node_index; int node_index;
subgraph->AddNodeWithParameters( subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, params, mul_reg, &node_index);
::tflite::ops::builtin::Register_MUL(), &node_index);
} }
// Build a subgraph with a pad op. Helper function for testing. // Build a subgraph with a pad op. Helper function for testing.
@ -172,10 +175,11 @@ void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) {
TfLitePadParams* params = TfLitePadParams* params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams))); reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
auto* pad_reg = ops::builtin::Register_PAD();
pad_reg->builtin_code = kTfLiteBuiltinPad;
int node_index; int node_index;
subgraph->AddNodeWithParameters( subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0,
{kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, params, pad_reg, &node_index);
::tflite::ops::builtin::Register_PAD(), &node_index);
} }
void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) { void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
@ -205,11 +209,12 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams))); reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams)));
params->then_subgraph_index = 1; params->then_subgraph_index = 1;
params->else_subgraph_index = 2; params->else_subgraph_index = 2;
auto* if_reg = ops::builtin::Register_IF();
if_reg->builtin_code = kTfLiteBuiltinIf;
int node_index; int node_index;
subgraph->AddNodeWithParameters( subgraph->AddNodeWithParameters({kCondInput, kInput1, kInput2}, {kOutput}, {},
{kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, nullptr, 0, params, if_reg, &node_index);
::tflite::ops::builtin::Register_IF(), &node_index);
} }
void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) { void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
@ -236,11 +241,13 @@ void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
SetupTensor(subgraph, kInput2, kTfLiteInt32); SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteBool); SetupTensor(subgraph, kOutput, kTfLiteBool);
auto* le_reg = ops::builtin::Register_LESS_EQUAL();
le_reg->builtin_code = kTfLiteBuiltinLessEqual;
CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs}); CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs});
int node_index; int node_index;
subgraph->AddNodeWithParameters( subgraph->AddNodeWithParameters({kInput1, kConstRhs}, {kOutput}, {}, nullptr,
{kInput1, kConstRhs}, {kOutput}, {}, nullptr, 0, nullptr, 0, nullptr, le_reg, &node_index);
::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index);
} }
void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) { void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
@ -277,13 +284,13 @@ void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
TfLiteAddParams* params = TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams))); reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone; params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, auto* add_reg = ops::builtin::Register_ADD();
::tflite::ops::builtin::Register_ADD(), add_reg->builtin_code = kTfLiteBuiltinAdd;
subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, add_reg,
&node_index); &node_index);
params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams))); params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone; params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, add_reg,
::tflite::ops::builtin::Register_ADD(),
&node_index); &node_index);
} }
@ -327,14 +334,18 @@ void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph,
TfLiteAddParams* add_params = TfLiteAddParams* add_params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams))); reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
add_params->activation = kTfLiteActNone; add_params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters( auto* add_reg = ops::builtin::Register_ADD();
{kInputCounter, kConstStep}, {kOutputCounter}, {}, nullptr, 0, add_params, add_reg->builtin_code = kTfLiteBuiltinAdd;
::tflite::ops::builtin::Register_ADD(), &node_index); subgraph->AddNodeWithParameters({kInputCounter, kConstStep}, {kOutputCounter},
{}, nullptr, 0, add_params, add_reg,
&node_index);
TfLitePadParams* pad_params = TfLitePadParams* pad_params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams))); reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams)));
subgraph->AddNodeWithParameters( auto* pad_reg = ops::builtin::Register_PAD();
{kInputValue, kConstPadding}, {kOutputValue}, {}, nullptr, 0, pad_params, pad_reg->builtin_code = kTfLiteBuiltinPad;
::tflite::ops::builtin::Register_PAD(), &node_index); subgraph->AddNodeWithParameters({kInputValue, kConstPadding}, {kOutputValue},
{}, nullptr, 0, pad_params, pad_reg,
&node_index);
} }
void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
@ -364,11 +375,12 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams))); reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams)));
params->cond_subgraph_index = 1; params->cond_subgraph_index = 1;
params->body_subgraph_index = 2; params->body_subgraph_index = 2;
auto* while_reg = ops::builtin::Register_WHILE();
while_reg->builtin_code = kTfLiteBuiltinWhile;
int node_index; int node_index;
subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params, subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params,
::tflite::ops::builtin::Register_WHILE(), while_reg, &node_index);
&node_index);
} }
void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph( void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph(

View File

@ -35,6 +35,7 @@ cc_library(
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs_with_reflection", "//tensorflow/lite/schema:schema_fbs_with_reflection",
"//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/container:flat_hash_map",
], ],
) )
@ -67,6 +68,7 @@ cc_test(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:subgraph_test_util",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util", "//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",

View File

@ -0,0 +1,63 @@
# TFLite Serialization Tool
**NOTE:** This tool is intended for advanced users only, and should be used with
care.
The (C++) serialization library generates and writes a TFLite flatbuffer given
an `Interpreter` or `Subgraph`. Example use-cases include authoring models with
the `Interpreter` API, or updating models on-device (by modifying `tensor.data`
for relevant tensors).
## Serialization
### Writing flatbuffer to file
To write a TFLite model from an `Interpreter` (see `lite/interpreter.h`):
`std::unique_ptr<tflite::Interpreter> interpreter; // ...build/modify
interpreter... tflite::ModelWriter writer(interpreter.get()); std::string
filename = "/tmp/model.tflite"; writer.Write(filename);`
Note that the above API does not support custom I/O tensors or custom ops yet.
However, it does support model with Control Flow.
To generate/write a flatbuffer for a particular `Subgraph` (see
`lite/core/subgraph.h`) you can use `SubgraphWriter`.
```
std::unique_ptr<tflite::Interpreter> interpreter;
// ...build/modify interpreter...
// The number of subgraphs can be obtained by:
// const int num_subgraphs = interpreter_->subgraphs_size();
// Note that 0 <= subgraph_index < num_subgraphs
tflite::SubgraphWriter writer(&interpreter->subgraph(subgraph_index));
std::string filename = "/tmp/model.tflite";
writer.Write(filename);
```
`SubgraphWriter` supports custom ops and/or custom I/O tensors.
### Generating flatbuffer in-memory
Both `ModelWriter` and `SubgraphWriter` support a `GetBuffer` method to return
the generated flatbuffer in-memory:
```
std::unique_ptr<uint8_t[]> output_buffer;
size_t output_buffer_size;
tflite::ModelWriter writer(interpreter.get());
writer.GetBuffer(&output_buffer, &output_buffer_size);
```
## De-serialization
The flatbuffers written as above can be de-serialized just like any other TFLite
model, for eg:
```
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile(filename);
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
```

View File

@ -34,7 +34,7 @@ int main(int argc, char* argv[]) {
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
tflite::SubgraphWriter writer(&interpreter->primary_subgraph()); tflite::ModelWriter writer(interpreter.get());
writer.Write(argv[2]); writer.Write(argv[2]);
return 0; return 0;

View File

@ -29,6 +29,41 @@ limitations under the License.
#include "tensorflow/lite/version.h" #include "tensorflow/lite/version.h"
namespace tflite { namespace tflite {
namespace {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
std::vector<OpCode>* opcodes) {
std::vector<flatbuffers::Offset<OperatorCode>> codes;
for (const auto& it : *opcodes) {
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
codes.push_back(CreateOperatorCodeDirect(
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
}
return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
for (auto buffer : *buffers) {
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
}
return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
}
TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return kTfLiteError;
const int result_size = fwrite(data, 1, size, fp);
fclose(fp);
if (result_size != size) return kTfLiteError;
return kTfLiteOk;
}
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion( std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op, flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
@ -39,6 +74,8 @@ std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>()); return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
} }
} // namespace
template <class T_OUTPUT, class T_INPUT> template <class T_OUTPUT, class T_INPUT>
flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector( flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) { flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
@ -159,8 +196,8 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
// Allocate a buffer index // Allocate a buffer index
int buffer_index = 0; // This is null int buffer_index = 0; // This is null
if (tensor->allocation_type == kTfLiteMmapRo) { if (tensor->allocation_type == kTfLiteMmapRo) {
buffer_index = buffers_.size(); buffer_index = buffers_->size();
buffers_.push_back(std::make_pair( buffers_->push_back(std::make_pair(
reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes)); reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
} }
// Primitive type. // Primitive type.
@ -214,23 +251,12 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>> flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) { SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
std::vector<flatbuffers::Offset<Buffer>> buffer_vector; return ExportBuffersImpl(fbb, buffers_);
for (auto buffer : buffers_) {
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
}
return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
} }
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>> flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) { SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
std::vector<flatbuffers::Offset<OperatorCode>> codes; return CreateOpCodeTableImpl(fbb, opcodes_);
for (const auto& it : opcodes_) {
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
codes.push_back(CreateOperatorCodeDirect(
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
}
return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
} }
template <class T> template <class T>
@ -254,19 +280,9 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
size_t* size) { size_t* size) {
if (!out || !size) return kTfLiteError; if (!out || !size) return kTfLiteError;
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector; std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
{ // subgraph specific stuff subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder));
auto tensors = ExportTensors(&builder);
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
auto ops = ExportOperators(&builder);
subgraphs_as_vector.push_back(
CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>> flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
buffers = ExportBuffers(&builder); buffers = ExportBuffers(&builder);
@ -284,21 +300,23 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
return kTfLiteOk; return kTfLiteOk;
} }
flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
flatbuffers::FlatBufferBuilder* builder) {
auto tensors = ExportTensors(builder);
std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
auto inputs = ExportVector<int32_t>(builder, written_inputs);
auto outputs = ExportVector<int32_t>(builder, written_outputs);
auto ops = ExportOperators(builder);
return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0);
}
TfLiteStatus SubgraphWriter::Write(const std::string& filename) { TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
std::unique_ptr<uint8_t[]> buffer; std::unique_ptr<uint8_t[]> buffer;
size_t size; size_t size;
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
return WriteImpl(filename, buffer.get(), size);
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return kTfLiteError;
if (fwrite(buffer.get(), 1, size, fp) != size) {
fclose(fp);
return kTfLiteError;
}
if (fclose(fp)) return kTfLiteError;
return kTfLiteOk;
} }
TfLiteStatus SubgraphWriter::RegisterCustomWriter( TfLiteStatus SubgraphWriter::RegisterCustomWriter(
@ -377,4 +395,50 @@ TfLiteStatus SubgraphWriter::SetCustomInputOutput(
return kTfLiteOk; return kTfLiteOk;
} }
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
return ExportBuffersImpl(fbb, &buffers_);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
return CreateOpCodeTableImpl(fbb, &opcodes_);
}
TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
size_t* size) {
if (!out || !size) return kTfLiteError;
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_,
&builtin_op_to_opcode_);
subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder));
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
buffers = ExportBuffers(&builder);
auto description = builder.CreateString("Exported from Subgraph.");
auto op_codes = CreateOpCodeTable(&builder);
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs_as_vector),
description, buffers);
::tflite::FinishModelBuffer(builder, model);
const uint8_t* buffer = builder.GetBufferPointer();
*size = builder.GetSize();
(*out).reset(new uint8_t[*size]);
memcpy(out->get(), buffer, *size);
return kTfLiteOk;
}
TfLiteStatus ModelWriter::Write(const std::string& filename) {
std::unique_ptr<uint8_t[]> buffer;
size_t size;
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
return WriteImpl(filename, buffer.get(), size);
}
} // namespace tflite } // namespace tflite

View File

@ -12,37 +12,72 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Writes a flatbuffer of a currently loaded TensorFlow Lite subgraph. // Library to write a flatbuffer of a currently loaded TFLite model/subgraph.
//
// Usage:
// From command line:
// bazel run third_party/tensorflow/lite/experimental/writer:writer
// -- foo.tflite foo.out.tflite
//
// From C++
// std::unique_ptr<Interpreter> interpreter;
// // Build Interpreter however
// // ... <omitted>
// SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite");
#ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
#define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
#include <iostream> #include <iostream>
#include <unordered_map> #include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/schema/reflection/schema_generated.h" #include "tensorflow/lite/schema/reflection/schema_generated.h"
#include "tensorflow/lite/tools/serialization/enum_mapping.h" #include "tensorflow/lite/tools/serialization/enum_mapping.h"
#include "tensorflow/lite/version.h" #include "tensorflow/lite/version.h"
namespace tflite { namespace tflite {
struct OpCode {
int builtin;
std::string custom;
};
// Handles writing a full TFLite model (with 1 or more subgraphs) to a
// serialized TF lite file format.
// TODO(b/174708523): Support custom I/O or unused tensors later.
class ModelWriter {
public:
// Construct a writer for the specified `interpreter`. Then, use
// .Write() or .GetBuffer(...) to extract the data.
explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) {
buffers_.push_back(std::make_pair(nullptr, 0));
}
// Get a buffer and size of a serialized flatbuffer.
TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
// Write the serialized flatbuffer to the prescribed `filename`.
TfLiteStatus Write(const std::string& filename);
private:
template <class T>
using Offset = flatbuffers::Offset<T>;
Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
flatbuffers::FlatBufferBuilder* fbb);
// ModelWriter does not take ownership of this object.
Interpreter* const interpreter_;
// This data corresponds to the overall model (rather than individual
// subgraphs), so we define common fields. Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of used opcodes
std::vector<OpCode> opcodes_;
absl::flat_hash_map<int, int> builtin_op_to_opcode_;
};
// Handles writing TensorFlow Lite running subgraph to a serialized TF lite // Handles writing TensorFlow Lite running subgraph to a serialized TF lite
// file format. // file format.
// TODO(b/174708523): Reconcile into ModelWriter?
class SubgraphWriter { class SubgraphWriter {
public: public:
friend class ModelWriter;
typedef flatbuffers::Offset<Operator> (*CustomWriter)( typedef flatbuffers::Offset<Operator> (*CustomWriter)(
flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index, flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
@ -55,7 +90,10 @@ class SubgraphWriter {
inputs_(subgraph->inputs()), inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()), outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) { execution_plan_(subgraph->execution_plan()) {
buffers_.push_back(std::make_pair(nullptr, 0)); buffers_ = &buffers_data_;
opcodes_ = &opcodes_data_;
builtin_op_to_opcode_ = &builtin_op_to_opcode_data_;
buffers_->push_back(std::make_pair(nullptr, 0));
} }
// Get a buffer and size of a serialized flatbuffer. // Get a buffer and size of a serialized flatbuffer.
@ -77,6 +115,28 @@ class SubgraphWriter {
const std::vector<int>& execution_plan); const std::vector<int>& execution_plan);
private: private:
// Used by ModelWriter.
explicit SubgraphWriter(
Subgraph* subgraph,
std::vector<std::pair<const uint8_t*, size_t>>* external_buffers,
std::vector<OpCode>* external_opcodes,
absl::flat_hash_map<int, int>* external_builtin_op_to_opcode)
: subgraph_(subgraph),
inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) {
buffers_ = external_buffers;
opcodes_ = external_opcodes;
builtin_op_to_opcode_ = external_builtin_op_to_opcode;
buffers_->push_back(std::make_pair(nullptr, 0));
}
// Used by ModelWriter to populate data specific to this subgraph.
// Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_,
// etc. & populated in the Flatbuffer by ModelWriter.
flatbuffers::Offset<SubGraph> PopulateAndGetOffset(
flatbuffers::FlatBufferBuilder* builder);
template <class T> template <class T>
using Offset = flatbuffers::Offset<T>; using Offset = flatbuffers::Offset<T>;
template <class T_OUTPUT, class T_INPUT> template <class T_OUTPUT, class T_INPUT>
@ -102,11 +162,11 @@ class SubgraphWriter {
int GetOpCodeForBuiltin(int builtin_op_index) { int GetOpCodeForBuiltin(int builtin_op_index) {
// auto it = builtin_op_to_opcode_.find(builtin_op_index); // auto it = builtin_op_to_opcode_.find(builtin_op_index);
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result = std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result =
builtin_op_to_opcode_.insert( builtin_op_to_opcode_->insert(
std::make_pair(builtin_op_index, opcodes_.size())); std::make_pair(builtin_op_index, opcodes_->size()));
if (result.second) { if (result.second) {
opcodes_.push_back({builtin_op_index, ""}); opcodes_->push_back({builtin_op_index, ""});
} }
return result.first->second; return result.first->second;
} }
@ -114,9 +174,9 @@ class SubgraphWriter {
int GetOpCodeForCustom(const std::string& custom_name) { int GetOpCodeForCustom(const std::string& custom_name) {
std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result = std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
custom_op_to_opcode_.insert( custom_op_to_opcode_.insert(
std::make_pair(custom_name, opcodes_.size())); std::make_pair(custom_name, opcodes_->size()));
if (result.second) { if (result.second) {
opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name});
} }
return result.first->second; return result.first->second;
} }
@ -129,22 +189,26 @@ class SubgraphWriter {
std::vector<int> outputs_; std::vector<int> outputs_;
// Order of nodes to be written. // Order of nodes to be written.
std::vector<int> execution_plan_; std::vector<int> execution_plan_;
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of op codes and mappings from builtin or custom op to opcode // List of op codes and mappings from builtin or custom op to opcode
struct OpCode {
int builtin;
std::string custom;
};
std::set<int> unused_tensors_; std::set<int> unused_tensors_;
// For every tensor index in the subgraph, the index in the written. // For every tensor index in the subgraph, the index in the written.
// This is different due to temporary and unused tensors not being written. // This is different due to temporary and unused tensors not being written.
std::vector<int> tensor_to_written_tensor_; std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
std::unordered_map<int, int> builtin_op_to_opcode_;
std::unordered_map<std::string, int> custom_op_to_opcode_; std::unordered_map<std::string, int> custom_op_to_opcode_;
std::unordered_map<std::string, CustomWriter> custom_op_to_writer_; std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
// We use pointers for these, since they may be provided by ModelWriter.
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>>* buffers_;
// List of used opcodes
std::vector<OpCode>* opcodes_;
absl::flat_hash_map<int, int>* builtin_op_to_opcode_;
// These are used if SubgraphWriter is being used directly.
std::vector<std::pair<const uint8_t*, size_t>> buffers_data_;
// List of used opcodes
std::vector<OpCode> opcodes_data_;
absl::flat_hash_map<int, int> builtin_op_to_opcode_data_;
}; };
} // namespace tflite } // namespace tflite

View File

@ -15,21 +15,47 @@ limitations under the License.
#include "tensorflow/lite/tools/serialization/writer_lib.h" #include "tensorflow/lite/tools/serialization/writer_lib.h"
#include <cstdlib>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string>
#include <tuple>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/subgraph_test_util.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/testing/util.h"
namespace tflite { namespace tflite {
// Make an interpreter that has no tensors and no nodes
// TODO(b/113731921): add more tests. using subgraph_test_util::CheckIntTensor;
TEST(Writer, FloatModelTest) { using subgraph_test_util::FillIntTensor;
std::string CreateFilePath(const std::string& file_name) {
return std::string(getenv("TEST_TMPDIR")) + file_name;
}
// The bool param indicates whether we use SubgraphWriter(true) or
// ModelWriter(false) for the test
class SingleSubgraphTest : public ::testing::TestWithParam<bool> {
protected:
void WriteToFile(Interpreter* interpreter, const std::string& filename,
bool use_subgraph_writer) {
if (use_subgraph_writer) {
SubgraphWriter writer(&interpreter->primary_subgraph());
CHECK_EQ(writer.Write(filename), kTfLiteOk);
} else {
ModelWriter writer(interpreter);
CHECK_EQ(writer.Write(filename), kTfLiteOk);
}
}
};
TEST_P(SingleSubgraphTest, InvalidDestinations) {
Interpreter interpreter; Interpreter interpreter;
interpreter.AddTensors(3); interpreter.AddTensors(3);
float foo[] = {1, 2, 3}; float foo[] = {1, 2, 3};
@ -52,10 +78,53 @@ TEST(Writer, FloatModelTest) {
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg); reinterpret_cast<void*>(builtin_data), reg);
SubgraphWriter writer(&interpreter.primary_subgraph()); // Check if invalid filename is handled gracefully.
writer.Write("/tmp/test_float.tflite"); if (GetParam()) {
SubgraphWriter writer(&interpreter.primary_subgraph());
CHECK_EQ(writer.Write(""), kTfLiteError);
} else {
ModelWriter writer(&interpreter);
CHECK_EQ(writer.Write(""), kTfLiteError);
}
// Check if invalid buffer is handled gracefully.
size_t size;
if (GetParam()) {
SubgraphWriter writer(&interpreter.primary_subgraph());
CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
} else {
ModelWriter writer(&interpreter);
CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
}
}
TEST_P(SingleSubgraphTest, FloatModelTest) {
Interpreter interpreter;
interpreter.AddTensors(3);
float foo[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
TfLiteQuantization());
interpreter.SetTensorParametersReadOnly(
1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
reinterpret_cast<char*>(foo), sizeof(foo));
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
TfLiteQuantization());
interpreter.SetInputs({0, 1});
interpreter.SetOutputs({2});
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteAddParams* builtin_data =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
builtin_data->activation = kTfLiteActNone;
builtin_data->pot_scale_int16 = false;
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
const std::string test_file = CreateFilePath("test_float.tflite");
WriteToFile(&interpreter, test_file, GetParam());
std::unique_ptr<FlatBufferModel> model = std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_float.tflite"); FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver); InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter; std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter); builder(&new_interpreter);
@ -63,7 +132,7 @@ TEST(Writer, FloatModelTest) {
} }
// Tests writing only a portion of the subgraph. // Tests writing only a portion of the subgraph.
TEST(Writer, CustomInputOutputTest) { TEST_P(SingleSubgraphTest, CustomInputOutputTest) {
Interpreter interpreter; Interpreter interpreter;
interpreter.AddTensors(4); interpreter.AddTensors(4);
constexpr float kFoo[] = {1, 2, 3}; constexpr float kFoo[] = {1, 2, 3};
@ -94,22 +163,23 @@ TEST(Writer, CustomInputOutputTest) {
interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2); interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
// Only write the second op. // Only write the second op.
const std::string test_file = CreateFilePath("test_custom.tflite");
SubgraphWriter writer(&interpreter.primary_subgraph()); SubgraphWriter writer(&interpreter.primary_subgraph());
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3}, EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
/*execution_plan=*/{1}), /*execution_plan=*/{1}),
kTfLiteOk); kTfLiteOk);
writer.SetUnusedTensors({0, 1}); writer.SetUnusedTensors({0, 1});
writer.Write("/tmp/test_custom.tflite"); writer.Write(test_file);
std::unique_ptr<FlatBufferModel> model = std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite"); FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver); InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter; std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter); builder(&new_interpreter);
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
} }
TEST(Writer, CustomInputOutputErrorCasesTest) { TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
Interpreter interpreter; Interpreter interpreter;
interpreter.AddTensors(5); interpreter.AddTensors(5);
constexpr float kFoo[] = {1, 2, 3}; constexpr float kFoo[] = {1, 2, 3};
@ -160,7 +230,7 @@ TEST(Writer, CustomInputOutputErrorCasesTest) {
kTfLiteOk); kTfLiteOk);
} }
TEST(Writer, PerTensorQuantizedModelTest) { TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
Interpreter interpreter; Interpreter interpreter;
interpreter.AddTensors(3); interpreter.AddTensors(3);
interpreter.SetTensorParametersReadWrite( interpreter.SetTensorParametersReadWrite(
@ -181,16 +251,18 @@ TEST(Writer, PerTensorQuantizedModelTest) {
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg); reinterpret_cast<void*>(builtin_data), reg);
SubgraphWriter writer(&interpreter.primary_subgraph()); const std::string test_file = CreateFilePath("test_uint8.tflite");
writer.Write("/tmp/test_uint8.tflite"); WriteToFile(&interpreter, test_file, GetParam());
std::unique_ptr<FlatBufferModel> model = std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test_uint8.tflite"); FlatBufferModel::BuildFromFile(test_file.c_str());
InterpreterBuilder builder(*model, resolver); InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter; std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter); builder(&new_interpreter);
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
} }
INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
struct ReshapeTestPattern { struct ReshapeTestPattern {
int num_inputs; int num_inputs;
bool is_param_valid; bool is_param_valid;
@ -241,8 +313,8 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
SubgraphWriter writer(&interpreter.primary_subgraph()); SubgraphWriter writer(&interpreter.primary_subgraph());
std::stringstream ss; std::stringstream ss;
ss << "/tmp/test_reshape_" << param.num_inputs << param.is_param_valid ss << CreateFilePath("test_reshape_") << param.num_inputs
<< ".tflite"; << param.is_param_valid << ".tflite";
std::string filename = ss.str(); std::string filename = ss.str();
writer.Write(filename); writer.Write(filename);
std::unique_ptr<FlatBufferModel> model = std::unique_ptr<FlatBufferModel> model =
@ -268,6 +340,57 @@ INSTANTIATE_TEST_SUITE_P(
std::string name = ss.str(); std::string name = ss.str();
return name; return name;
}); });
class WhileTest : public subgraph_test_util::ControlFlowOpTest {};
// The test builds a model that produces the i-th number of
// triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
TEST_F(WhileTest, TestTriangularNumberSequence) {
const int kSeqNumber = 4;
const int kExpectedValue = 15;
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {kSeqNumber + 1});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {1}, {kExpectedValue});
// Now serialize & deserialize model into a new Interpreter.
ModelWriter writer(interpreter_.get());
const std::string test_file = CreateFilePath("test_while.tflite");
writer.Write(test_file);
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile(test_file.c_str());
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
// Check deserialized model.
new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1});
new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1});
ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk);
output1 = new_interpreter->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {kSeqNumber + 1});
output2 = new_interpreter->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {1}, {kExpectedValue});
}
} // namespace tflite } // namespace tflite
int main(int argc, char** argv) { int main(int argc, char** argv) {

View File

@ -35,7 +35,7 @@ int main(int argc, char* argv[]) {
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
tflite::SubgraphWriter writer(&interpreter->primary_subgraph()); tflite::ModelWriter writer(interpreter.get());
std::unique_ptr<uint8_t[]> output_buffer; std::unique_ptr<uint8_t[]> output_buffer;
size_t output_buffer_size; size_t output_buffer_size;
writer.GetBuffer(&output_buffer, &output_buffer_size); writer.GetBuffer(&output_buffer, &output_buffer_size);