From 2028a32c044d745285b7d521f8c6cbe895d1083e Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Wed, 16 Dec 2020 14:07:10 -0800 Subject: [PATCH] Rollforward serialization tool with internal test fix PiperOrigin-RevId: 347894186 Change-Id: I00bd6f41bca9980df6785e067e01c3b94dc81a95 --- tensorflow/lite/kernels/BUILD | 1 + tensorflow/lite/kernels/subgraph_test_util.cc | 66 +++++--- tensorflow/lite/tools/serialization/BUILD | 2 + tensorflow/lite/tools/serialization/README.md | 63 +++++++ tensorflow/lite/tools/serialization/writer.cc | 2 +- .../lite/tools/serialization/writer_lib.cc | 138 +++++++++++----- .../lite/tools/serialization/writer_lib.h | 120 ++++++++++---- .../tools/serialization/writer_lib_test.cc | 155 ++++++++++++++++-- .../lite/tools/serialization/writer_test.cc | 2 +- 9 files changed, 439 insertions(+), 110 deletions(-) create mode 100644 tensorflow/lite/tools/serialization/README.md diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index a0808aafc80..cdab82c0914 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -2256,6 +2256,7 @@ cc_library( ":builtin_ops", ":kernel_util", ":variable_op_kernels", + "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "@com_google_googletest//:gtest", diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index 6cf3e89b8c1..3f65b70bcee 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/subgraph.h" @@ -113,10 +114,11 @@ void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) { TfLiteAddParams* params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); params->activation = kTfLiteActNone; + auto* add_reg = ops::builtin::Register_ADD(); + add_reg->builtin_code = kTfLiteBuiltinAdd; int node_index; - subgraph->AddNodeWithParameters( - {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_ADD(), &node_index); + subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0, + params, add_reg, &node_index); } // Build a subgraph with an mul op. Helper function for testing. @@ -143,10 +145,11 @@ void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) { TfLiteMulParams* params = reinterpret_cast(malloc(sizeof(TfLiteMulParams))); params->activation = kTfLiteActNone; + auto* mul_reg = ops::builtin::Register_MUL(); + mul_reg->builtin_code = kTfLiteBuiltinMul; int node_index; - subgraph->AddNodeWithParameters( - {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_MUL(), &node_index); + subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0, + params, mul_reg, &node_index); } // Build a subgraph with a pad op. Helper function for testing. @@ -172,10 +175,11 @@ void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) { TfLitePadParams* params = reinterpret_cast(malloc(sizeof(TfLitePadParams))); + auto* pad_reg = ops::builtin::Register_PAD(); + pad_reg->builtin_code = kTfLiteBuiltinPad; int node_index; - subgraph->AddNodeWithParameters( - {kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_PAD(), &node_index); + subgraph->AddNodeWithParameters({kInput1, kInput2}, {kOutput}, {}, nullptr, 0, + params, pad_reg, &node_index); } void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) { @@ -205,11 +209,12 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) { reinterpret_cast(malloc(sizeof(TfLiteIfParams))); params->then_subgraph_index = 1; params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinIf; int node_index; - subgraph->AddNodeWithParameters( - {kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_IF(), &node_index); + subgraph->AddNodeWithParameters({kCondInput, kInput1, kInput2}, {kOutput}, {}, + nullptr, 0, params, if_reg, &node_index); } void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) { @@ -236,11 +241,13 @@ void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) { SetupTensor(subgraph, kInput2, kTfLiteInt32); SetupTensor(subgraph, kOutput, kTfLiteBool); + auto* le_reg = ops::builtin::Register_LESS_EQUAL(); + le_reg->builtin_code = kTfLiteBuiltinLessEqual; + CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs}); int node_index; - subgraph->AddNodeWithParameters( - {kInput1, kConstRhs}, {kOutput}, {}, nullptr, 0, nullptr, - ::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index); + subgraph->AddNodeWithParameters({kInput1, kConstRhs}, {kOutput}, {}, nullptr, + 0, nullptr, le_reg, &node_index); } void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) { @@ -277,13 +284,13 @@ void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) { TfLiteAddParams* params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); params->activation = kTfLiteActNone; - subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_ADD(), + auto* add_reg = ops::builtin::Register_ADD(); + add_reg->builtin_code = kTfLiteBuiltinAdd; + subgraph->AddNodeWithParameters({0, 4}, {2}, {}, nullptr, 0, params, add_reg, &node_index); params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); params->activation = kTfLiteActNone; - subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_ADD(), + subgraph->AddNodeWithParameters({2, 1}, {3}, {}, nullptr, 0, params, add_reg, &node_index); } @@ -327,14 +334,18 @@ void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph, TfLiteAddParams* add_params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); add_params->activation = kTfLiteActNone; - subgraph->AddNodeWithParameters( - {kInputCounter, kConstStep}, {kOutputCounter}, {}, nullptr, 0, add_params, - ::tflite::ops::builtin::Register_ADD(), &node_index); + auto* add_reg = ops::builtin::Register_ADD(); + add_reg->builtin_code = kTfLiteBuiltinAdd; + subgraph->AddNodeWithParameters({kInputCounter, kConstStep}, {kOutputCounter}, + {}, nullptr, 0, add_params, add_reg, + &node_index); TfLitePadParams* pad_params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); - subgraph->AddNodeWithParameters( - {kInputValue, kConstPadding}, {kOutputValue}, {}, nullptr, 0, pad_params, - ::tflite::ops::builtin::Register_PAD(), &node_index); + auto* pad_reg = ops::builtin::Register_PAD(); + pad_reg->builtin_code = kTfLiteBuiltinPad; + subgraph->AddNodeWithParameters({kInputValue, kConstPadding}, {kOutputValue}, + {}, nullptr, 0, pad_params, pad_reg, + &node_index); } void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { @@ -364,11 +375,12 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { reinterpret_cast(malloc(sizeof(TfLiteWhileParams))); params->cond_subgraph_index = 1; params->body_subgraph_index = 2; + auto* while_reg = ops::builtin::Register_WHILE(); + while_reg->builtin_code = kTfLiteBuiltinWhile; int node_index; subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params, - ::tflite::ops::builtin::Register_WHILE(), - &node_index); + while_reg, &node_index); } void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph( diff --git a/tensorflow/lite/tools/serialization/BUILD b/tensorflow/lite/tools/serialization/BUILD index ceb11e204d8..5472dbee1a2 100644 --- a/tensorflow/lite/tools/serialization/BUILD +++ b/tensorflow/lite/tools/serialization/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs_with_reflection", "//tensorflow/lite/schema:schema_utils", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -67,6 +68,7 @@ cc_test( "//tensorflow/lite:framework", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:subgraph_test_util", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", diff --git a/tensorflow/lite/tools/serialization/README.md b/tensorflow/lite/tools/serialization/README.md new file mode 100644 index 00000000000..bd6c91e2796 --- /dev/null +++ b/tensorflow/lite/tools/serialization/README.md @@ -0,0 +1,63 @@ +# TFLite Serialization Tool + +**NOTE:** This tool is intended for advanced users only, and should be used with +care. + +The (C++) serialization library generates and writes a TFLite flatbuffer given +an `Interpreter` or `Subgraph`. Example use-cases include authoring models with +the `Interpreter` API, or updating models on-device (by modifying `tensor.data` +for relevant tensors). + +## Serialization + +### Writing flatbuffer to file + +To write a TFLite model from an `Interpreter` (see `lite/interpreter.h`): +`std::unique_ptr interpreter; // ...build/modify +interpreter... tflite::ModelWriter writer(interpreter.get()); std::string +filename = "/tmp/model.tflite"; writer.Write(filename);` + +Note that the above API does not support custom I/O tensors or custom ops yet. +However, it does support model with Control Flow. + +To generate/write a flatbuffer for a particular `Subgraph` (see +`lite/core/subgraph.h`) you can use `SubgraphWriter`. + +``` +std::unique_ptr interpreter; +// ...build/modify interpreter... +// The number of subgraphs can be obtained by: +// const int num_subgraphs = interpreter_->subgraphs_size(); +// Note that 0 <= subgraph_index < num_subgraphs +tflite::SubgraphWriter writer(&interpreter->subgraph(subgraph_index)); +std::string filename = "/tmp/model.tflite"; +writer.Write(filename); +``` + +`SubgraphWriter` supports custom ops and/or custom I/O tensors. + +### Generating flatbuffer in-memory + +Both `ModelWriter` and `SubgraphWriter` support a `GetBuffer` method to return +the generated flatbuffer in-memory: + +``` +std::unique_ptr output_buffer; +size_t output_buffer_size; +tflite::ModelWriter writer(interpreter.get()); +writer.GetBuffer(&output_buffer, &output_buffer_size); +``` + +## De-serialization + +The flatbuffers written as above can be de-serialized just like any other TFLite +model, for eg: + +``` +std::unique_ptr model = + FlatBufferModel::BuildFromFile(filename); +tflite::ops::builtin::BuiltinOpResolver resolver; +InterpreterBuilder builder(*model, resolver); +std::unique_ptr new_interpreter; +builder(&new_interpreter); +``` diff --git a/tensorflow/lite/tools/serialization/writer.cc b/tensorflow/lite/tools/serialization/writer.cc index fb816792b6a..e52114b965d 100644 --- a/tensorflow/lite/tools/serialization/writer.cc +++ b/tensorflow/lite/tools/serialization/writer.cc @@ -34,7 +34,7 @@ int main(int argc, char* argv[]) { std::unique_ptr interpreter; tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); - tflite::SubgraphWriter writer(&interpreter->primary_subgraph()); + tflite::ModelWriter writer(interpreter.get()); writer.Write(argv[2]); return 0; diff --git a/tensorflow/lite/tools/serialization/writer_lib.cc b/tensorflow/lite/tools/serialization/writer_lib.cc index 0d831f5f9a0..7270da510d8 100644 --- a/tensorflow/lite/tools/serialization/writer_lib.cc +++ b/tensorflow/lite/tools/serialization/writer_lib.cc @@ -29,6 +29,41 @@ limitations under the License. #include "tensorflow/lite/version.h" namespace tflite { +namespace { + +flatbuffers::Offset>> +CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb, + std::vector* opcodes) { + std::vector> codes; + for (const auto& it : *opcodes) { + const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); + codes.push_back(CreateOperatorCodeDirect( + *fbb, static_cast(it.builtin), custom_name)); + } + return fbb->template CreateVector>(codes); +} + +flatbuffers::Offset>> +ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb, + std::vector>* buffers) { + std::vector> buffer_vector; + for (auto buffer : *buffers) { + auto data_offset = fbb->CreateVector(buffer.first, buffer.second); + buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); + } + return fbb->template CreateVector>(buffer_vector); +} + +TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) { + FILE* fp = fopen(filename.c_str(), "wb"); + if (!fp) return kTfLiteError; + + const int result_size = fwrite(data, 1, size, fp); + fclose(fp); + if (result_size != size) return kTfLiteError; + + return kTfLiteOk; +} std::pair> CreateBuiltinUnion( flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op, @@ -39,6 +74,8 @@ std::pair> CreateBuiltinUnion( return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset()); } +} // namespace + template flatbuffers::Offset> SubgraphWriter::ExportVector( flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) { @@ -159,8 +196,8 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) { // Allocate a buffer index int buffer_index = 0; // This is null if (tensor->allocation_type == kTfLiteMmapRo) { - buffer_index = buffers_.size(); - buffers_.push_back(std::make_pair( + buffer_index = buffers_->size(); + buffers_->push_back(std::make_pair( reinterpret_cast(tensor->data.raw), tensor->bytes)); } // Primitive type. @@ -214,23 +251,12 @@ SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) { flatbuffers::Offset>> SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) { - std::vector> buffer_vector; - for (auto buffer : buffers_) { - auto data_offset = fbb->CreateVector(buffer.first, buffer.second); - buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); - } - return fbb->template CreateVector>(buffer_vector); + return ExportBuffersImpl(fbb, buffers_); } flatbuffers::Offset>> SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) { - std::vector> codes; - for (const auto& it : opcodes_) { - const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); - codes.push_back(CreateOperatorCodeDirect( - *fbb, static_cast(it.builtin), custom_name)); - } - return fbb->template CreateVector>(codes); + return CreateOpCodeTableImpl(fbb, opcodes_); } template @@ -254,19 +280,9 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr* out, size_t* size) { if (!out || !size) return kTfLiteError; flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); - std::vector> subgraphs_as_vector; - { // subgraph specific stuff - auto tensors = ExportTensors(&builder); - std::vector written_inputs = RemapTensorIndicesToWritten(inputs_); - std::vector written_outputs = RemapTensorIndicesToWritten(outputs_); - auto inputs = ExportVector(&builder, written_inputs); - auto outputs = ExportVector(&builder, written_outputs); + subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder)); - auto ops = ExportOperators(&builder); - subgraphs_as_vector.push_back( - CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0)); - } flatbuffers::Offset>> buffers = ExportBuffers(&builder); @@ -284,21 +300,23 @@ TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr* out, return kTfLiteOk; } +flatbuffers::Offset SubgraphWriter::PopulateAndGetOffset( + flatbuffers::FlatBufferBuilder* builder) { + auto tensors = ExportTensors(builder); + std::vector written_inputs = RemapTensorIndicesToWritten(inputs_); + std::vector written_outputs = RemapTensorIndicesToWritten(outputs_); + auto inputs = ExportVector(builder, written_inputs); + auto outputs = ExportVector(builder, written_outputs); + + auto ops = ExportOperators(builder); + return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0); +} + TfLiteStatus SubgraphWriter::Write(const std::string& filename) { std::unique_ptr buffer; size_t size; TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); - - FILE* fp = fopen(filename.c_str(), "wb"); - if (!fp) return kTfLiteError; - - if (fwrite(buffer.get(), 1, size, fp) != size) { - fclose(fp); - return kTfLiteError; - } - if (fclose(fp)) return kTfLiteError; - - return kTfLiteOk; + return WriteImpl(filename, buffer.get(), size); } TfLiteStatus SubgraphWriter::RegisterCustomWriter( @@ -377,4 +395,50 @@ TfLiteStatus SubgraphWriter::SetCustomInputOutput( return kTfLiteOk; } +flatbuffers::Offset>> +ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) { + return ExportBuffersImpl(fbb, &buffers_); +} + +flatbuffers::Offset>> +ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) { + return CreateOpCodeTableImpl(fbb, &opcodes_); +} + +TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr* out, + size_t* size) { + if (!out || !size) return kTfLiteError; + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); + + std::vector> subgraphs_as_vector; + for (int i = 0; i < interpreter_->subgraphs_size(); ++i) { + SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_, + &builtin_op_to_opcode_); + subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder)); + } + + flatbuffers::Offset>> + buffers = ExportBuffers(&builder); + + auto description = builder.CreateString("Exported from Subgraph."); + + auto op_codes = CreateOpCodeTable(&builder); + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs_as_vector), + description, buffers); + ::tflite::FinishModelBuffer(builder, model); + const uint8_t* buffer = builder.GetBufferPointer(); + *size = builder.GetSize(); + (*out).reset(new uint8_t[*size]); + memcpy(out->get(), buffer, *size); + return kTfLiteOk; +} + +TfLiteStatus ModelWriter::Write(const std::string& filename) { + std::unique_ptr buffer; + size_t size; + TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); + return WriteImpl(filename, buffer.get(), size); +} + } // namespace tflite diff --git a/tensorflow/lite/tools/serialization/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h index a18a3dd0958..3119278e77a 100644 --- a/tensorflow/lite/tools/serialization/writer_lib.h +++ b/tensorflow/lite/tools/serialization/writer_lib.h @@ -12,37 +12,72 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Writes a flatbuffer of a currently loaded TensorFlow Lite subgraph. -// -// Usage: -// From command line: -// bazel run third_party/tensorflow/lite/experimental/writer:writer -// -- foo.tflite foo.out.tflite -// -// From C++ -// std::unique_ptr interpreter; -// // Build Interpreter however -// // ... -// SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite"); +// Library to write a flatbuffer of a currently loaded TFLite model/subgraph. + #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" #include "tensorflow/lite/tools/serialization/enum_mapping.h" #include "tensorflow/lite/version.h" namespace tflite { +struct OpCode { + int builtin; + std::string custom; +}; + +// Handles writing a full TFLite model (with 1 or more subgraphs) to a +// serialized TF lite file format. +// TODO(b/174708523): Support custom I/O or unused tensors later. +class ModelWriter { + public: + // Construct a writer for the specified `interpreter`. Then, use + // .Write() or .GetBuffer(...) to extract the data. + explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) { + buffers_.push_back(std::make_pair(nullptr, 0)); + } + + // Get a buffer and size of a serialized flatbuffer. + TfLiteStatus GetBuffer(std::unique_ptr* out, size_t* size); + // Write the serialized flatbuffer to the prescribed `filename`. + TfLiteStatus Write(const std::string& filename); + + private: + template + using Offset = flatbuffers::Offset; + Offset>> CreateOpCodeTable( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> ExportBuffers( + flatbuffers::FlatBufferBuilder* fbb); + + // ModelWriter does not take ownership of this object. + Interpreter* const interpreter_; + + // This data corresponds to the overall model (rather than individual + // subgraphs), so we define common fields. Keep track of byte buffers + std::vector> buffers_; + // List of used opcodes + std::vector opcodes_; + absl::flat_hash_map builtin_op_to_opcode_; +}; + // Handles writing TensorFlow Lite running subgraph to a serialized TF lite // file format. +// TODO(b/174708523): Reconcile into ModelWriter? class SubgraphWriter { public: + friend class ModelWriter; + typedef flatbuffers::Offset (*CustomWriter)( flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index, flatbuffers::Offset>* output_options, @@ -55,7 +90,10 @@ class SubgraphWriter { inputs_(subgraph->inputs()), outputs_(subgraph->outputs()), execution_plan_(subgraph->execution_plan()) { - buffers_.push_back(std::make_pair(nullptr, 0)); + buffers_ = &buffers_data_; + opcodes_ = &opcodes_data_; + builtin_op_to_opcode_ = &builtin_op_to_opcode_data_; + buffers_->push_back(std::make_pair(nullptr, 0)); } // Get a buffer and size of a serialized flatbuffer. @@ -77,6 +115,28 @@ class SubgraphWriter { const std::vector& execution_plan); private: + // Used by ModelWriter. + explicit SubgraphWriter( + Subgraph* subgraph, + std::vector>* external_buffers, + std::vector* external_opcodes, + absl::flat_hash_map* external_builtin_op_to_opcode) + : subgraph_(subgraph), + inputs_(subgraph->inputs()), + outputs_(subgraph->outputs()), + execution_plan_(subgraph->execution_plan()) { + buffers_ = external_buffers; + opcodes_ = external_opcodes; + builtin_op_to_opcode_ = external_builtin_op_to_opcode; + buffers_->push_back(std::make_pair(nullptr, 0)); + } + + // Used by ModelWriter to populate data specific to this subgraph. + // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_, + // etc. & populated in the Flatbuffer by ModelWriter. + flatbuffers::Offset PopulateAndGetOffset( + flatbuffers::FlatBufferBuilder* builder); + template using Offset = flatbuffers::Offset; template @@ -102,11 +162,11 @@ class SubgraphWriter { int GetOpCodeForBuiltin(int builtin_op_index) { // auto it = builtin_op_to_opcode_.find(builtin_op_index); - std::pair result = - builtin_op_to_opcode_.insert( - std::make_pair(builtin_op_index, opcodes_.size())); + std::pair result = + builtin_op_to_opcode_->insert( + std::make_pair(builtin_op_index, opcodes_->size())); if (result.second) { - opcodes_.push_back({builtin_op_index, ""}); + opcodes_->push_back({builtin_op_index, ""}); } return result.first->second; } @@ -114,9 +174,9 @@ class SubgraphWriter { int GetOpCodeForCustom(const std::string& custom_name) { std::pair result = custom_op_to_opcode_.insert( - std::make_pair(custom_name, opcodes_.size())); + std::make_pair(custom_name, opcodes_->size())); if (result.second) { - opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); + opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name}); } return result.first->second; } @@ -129,22 +189,26 @@ class SubgraphWriter { std::vector outputs_; // Order of nodes to be written. std::vector execution_plan_; - // Keep track of byte buffers - std::vector> buffers_; // List of op codes and mappings from builtin or custom op to opcode - struct OpCode { - int builtin; - std::string custom; - }; std::set unused_tensors_; // For every tensor index in the subgraph, the index in the written. // This is different due to temporary and unused tensors not being written. std::vector tensor_to_written_tensor_; - // List of used opcodes - std::vector opcodes_; - std::unordered_map builtin_op_to_opcode_; std::unordered_map custom_op_to_opcode_; std::unordered_map custom_op_to_writer_; + + // We use pointers for these, since they may be provided by ModelWriter. + // Keep track of byte buffers + std::vector>* buffers_; + // List of used opcodes + std::vector* opcodes_; + absl::flat_hash_map* builtin_op_to_opcode_; + + // These are used if SubgraphWriter is being used directly. + std::vector> buffers_data_; + // List of used opcodes + std::vector opcodes_data_; + absl::flat_hash_map builtin_op_to_opcode_data_; }; } // namespace tflite diff --git a/tensorflow/lite/tools/serialization/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc index 189b4bc106f..3f73f3c2b0f 100644 --- a/tensorflow/lite/tools/serialization/writer_lib_test.cc +++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc @@ -15,21 +15,47 @@ limitations under the License. #include "tensorflow/lite/tools/serialization/writer_lib.h" +#include #include #include +#include +#include #include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/subgraph_test_util.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/testing/util.h" namespace tflite { -// Make an interpreter that has no tensors and no nodes -// TODO(b/113731921): add more tests. -TEST(Writer, FloatModelTest) { + +using subgraph_test_util::CheckIntTensor; +using subgraph_test_util::FillIntTensor; + +std::string CreateFilePath(const std::string& file_name) { + return std::string(getenv("TEST_TMPDIR")) + file_name; +} + +// The bool param indicates whether we use SubgraphWriter(true) or +// ModelWriter(false) for the test +class SingleSubgraphTest : public ::testing::TestWithParam { + protected: + void WriteToFile(Interpreter* interpreter, const std::string& filename, + bool use_subgraph_writer) { + if (use_subgraph_writer) { + SubgraphWriter writer(&interpreter->primary_subgraph()); + CHECK_EQ(writer.Write(filename), kTfLiteOk); + } else { + ModelWriter writer(interpreter); + CHECK_EQ(writer.Write(filename), kTfLiteOk); + } + } +}; + +TEST_P(SingleSubgraphTest, InvalidDestinations) { Interpreter interpreter; interpreter.AddTensors(3); float foo[] = {1, 2, 3}; @@ -52,10 +78,53 @@ TEST(Writer, FloatModelTest) { interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, reinterpret_cast(builtin_data), reg); - SubgraphWriter writer(&interpreter.primary_subgraph()); - writer.Write("/tmp/test_float.tflite"); + // Check if invalid filename is handled gracefully. + if (GetParam()) { + SubgraphWriter writer(&interpreter.primary_subgraph()); + CHECK_EQ(writer.Write(""), kTfLiteError); + } else { + ModelWriter writer(&interpreter); + CHECK_EQ(writer.Write(""), kTfLiteError); + } + + // Check if invalid buffer is handled gracefully. + size_t size; + if (GetParam()) { + SubgraphWriter writer(&interpreter.primary_subgraph()); + CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError); + } else { + ModelWriter writer(&interpreter); + CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError); + } +} + +TEST_P(SingleSubgraphTest, FloatModelTest) { + Interpreter interpreter; + interpreter.AddTensors(3); + float foo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantization()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(), + reinterpret_cast(foo), sizeof(foo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantization()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + builtin_data->pot_scale_int16 = false; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast(builtin_data), reg); + + const std::string test_file = CreateFilePath("test_float.tflite"); + WriteToFile(&interpreter, test_file, GetParam()); std::unique_ptr model = - FlatBufferModel::BuildFromFile("/tmp/test_float.tflite"); + FlatBufferModel::BuildFromFile(test_file.c_str()); InterpreterBuilder builder(*model, resolver); std::unique_ptr new_interpreter; builder(&new_interpreter); @@ -63,7 +132,7 @@ TEST(Writer, FloatModelTest) { } // Tests writing only a portion of the subgraph. -TEST(Writer, CustomInputOutputTest) { +TEST_P(SingleSubgraphTest, CustomInputOutputTest) { Interpreter interpreter; interpreter.AddTensors(4); constexpr float kFoo[] = {1, 2, 3}; @@ -94,22 +163,23 @@ TEST(Writer, CustomInputOutputTest) { interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2); // Only write the second op. + const std::string test_file = CreateFilePath("test_custom.tflite"); SubgraphWriter writer(&interpreter.primary_subgraph()); EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3}, /*execution_plan=*/{1}), kTfLiteOk); writer.SetUnusedTensors({0, 1}); - writer.Write("/tmp/test_custom.tflite"); + writer.Write(test_file); std::unique_ptr model = - FlatBufferModel::BuildFromFile("/tmp/test_custom.tflite"); + FlatBufferModel::BuildFromFile(test_file.c_str()); InterpreterBuilder builder(*model, resolver); std::unique_ptr new_interpreter; builder(&new_interpreter); ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); } -TEST(Writer, CustomInputOutputErrorCasesTest) { +TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) { Interpreter interpreter; interpreter.AddTensors(5); constexpr float kFoo[] = {1, 2, 3}; @@ -160,7 +230,7 @@ TEST(Writer, CustomInputOutputErrorCasesTest) { kTfLiteOk); } -TEST(Writer, PerTensorQuantizedModelTest) { +TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) { Interpreter interpreter; interpreter.AddTensors(3); interpreter.SetTensorParametersReadWrite( @@ -181,16 +251,18 @@ TEST(Writer, PerTensorQuantizedModelTest) { interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, reinterpret_cast(builtin_data), reg); - SubgraphWriter writer(&interpreter.primary_subgraph()); - writer.Write("/tmp/test_uint8.tflite"); + const std::string test_file = CreateFilePath("test_uint8.tflite"); + WriteToFile(&interpreter, test_file, GetParam()); std::unique_ptr model = - FlatBufferModel::BuildFromFile("/tmp/test_uint8.tflite"); + FlatBufferModel::BuildFromFile(test_file.c_str()); InterpreterBuilder builder(*model, resolver); std::unique_ptr new_interpreter; builder(&new_interpreter); CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); } +INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool()); + struct ReshapeTestPattern { int num_inputs; bool is_param_valid; @@ -241,8 +313,8 @@ TEST_P(ReshapeLayerTest, ReshapeLayerTest) { SubgraphWriter writer(&interpreter.primary_subgraph()); std::stringstream ss; - ss << "/tmp/test_reshape_" << param.num_inputs << param.is_param_valid - << ".tflite"; + ss << CreateFilePath("test_reshape_") << param.num_inputs + << param.is_param_valid << ".tflite"; std::string filename = ss.str(); writer.Write(filename); std::unique_ptr model = @@ -268,6 +340,57 @@ INSTANTIATE_TEST_SUITE_P( std::string name = ss.str(); return name; }); + +class WhileTest : public subgraph_test_util::ControlFlowOpTest {}; + +// The test builds a model that produces the i-th number of +// triangular number sequence: 1, 3, 6, 10, 15, 21, 28. +TEST_F(WhileTest, TestTriangularNumberSequence) { + const int kSeqNumber = 4; + const int kExpectedValue = 15; + + interpreter_.reset(new Interpreter); + interpreter_->AddSubgraphs(2); + builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {kSeqNumber + 1}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {1}, {kExpectedValue}); + + // Now serialize & deserialize model into a new Interpreter. + ModelWriter writer(interpreter_.get()); + const std::string test_file = CreateFilePath("test_while.tflite"); + writer.Write(test_file); + std::unique_ptr model = + FlatBufferModel::BuildFromFile(test_file.c_str()); + tflite::ops::builtin::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); + + // Check deserialized model. + new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1}); + new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1}); + ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk); + FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1}); + ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk); + output1 = new_interpreter->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {kSeqNumber + 1}); + output2 = new_interpreter->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {1}, {kExpectedValue}); +} + } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/lite/tools/serialization/writer_test.cc b/tensorflow/lite/tools/serialization/writer_test.cc index ccaab76776b..2ad77df8f7c 100644 --- a/tensorflow/lite/tools/serialization/writer_test.cc +++ b/tensorflow/lite/tools/serialization/writer_test.cc @@ -35,7 +35,7 @@ int main(int argc, char* argv[]) { std::unique_ptr interpreter; tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); - tflite::SubgraphWriter writer(&interpreter->primary_subgraph()); + tflite::ModelWriter writer(interpreter.get()); std::unique_ptr output_buffer; size_t output_buffer_size; writer.GetBuffer(&output_buffer, &output_buffer_size);