Abdurrahman Akkas 2a6aa033c2 Add custom input/output/execution_plan to SubgraphWriter.
PiperOrigin-RevId: 303785677
Change-Id: Iad91f6332510604cdb1de3423a4e1056d07d4f28
2020-03-30 11:58:00 -07:00

153 lines
6.1 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Writes a flatbuffer of a currently loaded TensorFlow Lite subgraph.
//
// Usage:
// From command line:
// bazel run third_party/tensorflow/lite/experimental/writer:writer
// -- foo.tflite foo.out.tflite
//
// From C++
// std::unique_ptr<Interpreter> interpreter;
// // Build Interpreter however
// // ... <omitted>
// SubgraphWriter(&interpreter->primary_subgraph()).Write("output.tflite");
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
#include <iostream>
#include <unordered_map>
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/writer/enum_mapping.h"
#include "tensorflow/lite/schema/reflection/schema_generated.h"
#include "tensorflow/lite/version.h"
namespace tflite {
// Handles writing TensorFlow Lite running subgraph to a serialized TF lite
// file format.
class SubgraphWriter {
public:
typedef flatbuffers::Offset<Operator> (*CustomWriter)(
flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
CustomOptionsFormat* custom_options_format);
// Construct a subgraph writer for the specified `subgraph`. Then, use
// .Write() or .GetBuffer(...) to extract the data.
explicit SubgraphWriter(Subgraph* subgraph)
: subgraph_(subgraph),
inputs_(subgraph->inputs()),
outputs_(subgraph->outputs()),
execution_plan_(subgraph->execution_plan()) {
buffers_.push_back(std::make_pair(nullptr, 0));
}
// Get a buffer and size of a serialized flatbuffer.
TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
// Write the serialized flatbuffer to the prescribed `filename`.
TfLiteStatus Write(const std::string& filename);
// Registers a custom writer for a custom op. The customization allows the
// caller to change the custom data.
TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
CustomWriter custom_writer);
// Tensors that are unused and shouldn't be written.
void SetUnusedTensors(const std::set<int>& unused_tensors) {
unused_tensors_ = unused_tensors;
}
// Sets custom inputs, outputs, and execution_plan so that a portion of the
// subgraph is written to the buffer instead of the whole subgraph.
TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs,
const std::vector<int>& outputs,
const std::vector<int>& execution_plan);
private:
template <class T>
using Offset = flatbuffers::Offset<T>;
template <class T_OUTPUT, class T_INPUT>
Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
flatbuffers::FlatBufferBuilder* fbb);
template <class T>
std::vector<int> RemapTensorIndicesToWritten(const T& input);
// Checks if given `input`, `output`, and `execution_plan` represents a valid
// model within the Subgraph.
TfLiteStatus CheckInputOutput(const std::vector<int>& inputs,
const std::vector<int>& outputs,
const std::vector<int>& execution_plan);
int GetOpCodeForBuiltin(int builtin_op_index) {
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
builtin_op_to_opcode_.insert(
std::make_pair(builtin_op_index, opcodes_.size()));
if (result.second) {
opcodes_.push_back({builtin_op_index, ""});
}
return result.first->second;
}
int GetOpCodeForCustom(const std::string& custom_name) {
std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
custom_op_to_opcode_.insert(
std::make_pair(custom_name, opcodes_.size()));
if (result.second) {
opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
}
return result.first->second;
}
// The subgraph we are writing
Subgraph* subgraph_;
// Input tensor indices to be written.
std::vector<int> inputs_;
// Output tensor indices to be written.
std::vector<int> outputs_;
// Order of nodes to be written.
std::vector<int> execution_plan_;
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of op codes and mappings from builtin or custom op to opcode
struct OpCode {
int builtin;
std::string custom;
};
std::set<int> unused_tensors_;
// For every tensor index in the subgraph, the index in the written.
// This is different due to temporary and unused tensors not being written.
std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
std::unordered_map<int, int> builtin_op_to_opcode_;
std::unordered_map<std::string, int> custom_op_to_opcode_;
std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_