Jaesung Chung 7292433984 Support flex ops in calibration optimization
This CL makes the tool generate a user-friendly error message as well.

In order to use the correct logger for mobile, it uses the error_reporter.

PiperOrigin-RevId: 316563081
Change-Id: Ib56f80330087750777725ed6ad3c97f54b1fa80b
2020-06-15 16:02:43 -07:00

413 lines
16 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
#include <fstream>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/stderr_reporter.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
#include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
#include "tensorflow/lite/tools/optimize/calibration/node_info_delegate.h"
namespace tflite {
namespace optimize {
namespace calibration {
namespace {
// Calibrator is used to hold information that can be accessed during kernel
// invocations.
// TfLite kernel invocations are C functions and cannot look at the global
// structure of the graph. Calibrator allows the kernel invoke functions to
// access the global structure of graph and know which node is currently being
// executed. This also allows us to write a simple kernel invoke wrapper
// (see LoggingEval) that can work for most builtin ops.
class Calibrator {
public:
Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
node_ptr_opinfo_map,
std::unique_ptr<LoggingOpResolver> logging_op_resolver,
ErrorReporter* error_reporter)
: node_ptr_opinfo_map_(node_ptr_opinfo_map),
logging_op_resolver_(std::move(logging_op_resolver)),
error_reporter_(error_reporter) {
logger_ = absl::make_unique<Logger>();
}
// Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
// Gets the instance of logger associated with the current context.
Logger* GetLogger() const { return logger_.get(); }
// Gets the error reporter.
ErrorReporter* GetErrorReporter() const { return error_reporter_; }
// Gets the operator information about the given TfLiteNode.
const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
return node_ptr_opinfo_map_.at(node);
}
private:
std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
const std::unordered_map<int, OperatorInfo> index_opinfo_;
std::unique_ptr<Logger> logger_;
ErrorReporter* error_reporter_;
};
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
auto op_info = node_ptr_opinfo_map_.at(node);
if (op_info.is_custom_op) {
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
op_info.version);
}
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
op_info.version);
}
// A registry of |Calibrator| objects per |TfLiteContext|.
// This global registry is needed to access |Calibrator| objects in the kernel
// invoke functions i.e. |TfLiteRegistration.invoke|.
// Kernel invoke functions are C functions that have limited access to
// |TfLiteContext|. Kernel invoke functions don't have access to global state of
// graph. That means during a kernel invocation, the function cannot know which
// node it was invoked for. E.g. in case of a model with |Conv| op at two
// locations, there is no easy way for the Conv.invoke function to disambiguate
// the calls.
//
// For calibration we solve this problem by creating a map of calibrators
// per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
//
// This registry is then accessed using a global getter function:
// |GetCalibratorRegistry|.
// E.g.
// TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
// .... code ....
// auto registry = GetCalibratorRegistry();
// auto calibrator = registry->GetCalibrator(context);
// ..... code ....
// }
//
// This way the kernel invoke functions can get the access to the Calibrator
// object associated with the |TfLiteContext|.
class GlobalCalibratorRegistry {
public:
// Get the |Calibrator| associated with given context, returns null if no
// calibrator is associated with the given context.
Calibrator* GetCalibrator(const TfLiteContext* context) const {
if (calibrator_registry_.find(context) == calibrator_registry_.cend()) {
return nullptr;
}
return calibrator_registry_.at(context).get();
}
// Removes the association between calibrator and context.
// Note: This deletes the calibrator as well.
void RemoveCalibrator(const TfLiteContext* context) {
calibrator_registry_.erase(context);
}
// Creates an instance of |Calibrator|.
// Registry owns the |Calibrator| object which can be deleted by calling
// |RemoveCalibrator|.
TfLiteStatus CreateCalibrator(
const TfLiteContext* context,
const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
std::unique_ptr<LoggingOpResolver> logging_op_resolver,
Calibrator** calibrator_ptr, ErrorReporter* reporter) {
if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
reporter->Report(
"Failed to create calibrator, context already registered.");
return kTfLiteError;
}
auto calibrator = absl::make_unique<Calibrator>(
node_to_opinfo, std::move(logging_op_resolver), reporter);
calibrator_registry_[context] = std::move(calibrator);
*calibrator_ptr = calibrator_registry_.at(context).get();
return kTfLiteOk;
}
private:
std::unordered_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
calibrator_registry_;
};
GlobalCalibratorRegistry* GetCalibratorRegistry() {
static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
return registry;
}
// Get the logging kernel if there are any.
// TODO(jianlijianli): extend this to support multiple recipe for the same
// model.
logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
TfLiteNode* node) {
const int lstm_number_input = 24;
if (node->inputs->size == lstm_number_input) {
// LSTM Op.
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
}
return nullptr;
}
// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
// invokes the wrapped implementation and then logs the outputs.
TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(context);
if (!calibrator) {
context->ReportError(context, "No calibrator found for context.");
return kTfLiteError;
}
auto kernel_invoke = calibrator->GetKernelInvoke(node);
auto logger = calibrator->GetLogger();
auto op_info = calibrator->GetOpInfo(node);
auto error_reporter = calibrator->GetErrorReporter();
for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
TfLiteStatus status;
if (kernel_invoke_intermediate == nullptr) {
status = kernel_invoke(context, node);
} else {
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger(),
error_reporter);
}
// TODO(shashishekhar): An intermediate tensor in graph will get logged twice
// once as an input and second time as output. This doesn't change the min max
// values but is inefficient.
// Using moving average will also break this.
// Log input again to make sure the state tensors are captured after lstm
// cell.
for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
for (int i : op_info.loggable_outputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
return status;
}
// Returns the loggable tensors. Not all inputs and outputs need to be logged.
// For example, const weight tensors which have buffers associated with them
// don't need to be logged.
std::vector<int> GetLoggableTensorIndices(
const std::vector<int>& tensor_indices,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
std::vector<int> loggable;
for (auto tensor_index : tensor_indices) {
if (tensor_index == kTfLiteOptionalTensor) {
continue;
}
auto tensor = tensors->Get(tensor_index);
auto buffer_index = tensor->buffer();
const bool has_no_buffer =
(tensor_buffers->Get(buffer_index) == nullptr) ||
(tensor_buffers->Get(buffer_index)->data() == nullptr) ||
(tensor_buffers->Get(buffer_index)->data()->size() == 0);
if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
loggable.push_back(tensor_index);
}
}
return loggable;
}
// Creates a mapping between the static model graph and the runtime TfLiteNode*
// nodes in the graph for the given context.
// This is done by querying the TfLiteContext for node and registrations using
// the |NodeInfoDelegateObserver|.
TfLiteStatus GetNodeOpInfoMapAndContext(
const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
tflite::Interpreter* const interpreter,
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
const TfLiteContext** context) {
NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
node_ptr_opinfo_map);
NodeInfoDelegateParams delegate_params;
delegate_params.delegate_observer = &delegate_observer;
TfLiteDelegate logging_delegate = CreateNodeInfoDelegate(&delegate_params);
auto modify_status = interpreter->ModifyGraphWithDelegate(&logging_delegate);
if (modify_status != kTfLiteOk) {
return kTfLiteError;
}
*context = delegate_observer.GetContext();
return kTfLiteOk;
}
string GetOpName(const tflite::OperatorCode& opcode) {
if (opcode.custom_code() != nullptr) {
return opcode.custom_code()->str();
}
return tflite::EnumNamesBuiltinOperator()[opcode.builtin_code()];
}
// A |CalibrationReader| that owns the Calibrator.
class Reader : public CalibrationReader {
public:
Reader(const TfLiteContext* context, const Logger* logger)
: CalibrationReader(logger), context_(context) {}
~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
private:
const TfLiteContext* context_;
};
} // namespace
TfLiteStatus BuildLoggingInterpreter(
const FlatBufferModel& model, const OpResolver& op_resolver,
std::unique_ptr<Interpreter>* interpreter,
std::unique_ptr<CalibrationReader>* calibration_reader) {
return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(),
op_resolver, interpreter, calibration_reader);
}
TfLiteStatus BuildLoggingInterpreter(
const tflite::Model* tflite_model, ErrorReporter* error_reporter,
const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter,
std::unique_ptr<CalibrationReader>* calibration_reader) {
if (error_reporter == nullptr) {
// Make sure error_reporter is valid.
error_reporter = DefaultErrorReporter();
}
auto subgraphs = tflite_model->subgraphs();
auto tensor_buffers = tflite_model->buffers();
if (subgraphs->size() != 1) {
error_reporter->Report(
"Only models with a single subgraph are supported, model had %d "
"subgraphs",
subgraphs->size());
return kTfLiteError;
}
// Populate the node index to operator info map.
// We want to collect this information so we can use it during runtime to
// log details of which inputs and outputs.
// At runtime TFLite kernel invoke functions can only look into their
// own node in the graph (TFLiteNode*) and some limited context information.
auto primary_subgraph = subgraphs->Get(0);
auto operator_codes = tflite_model->operator_codes();
auto operators = primary_subgraph->operators();
auto tensors = primary_subgraph->tensors();
std::unordered_map<int, OperatorInfo> node_to_opinfo;
BuiltinOpsSet builtin_op_and_versions;
CustomOpsSet custom_op_and_versions;
for (size_t i = 0; i < operators->size(); i++) {
OperatorInfo op_info;
op_info.node_index = i;
auto op = operators->Get(i);
auto operator_code = operator_codes->Get(op->opcode_index());
op_info.builtin_op_code = operator_code->builtin_code();
op_info.name = GetOpName(*operator_code);
op_info.is_custom_op = operator_code->custom_code() != nullptr;
op_info.version = operator_code->version();
auto op_inputs = op->inputs();
auto op_outputs = op->outputs();
op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
op_info.outputs = std::vector<int>(op_outputs->begin(), op_outputs->end());
op_info.loggable_inputs =
GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
op_info.loggable_outputs =
GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
if (op_info.is_custom_op) {
op_info.registration =
op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
custom_op_and_versions.insert(
{op_info.name.c_str(), operator_code->version()});
} else {
op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
operator_code->version());
builtin_op_and_versions.insert(
{op_info.builtin_op_code, operator_code->version()});
}
node_to_opinfo[i] = op_info;
}
// Prepare the logging op resolver to use |LoggingEval| for kernel
// invocations.
auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
error_reporter);
tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
error_reporter)(interpreter);
if (!(*interpreter)) {
error_reporter->Report("Failed to construct interpreter");
return kTfLiteError;
}
// Compute the mapping between runtime and static graph structure, i.e.
// (TfLiteContext, TfLiteNode) -> OperatorInfo
std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
const TfLiteContext* context = nullptr;
GetNodeOpInfoMapAndContext(node_to_opinfo, interpreter->get(),
&node_ptr_opinfo_map, &context);
Calibrator* calibrator = nullptr;
// Register a calibrator object for the context. This can be accessed
// during invocations by the logging kernels.
TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
error_reporter));
*calibration_reader = std::unique_ptr<CalibrationReader>(
new Reader(context, calibrator->GetLogger()));
return kTfLiteOk;
}
} // namespace calibration
} // namespace optimize
} // namespace tflite