STT-tensorflow/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
Jacques Pienaar 544ca41bed Populate value hoisting function that LoopLike uses
* Populate method of LoopLike op interface and verify with test with LICM.
  - The function to determine if values are outside the loop is still too conservative and needs to be expanded.
* Make canonical form of TFL While to match exported form. This is not strictly required for the optimization format, but makes export simpler as it matches export form.
  - Check that this constraint is met during translation to flatbuffer.
* Add an integration/execution test to verify the export is as expected by interpreter. This was as I had the export wrong initially and the first place I hit the failure was during execution. This is now check during export instead.

TFL WhileOp needs additional verification still to verify the types are compatible.

PiperOrigin-RevId: 292958195
Change-Id: If8b2b506c406aa183f1b365ab36b03b46e20087a
2020-02-03 11:07:22 -08:00

147 lines
5.5 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tool to run a TFLite computation from a MLIR input using the TFLite
// interpreter.
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
using llvm::cl::desc;
using llvm::cl::init;
using llvm::cl::opt;
// NOLINTNEXTLINE
static opt<std::string> input_filename(llvm::cl::Positional,
desc("<input file>"), init("-"));
// NOLINTNEXTLINE
static opt<bool> dump_state("dump-interpreter-state",
desc("dump interpreter state post execution"),
init(false));
// TODO(jpienaar): Move these functions to some debug utils.
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
auto begin = tensor.dims ? tensor.dims->data : nullptr;
auto end = tensor.dims ? tensor.dims->data + tensor.dims->size : nullptr;
return absl::StrJoin(begin, end, ", ");
}
template <typename T>
static std::string TfLiteTypedTensorString(const TfLiteTensor& tensor) {
const T* data = reinterpret_cast<T*>(tensor.data.raw);
if (!data) return "<null>";
int count = tensor.bytes / sizeof(T);
return absl::StrJoin(data, data + count, ", ");
}
// TODO(jpienaar): This really feels like something that should exist already.
static std::string TfLiteTensorString(const TfLiteTensor& tensor) {
switch (tensor.type) {
case kTfLiteInt32:
return TfLiteTypedTensorString<int32_t>(tensor);
case kTfLiteInt64:
return TfLiteTypedTensorString<int64_t>(tensor);
case kTfLiteFloat32:
return TfLiteTypedTensorString<float>(tensor);
default:
LOG(QFATAL) << "Unsupported type: " << TfLiteTypeGetName(tensor.type);
}
}
int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
<< "': " << error.message() << "\n";
return 1;
}
// Load the MLIR module.
mlir::MLIRContext context;
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
if (!module) return 1;
// TODO(jpienaar): Expand to support inputs.
mlir::FuncOp main = module->lookupSymbol<mlir::FuncOp>("main");
QCHECK(main) << "No 'main' function specified.";
if (main.getType().getNumInputs() != 0)
LOG(QFATAL) << "NYI: Only nullary functions supported.";
// Convert to flatbuffer.
std::string serialized_flatbuffer;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops))
return 1;
// Create TFLite interpreter & invoke converted program.
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(serialized_flatbuffer.c_str(),
serialized_flatbuffer.size());
tflite::ops::builtin::BuiltinOpResolver builtins;
std::unique_ptr<tflite::Interpreter> interpreter;
QCHECK(tflite::InterpreterBuilder(*model, builtins)(&interpreter) ==
kTfLiteOk);
QCHECK(interpreter->AllocateTensors() == kTfLiteOk);
QCHECK(interpreter->Invoke() == kTfLiteOk);
// Print the resulting outputs.
// TODO(jpienaar): Allow specifying output stream/file.
QCHECK(interpreter->outputs().size() == main.getType().getNumResults());
for (int index : interpreter->outputs()) {
const auto& out = *interpreter->tensor(index);
// Print name if named.
if (out.name) fprintf(stdout, "%s: ", out.name);
// Print tensor result.
fprintf(stdout, "Tensor<type: %s, shape: %s, values: %s>\n",
TfLiteTypeGetName(out.type), TfLiteTensorDimString(out).c_str(),
TfLiteTensorString(out).c_str());
}
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
return 0;
}