STT-tensorflow/tensorflow/lite/toco/python/toco_python_api.cc
A. Unique TensorFlower c1ae0ef7ce Qualify uses of std::string
PiperOrigin-RevId: 316914367
Change-Id: Iae32a48b4db10d313f9e9b72f56eb6a6ac64c55f
2020-06-17 10:25:14 -07:00

318 lines
12 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include "google/protobuf/text_format.h"
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/import_tensorflow.h"
#include "tensorflow/lite/toco/logging/conversion_log_util.h"
#include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_convert.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/lite/toco/toco_port.h"
#include "tensorflow/lite/toco/toco_tooling.h"
#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/lite/toco/types.pb.h"
namespace toco {
void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
toco::TocoFlags* toco_flags,
const std::string& input_contents_txt,
const std::string& output_file_contents_txt,
const std::string& error_message,
GraphVizDumpOptions* dump_options) {
// Make sure the graphviz file will be dumped under the same folder.
dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
// Here we construct the `toco::Model` class based on the input graph def,
// it will then be used to populate the conversion log.
// TODO(haoliang): Don't depend on `toco::Model`.
std::unique_ptr<toco::Model> imported_model =
toco::Import(*toco_flags, model_flags, input_contents_txt);
// Dump pre-conversion toco logs.
TocoConversionLog toco_log_before;
PopulateConversionLog(*imported_model, &toco_log_before);
std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
"/toco_log_before.pb");
toco_log_before.SerializeToOstream(&osstream_before);
osstream_before.close();
toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
// Populate the post-conversion log, for convenient initiate the
// `toco::Model` class from the generated flatbuffer.
toco_flags->set_input_format(toco::FileFormat::TFLITE);
std::unique_ptr<toco::Model> flatbuffer_model =
toco::Import(*toco_flags, model_flags, output_file_contents_txt);
// Dump post-conversion toco logs.
TocoConversionLog toco_log_after;
PopulateConversionLog(*flatbuffer_model, &toco_log_after);
// Make sure we sanitize the error message.
toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
"/toco_log_after.pb");
toco_log_after.SerializeToOstream(&ostream_after);
ostream_after.close();
toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
}
// NOTE(aselle): We are using raw PyObject's here because we want to make
// sure we input and output bytes rather than unicode strings for Python3.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw,
PyObject* input_contents_txt_raw, bool extended_return,
PyObject* debug_info_txt_raw,
bool enable_mlir_converter) {
// Use Python C API to validate and convert arguments. In py3 (bytes),
// in py2 (str).
auto ConvertArg = [&](PyObject* obj, bool* error) {
char* buf;
Py_ssize_t len;
if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
*error = true;
return std::string();
} else {
*error = false;
return std::string(buf, len);
}
};
bool error;
std::string model_flags_proto_txt =
ConvertArg(model_flags_proto_txt_raw, &error);
if (error) {
PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
return nullptr;
}
std::string toco_flags_proto_txt =
ConvertArg(toco_flags_proto_txt_raw, &error);
if (error) {
PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
return nullptr;
}
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
if (error) {
PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
return nullptr;
}
// Use TOCO to produce new outputs.
toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert Model to Python String.");
return nullptr;
}
toco::TocoFlags toco_flags;
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert Toco to Python String.");
return nullptr;
}
tensorflow::GraphDebugInfo debug_info;
if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
if (error) {
PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
return nullptr;
}
if (!debug_info.ParseFromString(debug_info_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert DebugInfo to Python String.");
return nullptr;
}
}
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(input_contents_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert GraphDef to Python String.");
return nullptr;
}
auto& dump_options = *GraphVizDumpOptions::singleton();
if (toco_flags.has_dump_graphviz_dir()) {
dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
}
if (toco_flags.has_dump_graphviz_include_video()) {
dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
}
std::string output_file_contents_txt;
tensorflow::Status status;
int64 arithmetic_ops_count;
// Convert model.
if (enable_mlir_converter) {
if (!model_flags.saved_model_dir().empty()) {
status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
model_flags, toco_flags, &output_file_contents_txt);
} else {
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(input_contents_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert GraphDef to Python String.");
return nullptr;
}
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
model_flags, toco_flags, debug_info, graph_def,
&output_file_contents_txt);
if (!toco_flags.conversion_summary_dir().empty()) {
PopulateConversionLogHelper(
model_flags, &toco_flags, input_contents_txt,
output_file_contents_txt, status.error_message(), &dump_options);
}
}
} else {
status = Convert(input_contents_txt, toco_flags, model_flags,
&output_file_contents_txt, &arithmetic_ops_count);
}
if (!status.ok()) {
PyErr_SetString(PyExc_Exception, status.error_message().c_str());
return nullptr;
}
if (extended_return && !enable_mlir_converter) {
PyObject* dict = PyDict_New();
PyDict_SetItemString(
dict, "flatbuffer",
::tflite::python_utils::ConvertToPyString(
output_file_contents_txt.data(), output_file_contents_txt.size()));
PyDict_SetItemString(dict, "arithmetic_ops",
PyLong_FromLong(arithmetic_ops_count));
return dict;
}
// Convert arguments back to byte (py3) or str (py2)
return ::tflite::python_utils::ConvertToPyString(
output_file_contents_txt.data(), output_file_contents_txt.size());
}
PyObject* TocoGetPotentiallySupportedOps() {
std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
PyObject* list = PyList_New(supported_ops.size());
for (size_t i = 0; i < supported_ops.size(); ++i) {
const std::string& op = supported_ops[i];
PyObject* op_dict = PyDict_New();
PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
PyList_SetItem(list, i, op_dict);
}
return list;
}
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize, int inference_type) {
using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
error_reporter.get());
if (!model) {
PyErr_Format(PyExc_ValueError, "Invalid model");
return nullptr;
}
auto tflite_model = absl::make_unique<tflite::ModelT>();
model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
tflite::TensorType inference_tensor_type;
switch (inference_type) {
case toco::IODataType::QUANTIZED_INT16:
inference_tensor_type = tflite::TensorType_INT16;
break;
case toco::IODataType::QUANTIZED_UINT8:
inference_tensor_type = tflite::TensorType_UINT8;
break;
case toco::IODataType::INT8:
inference_tensor_type = tflite::TensorType_INT8;
break;
default:
return nullptr;
}
tflite::TensorType inference_io_type =
fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
flatbuffers::FlatBufferBuilder builder;
auto status = mlir::lite::QuantizeModel(
*tflite_model, inference_io_type, inference_io_type,
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
error_reporter.get());
if (status != kTfLiteOk) {
error_reporter->exception();
return nullptr;
}
return tflite::python_utils::ConvertToPyString(
reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
builder.GetSize());
}
PyObject* MlirSparsifyModel(PyObject* data) {
using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
error_reporter.get());
if (!model) {
PyErr_Format(PyExc_ValueError, "Invalid model");
return nullptr;
}
auto tflite_model = absl::make_unique<tflite::ModelT>();
model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
flatbuffers::FlatBufferBuilder builder;
auto status =
mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
if (status != kTfLiteOk) {
error_reporter->exception();
return nullptr;
}
return tflite::python_utils::ConvertToPyString(
reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
builder.GetSize());
}
} // namespace toco