STT-tensorflow/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
stevensa a0b233ac41 Merge branch 'master' into IFX/PR_tfl_converter_QAT_fixes
# Conflicts:
#	tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
2020-05-27 15:26:47 +02:00

251 lines
9.9 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include "absl/strings/str_split.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using mlir::FuncOp;
using mlir::MLIRContext;
using mlir::ModuleOp;
using stream_executor::port::StatusOr;
// Debugging flag to print function mapping in the flatbuffer.
// NOLINTNEXTLINE
static llvm::cl::opt<bool> print_function_result_mapping(
"print-function-result-mapping",
llvm::cl::desc(
"Print the mapping of function result to flatbuffer output buffer"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> weight_quantization(
"weight_quantization",
llvm::cl::desc("The type of the quantized weight buffer. Must be NONE, "
"INT8, FLOAT16."),
llvm::cl::init("NONE"));
enum TranslationStatus { kTrSuccess, kTrFailure };
static int PrintFunctionResultMapping(const std::string &result,
ModuleOp module) {
// Build model from the resultant string to extract the return values from
// their source of truth.
auto model =
tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
if (!model) return kTrFailure;
// Get an unknown location for where we don't have a terminator to get the
// location of the return value from.
auto unknown_loc = mlir::UnknownLoc::get(module.getContext());
auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
std::function<mlir::Location(int)> loc) {
const auto &output_tensor = (*subgraph.tensors())[buffer];
std::cout << "\tname: '"
<< (output_tensor->name() ? output_tensor->name()->str()
: "<<unnamed>>")
<< "' buffer: " << buffer;
if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
std::cout << '\n';
};
// For every subgraph print out the name (if available), each result's output
// buffer number and location of the return value (if available).
for (auto *subgraph : *(*model)->subgraphs()) {
std::string subgraph_name =
subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
std::cout << '\'' << subgraph_name << "' inputs:\n";
int i = 0;
for (auto input : *subgraph->inputs())
print_buffer(*subgraph, i++, input, nullptr);
std::cout << '\'' << subgraph_name << "' outputs:\n";
mlir::Operation *terminator = nullptr;
if (subgraph->name()) {
if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
terminator = fn.back().getTerminator();
}
i = 0;
for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
});
}
}
return kTrSuccess;
}
int main(int argc, char **argv) {
// TODO(jpienaar): Revise the command line option parsing here.
tensorflow::InitMlir y(&argc, &argv);
// TODO(antiagainst): We are pulling in multiple transformations as follows.
// Each transformation has its own set of command-line options; options of one
// transformation can essentially be aliases to another. For example, the
// -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
// -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
// -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
// We need to disable duplicated ones to provide a cleaner command-line option
// interface. That also means we need to relay the value set in one option to
// all its aliases.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
MLIRContext context;
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module;
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done.
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
int saved_model_version;
if (import_saved_model_object_graph) {
saved_model_version = 2;
} else {
saved_model_version = 1;
}
if (input_mlir)
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (exported_names.size() != 1) {
llvm::errs() << "There should be only one exported name";
return kTrFailure;
}
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else {
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error
// message. So we can just return here.
if (!module.ok()) return kTrFailure;
mlir::PassManager pm(&context);
mlir::applyPassManagerCLOptions(pm);
// Set the quantization specifications from the command line flags.
mlir::TFL::QuantizationSpecs quant_specs;
if (mlir::TFL::ParseInputNodeQuantSpecs(input_arrays, min_values, max_values,
inference_type, &quant_specs)) {
llvm::errs() << "Failed to get input quant spec.";
return kTrFailure;
}
if (weight_quantization != "NONE") {
quant_specs.weight_quantization = true;
if (weight_quantization == "INT8") {
quant_specs.inference_type = tensorflow::DT_QINT8;
} else if (weight_quantization == "FLOAT16") {
quant_specs.inference_type = tensorflow::DT_HALF;
} else {
llvm::errs() << "Unknown weight quantization " << weight_quantization;
return kTrFailure;
}
}
if (!emit_quant_adaptor_ops) {
quant_specs.inference_input_type = quant_specs.inference_type;
}
if (!quant_stats_file_name.empty()) {
std::string error_message;
auto file = mlir::openInputFile(quant_stats_file_name, &error_message);
if (!file) {
llvm::errs() << "fail to open quant stats file: "
<< quant_stats_file_name;
return kTrFailure;
}
quant_specs.serialized_quant_stats = file->getBuffer().str();
}
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// TODO(b/150901738): Move those into tf_tfl_translate.cc.
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, &result, &pm);
if (!status.ok()) return kTrFailure;
std::string error_msg;
auto output = mlir::openOutputFile(output_file_name, &error_msg);
if (output == nullptr) {
llvm::errs() << error_msg << '\n';
return kTrFailure;
}
output->os() << result;
output->keep();
// Print out debugging info related to function mapping.
if (print_function_result_mapping)
return PrintFunctionResultMapping(result, module.ValueOrDie().get());
return kTrSuccess;
}