STT-tensorflow/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
Jaesung Chung 17f94a78e8 Refactor reading builtin code in TFLite
This change is a preliminary work for resolving builtin code shortage problem.
It introduces a new utility build target, schema_utils, which will be used for
getting/setting builtin code operator value in TFLite flatbuffer in order to
have a single place to access actual fields for accessing values.

See also the RFC proposal draft,
https://github.com/tensorflow/community/pull/285

PiperOrigin-RevId: 335513647
Change-Id: I810a33425bbed3489cfe4a4a98a10dc4bd67a6ba
2020-10-05 15:36:21 -07:00

1129 lines
44 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include <algorithm>
#include <cctype>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
using llvm::ArrayRef;
using mlir::Builder;
using mlir::DenseElementsAttr;
using mlir::FuncOp;
using mlir::Location;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::Operation;
using mlir::OperationState;
using mlir::OwningModuleRef;
using mlir::RankedTensorType;
using mlir::UnrankedTensorType;
using mlir::Value;
using mlir::quant::QuantizedType;
using tflite::TensorT;
using xla::StatusOr;
namespace errors = tensorflow::errors;
namespace tfl = mlir::TFL;
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
// Work out a way to handle this and stub out the code until then
return tensor.shape.empty() && false;
}
bool IsQuantized(const TensorT& tensor) {
return (tensor.quantization != nullptr) &&
!tensor.quantization->zero_point.empty();
}
// Create the MLIR NamedLoc location corresponding to a given tensor
Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
if (tensor.name.empty()) {
return base;
}
return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
}
// Returns the correct type for a quantized tensor
// We have a special case for constants since they have a higher minimum value.
StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
bool is_constant = false) {
tflite::QuantizationParametersT& quant_params = *tensor.quantization;
if (quant_params.details.AsCustomQuantization()) {
return errors::Unimplemented("Cannot handle experimental quantization");
}
bool is_signed = true;
mlir::IntegerType storage_type;
if (tensor.type == tflite::TensorType_UINT8) {
is_signed = false;
storage_type = builder.getIntegerType(8);
} else {
auto raw_elem_type = ConvertElementType(tensor.type, builder);
if (!raw_elem_type.isa<mlir::IntegerType>()) {
return errors::InvalidArgument(
"Quantized tensors must be stored as integers");
}
storage_type = raw_elem_type.cast<mlir::IntegerType>();
}
// TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
// Since we don't know which ones are weights, we represent this optimization
// as a change in the storage bounds for the type for all constants of this
// type.
bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
is_signed, storage_type.getWidth()) +
is_weight_buffer;
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
is_signed, storage_type.getWidth());
uint32_t flags =
is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
// Rejects if quantized tensors have zero scales.
for (float scale : quant_params.scale) {
if (scale == 0) {
return errors::InvalidArgument(
"Quantized tensors must have non-zero scales");
}
}
// Scale size can't be zero as it is checked before.
if (quant_params.scale.size() != 1) {
llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
quant_params.scale.end());
return mlir::quant::UniformQuantizedPerAxisType::get(
flags, storage_type, builder.getF32Type(), scales,
quant_params.zero_point, quant_params.quantized_dimension, storage_min,
storage_max);
}
return mlir::quant::UniformQuantizedType::get(
flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
quant_params.zero_point.at(0), storage_min, storage_max);
}
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
// make that distinction and don't have to rely on context
// (input to main and constants must have static shape)
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
bool shapeless_are_scalars = false,
bool is_constant = false) {
mlir::Type elem_type = ConvertElementType(tensor.type, builder);
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
// if it's set
if (IsQuantized(tensor)) {
TF_ASSIGN_OR_RETURN(elem_type,
GetQuantizedType(tensor, builder, is_constant));
}
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
return RankedTensorType::get({}, elem_type);
}
if (!tensor.shape_signature.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
tensor.shape_signature.end());
return RankedTensorType::get(shape, elem_type);
}
if (!tensor.shape.empty()) {
llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
tensor.shape.end());
return RankedTensorType::get(shape, elem_type);
}
return UnrankedTensorType::get(elem_type);
}
// Extract the min max information in the tensor and create the quant stats op.
// If the input `tensor` has scale/zero_point, `res` should have quantized
// type, thus none stats op is required and nullptr is retruned.
// If the min max information is invalid, nullptr is returned.
mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
Value res) {
// If the `tensor` has scale/zero_point, it must have been quantized, then the
// min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored.
if (!res.getType()
.cast<mlir::ShapedType>()
.getElementType()
.isa<mlir::FloatType>()) {
return nullptr;
}
auto mins = tensor.quantization->min;
auto maxs = tensor.quantization->max;
if (mins.size() != maxs.size() || mins.empty()) return nullptr;
llvm::SmallVector<llvm::APFloat, 4> min_maxs;
min_maxs.reserve(mins.size() * 2);
for (int i = 0, end = mins.size(); i < end; ++i) {
llvm::APFloat min(mins[i]);
llvm::APFloat max(maxs[i]);
min_maxs.push_back(min);
min_maxs.push_back(max);
}
// The layer stats contain only the first min/max pairs.
mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
mlir::RankedTensorType::get({2}, b.getF32Type()),
{min_maxs[0], min_maxs[1]});
mlir::ElementsAttr axis_stats;
mlir::IntegerAttr axis;
if (mins.size() > 1) {
llvm::SmallVector<int64_t, 4> axis_stats_shape{
static_cast<int64_t>(mins.size()), 2};
axis_stats = mlir::DenseFPElementsAttr::get(
mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
min_maxs);
// TODO(fengliuai): this quantization dimension isn't correct.
axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
}
return b.create<mlir::quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis);
}
// Returns true if this is a basic LSTM op.
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
if (const auto* op = op_union.AsLSTMOptions()) {
return op->kernel_type == tflite::LSTMKernelType_BASIC;
} else {
return false;
}
}
// Gets the MLIR op name with the dialect name for the flatbuffer operator.
StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
const tflite::OperatorCodeT& op_code) {
if (IsBasicLSTMOp(op.builtin_options)) {
return std::string("tfl.basic_lstm");
}
auto builtin_code = tflite::GetBuiltinCode(&op_code);
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
return std::string("tfl.custom");
}
if (builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
}
if (builtin_code == tflite::BuiltinOperator_WHILE) {
return std::string("tf.While");
}
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
return llvm::Twine("tfl.", op_name.lower()).str();
}
// The buffers in TFLite flatbuffers have their contents stored as a vector of
// bytes that represent little-endian values.
// The read_size parameter is present to allow reading both float16 and float32s
// without a case split.
template <typename T>
std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
std::vector<T> ret;
size_t read_size = sizeof(T);
int bytes_len = bytes.size();
assert(bytes_len % read_size == 0);
int elem_count = bytes_len / read_size;
ret.reserve(elem_count);
const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
for (int i = 0; i < elem_count; i++) {
ret.push_back(
llvm::support::endian::readNext<T, llvm::support::little,
llvm::support::unaligned>(data_ptr));
}
return ret;
}
tensorflow::TensorProto ConvertTfliteConstTensor(
const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
tensorflow::TensorProto ret;
ret.set_dtype(TflTypeToTfType(tensor.type));
tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
shape->set_unknown_rank(false);
for (auto dim : tensor.shape) {
shape->add_dim()->set_size(int64_t{dim});
}
std::string content;
content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
ret.set_tensor_content(content);
return ret;
}
StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
const std::vector<uint8_t>& buffer) {
size_t bytes_len = buffer.size();
// The bytes of floats are stored little-endian.
switch (elem_type.getWidth()) {
case 16: {
assert(bytes_len % 2 == 0);
int elem_count = bytes_len / 2;
std::vector<llvm::APFloat> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
auto& semantics = elem_type.getFloatSemantics();
for (int i = 0; i < elem_count; i++) {
uint16_t bit_repr =
llvm::support::endian::readNext<uint16_t, llvm::support::little,
llvm::support::unaligned>(data);
llvm::APInt int_repr(16, bit_repr);
values.emplace_back(semantics, int_repr);
}
return DenseElementsAttr::get(shaped_type, values);
}
case 32: {
assert(bytes_len % 4 == 0);
int elem_count = bytes_len / 4;
std::vector<float> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
for (int i = 0; i < elem_count; i++) {
uint32_t bit_repr =
llvm::support::endian::readNext<uint32_t, llvm::support::little,
llvm::support::unaligned>(data);
values.push_back(absl::bit_cast<float>(bit_repr));
}
return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
}
case 64: {
assert(bytes_len % 8 == 0);
int elem_count = bytes_len / 8;
std::vector<double> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
for (int i = 0; i < elem_count; i++) {
uint64_t bit_repr =
llvm::support::endian::readNext<uint64_t, llvm::support::little,
llvm::support::unaligned>(data);
values.push_back(absl::bit_cast<double>(bit_repr));
}
return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
}
}
return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
}
StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
const std::vector<uint8_t>& buffer) {
unsigned bit_width;
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
bit_width = itype.getWidth();
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
bit_width = qtype.getStorageTypeIntegralWidth();
shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
qtype.getStorageType());
} else {
return errors::InvalidArgument("unsupported integer constant type");
}
switch (bit_width) {
case 1: {
// vector<bool> doesn't convert to an ArrayRef
llvm::SmallVector<bool, 8> values;
values.reserve(buffer.size());
for (auto b : buffer) {
values.emplace_back(b != 0);
}
return DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values));
}
case 8: {
return DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer));
}
case 16: {
auto values = ReadAsLittleEndian<uint16_t>(buffer);
return DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values));
}
case 32: {
auto values = ReadAsLittleEndian<uint32_t>(buffer);
return DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values));
}
case 64: {
auto values = ReadAsLittleEndian<uint64_t>(buffer);
return DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values));
}
default:
return errors::Unimplemented("Cannot handle bit width ", bit_width);
}
}
StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
int32_t buffer_index,
OpBuilder builder, Location loc) {
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
if (!shaped_type) {
return errors::Internal("Constant doesn't have a shape");
}
auto op = builder.create<tfl::ExternalConstOp>(
loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
return op.getOperation();
}
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
if (!shaped_type) {
return errors::Internal("Constant doesn't have a shape");
}
auto elem_type = shaped_type.getElementType();
mlir::ElementsAttr value;
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertIntBuffer(shaped_type, elem_type, buffer));
} else if (elem_type.isa<mlir::TF::StringType>()) {
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::vector<llvm::StringRef> refs;
refs.reserve(repr.string_val_size());
for (const auto& ref : repr.string_val())
refs.push_back({ref.data(), ref.size()});
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
auto dialect = elem_type.getContext()->getLoadedDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
} else {
return errors::Unimplemented("Constant of unsupported type");
}
if (IsQuantized(tensor)) {
auto op = builder.create<tfl::QConstOp>(
loc, mlir::TypeAttr::get(shaped_type), value);
return op.getOperation();
}
auto op = builder.create<tfl::ConstOp>(loc, value);
return op.getOperation();
}
llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
tflite::BuiltinOptionsUnion options,
const std::vector<std::string>& func_names, Builder builder) {
if (auto* opts = options.AsIfOptions()) {
uint32_t then_idx = opts->then_subgraph_index;
auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx));
uint32_t else_idx = opts->else_subgraph_index;
auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx));
return {builder.getNamedAttr("then_branch", then_attr),
builder.getNamedAttr("else_branch", else_attr),
// TODO(b/139667752): Analyze statelessness correctly
builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
}
if (auto* opts = options.AsWhileOptions()) {
uint32_t cond_idx = opts->cond_subgraph_index;
auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx));
uint32_t body_idx = opts->body_subgraph_index;
auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
return {builder.getNamedAttr("cond", cond_attr),
builder.getNamedAttr("body", body_attr),
// TODO(b/139667752): Analyze statelessness correctly
builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
}
return {};
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
const std::vector<mlir::TensorType>& intermediate_types,
Value optional_arg_marker,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder) {
llvm::SmallVector<Value, 4> operands;
llvm::SmallVector<mlir::Type, 2> outputTypes;
if (op.outputs.empty()) {
auto err = errors::InvalidArgument("operator with no outputs");
return emitError(loc, err.ToString()), err;
}
const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
OperationState op_state(loc, op_name);
for (auto input_num : op.inputs) {
if (input_num == -1) {
assert(optional_arg_marker != nullptr);
op_state.addOperands({optional_arg_marker});
} else {
op_state.addOperands({vals_map.at(input_num)});
}
}
for (auto output_num : op.outputs) {
auto& tensor = *tensors.at(output_num);
auto type_or_err = GetTensorType(tensor, builder);
if (!type_or_err.ok()) {
return emitError(loc, type_or_err.status().ToString()),
type_or_err.status();
}
auto type = type_or_err.ConsumeValueOrDie();
if (op_name == "tfl.quantize") {
// Special case for quantize: return type must also be in qtype attribute
op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
} else if (op_name == "tfl.reshape" && type.hasStaticShape() &&
op_state.operands.size() == 1) {
// Special case for reshape: the second op is optional in the old
// converter and kernel, so we create the second operand, which is
// required by the new converter, from the result shape.
auto shape_type =
RankedTensorType::get({type.getRank()}, builder.getIntegerType(32));
mlir::SmallVector<mlir::Attribute, 4> shape;
shape.reserve(type.getRank());
for (auto s : type.getShape()) {
shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
}
auto output_shape = DenseElementsAttr::get(shape_type, shape);
auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
op_state.addOperands({shape_op});
}
op_state.addTypes({type});
}
// While the last several tensors could be optional tensors for an tfl op, the
// number of input operands could vary. Gets the min/max number of
// operands from tflite op name.
// Also, since the above code special-handles the `tfl.reshape` op and add an
// additional input, we put these function block here.
llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
int input_max_num = input_min_max.Max;
int op_input_num = op_state.operands.size();
if (input_max_num != 0 && input_max_num > op_input_num) {
// If the number of current inputs is less than the op definition, fill in
// with `none` value,
llvm::SmallVector<Value, 4> none_operands(
input_max_num - op_input_num,
builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
builder.getUnitAttr()));
op_state.addOperands(ArrayRef<Value>(none_operands));
}
if (op_name == "tfl.lstm") {
// TODO(b/147587779): add the right region if region is empty.
op_state.addRegion();
if (!op.intermediates.empty()) {
if (op.intermediates.size() != 5) {
auto err = errors::InvalidArgument(
"operator has intermediate tensors but the number of them is not "
"five.");
return emitError(loc, err.ToString()), err;
}
// Create intermediate value
const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto type_and_name :
llvm::zip(intermediate_types, kIntermediateNames)) {
mlir::TypeAttr type_attr =
mlir::TypeAttr::get(std::get<0>(type_and_name));
auto named_attr =
builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
op_state.addAttribute(named_attr.first, named_attr.second);
}
}
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
auto builtin_code = tflite::GetBuiltinCode(&op_code);
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
auto status = mlir::CustomOptionsToAttributes(
op_code.custom_code, op.custom_options, builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
} else {
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
}
op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While
auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
op.builtin_options, func_names, builder);
op_state.addAttributes(function_ref_attrs);
return builder.createOperation(op_state);
}
// Returns indices of the given tensors in the subgraph. Returns error if a
// tensor name cannot be found in the subgraph.
StatusOr<std::vector<int>> GetTensorIndices(
const tflite::SubGraphT& subgraph,
const std::vector<std::string>& tensor_names) {
absl::flat_hash_map<std::string, int> name_to_index;
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
}
std::vector<int> indices;
indices.reserve(tensor_names.size());
for (const auto& name : tensor_names) {
auto found = name_to_index.find(name);
if (found != name_to_index.end()) {
indices.push_back(found->second);
} else {
return errors::InvalidArgument("could not find tensor in subgraph: ",
name);
}
}
return indices;
}
// Given a list of tensor indices, returns a string of concatenated tensor names
// wrapped in a NamedAttribute.
template <typename ContainerType>
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
const ContainerType indices) {
auto tensor_names = llvm::map_range(
indices, [&](int i) { return subgraph.tensors.at(i)->name; });
return builder->getNamedAttr(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Traverses the subgraph from output_indices to input_indices and returns the
// set of ops that are visited.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
if (!llvm::is_contained(input_indices, output)) {
defining_op[output] = op.get();
}
}
}
std::vector<const tflite::OperatorT*> queue;
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
}
}
// Traverse the graph towards inputs.
absl::flat_hash_set<const tflite::OperatorT*> visited;
while (!queue.empty()) {
const tflite::OperatorT* op = queue.back();
queue.pop_back();
if (!visited.insert(op).second) {
// The node has already been visited.
continue;
}
for (int32_t input : op->inputs) {
// Input tensor may not have a defining op in case it is a subgraph input
// or a constant tensor.
if (auto& op = defining_op[input]) {
queue.push_back(op);
}
}
}
return visited;
}
// We want to adjust the func op according to some cross ops information.
static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
OpBuilder builder(func);
// When a quantized constant is imported, its quantization parameter is set
// to be narrow range. Here revert to be the fully range if the user doesn't
// require narrow range.
func.walk([&](tfl::QConstOp cst) {
Value value = cst.getResult();
Value full_range_const = value;
for (auto& use : value.getUses()) {
Operation* user = use.getOwner();
if (user->isKnownTerminator()) return;
auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
value.getType());
// Only the 8-bit constants are imported with narrow range.
if (!qtype || qtype.getStorageTypeIntegralWidth() != 8) return;
auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
if (affine_user &&
affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
affine_user.RequiredNarrowRangeAffineOperand())
return;
// Create a fully range quantized constant.
if (full_range_const == value) {
mlir::quant::QuantizedType new_qtype;
if (auto per_axis =
qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
per_axis.getFlags(), per_axis.getStorageType(),
per_axis.getExpressedType(), per_axis.getScales(),
per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
} else if (auto per_tensor =
qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
new_qtype = mlir::quant::UniformQuantizedType::get(
per_tensor.getFlags(), per_tensor.getStorageType(),
per_tensor.getExpressedType(), per_tensor.getScale(),
per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
per_tensor.getStorageTypeMax());
} else {
return;
}
auto new_output_type = new_qtype.castFromExpressedType(
mlir::quant::UniformQuantizedType::castToExpressedType(
value.getType()));
builder.setInsertionPointAfter(cst.getOperation());
auto new_op = builder.create<tfl::QConstOp>(
cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
cst.valueAttr());
full_range_const = new_op.output();
}
use.set(full_range_const);
}
if (cst.use_empty()) cst.erase();
});
return func;
}
// Build a FuncOp from a tflite SubGraph
// The buffers are directly taken
// from the deserialized flatbuffer as we do not have the type information to
// interpret them until this point. The base_loc parameter is the location of
// the flatbuffer as a whole (usually a file). The is_entry_point flag
// controls whether shapeless types are treated as scalars. If
// ordered_output_arrays is not empty, then the imported mlir function will only
// return nodes in ordered_output_arrays in the same order.
StatusOr<FuncOp> ConvertSubgraph(
const tflite::SubGraphT& subgraph, llvm::StringRef name,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder, bool is_entry_point,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
std::vector<int> func_inputs = subgraph.inputs;
if (is_entry_point && !ordered_input_arrays.empty()) {
if (!experimental_prune_unreachable_nodes_unconditionally) {
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
return errors::InvalidArgument(
"input-arrays should be used with experimental pruning flag");
}
TF_ASSIGN_OR_RETURN(func_inputs,
GetTensorIndices(subgraph, ordered_input_arrays));
}
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
input_index_set.insert(i);
}
}
for (auto input_or_variable : func_inputs) {
auto& tensor = *subgraph.tensors.at(input_or_variable);
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
// but we cannot differentiate scalars from unranked tensors.
// Here we reverse the default assumption that shape = [] means unranked.
// when processing main()
auto type_or_err = GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/is_entry_point,
/*is_constant=*/false);
if (!type_or_err.ok()) {
emitError(func_loc, "error reading argument types")
<< type_or_err.status().ToString();
return type_or_err.status();
}
auto type = type_or_err.ConsumeValueOrDie();
input_types.push_back(type);
}
llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
for (auto& op : subgraph.operators) {
for (auto output : op->outputs) {
is_op_output[output] = true;
}
}
std::vector<int> func_outputs = subgraph.outputs;
if (is_entry_point && !ordered_output_arrays.empty()) {
TF_ASSIGN_OR_RETURN(func_outputs,
GetTensorIndices(subgraph, ordered_output_arrays));
}
for (auto output : func_outputs) {
const bool is_func_input = input_index_set.contains(output);
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
// 1. `is_constant` = true, means this tensor is created from a constant op.
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
// tensor is function input and function input type is a scalar tensor.
const bool shapeless_is_scalar =
is_constant || (is_func_input && is_entry_point);
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
shapeless_is_scalar,
/*is_constant=*/is_constant);
if (!type_or_err.ok()) {
emitError(func_loc, "error reading return types")
<< type_or_err.status().ToString();
return type_or_err.status();
}
auto type = type_or_err.ConsumeValueOrDie();
ret_types.push_back(type);
}
auto func_type = builder.getFunctionType(input_types, ret_types);
// Construct function object
auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
func.addEntryBlock();
auto& body = func.getBody();
OpBuilder op_builder{body};
std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
Value maybe_optional_arg_marker = nullptr;
// Get or construct MLIR values for each input
for (int i = 0, e = func_inputs.size(); i < e; i++) {
auto input_tensor = func_inputs[i];
const auto& tensor = *subgraph.tensors.at(input_tensor);
auto loc = TensorLoc(tensor, builder, base_loc);
if (vals_map[input_tensor]) {
auto err = errors::FailedPrecondition("duplicate input arguments");
return emitError(loc, err.ToString()), err;
}
Value input_value = func.getArgument(i);
// If the `tensor` has min/max and doesn't have scale/zero_point
// information, a stats op is created to use the input_value, then the
// `tensor` should be mapped to the result of this new stats op.
if (auto stats_op =
ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
vals_map[input_tensor] = stats_op->getResult(0);
} else {
vals_map[input_tensor] = input_value;
}
}
// Set tf.entry_function attribute
if (is_entry_point) {
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
if (!func_inputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
subgraph, &builder, "inputs", func_inputs));
}
if (!func_outputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
subgraph, &builder, "outputs", func_outputs));
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_inputs, func_outputs));
}
// Construct MLIR operators from TFLite operators
for (auto& op : subgraph.operators) {
if (experimental_prune_unreachable_nodes_unconditionally &&
!pruned_subgraph_ops.contains(op)) {
continue;
}
for (auto input_num : op->inputs) {
// The operators in a graph are topologically sorted
// and so if no previous operation has produced a tensor
// it must be a constant.
if (input_num == -1) {
if (maybe_optional_arg_marker == nullptr) {
maybe_optional_arg_marker =
op_builder
.create<mlir::ConstantOp>(base_loc, builder.getNoneType(),
builder.getUnitAttr())
.getResult();
}
} else if (!vals_map.at(input_num)) {
auto& const_tensor = *subgraph.tensors[input_num];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
use_external_constant
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();
}
vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
}
}
// Intermediate tensors for tfl.lstm are used to carry quantization range
// in their types, so we only need and extract their types.
std::vector<mlir::TensorType> intermediate_types;
intermediate_types.reserve(5);
for (auto intermediate : op->intermediates) {
TF_ASSIGN_OR_RETURN(
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
intermediate_types.emplace_back(type);
}
// The NameLoc corresponding to the name of the first output tensor
auto op_loc =
op->outputs.empty()
? base_loc
: TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
// If there's an optional argument, maybe_optional_arg_marker has been set
// to a valid Value
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
op_codes, func_names, subgraph.tensors, op_loc, op_builder));
// Add the results to the value maps. There are two cases: 1. the result
// tensor does not have min/max values, the original op result is used
// directly; 2. the result tensor has some min/max values, a stats op is
// created, then the result of the stats op is used.
for (auto pair : llvm::enumerate(mlir_op->getResults())) {
int output_tensor_index = op->outputs[pair.index()];
auto& tensor = *subgraph.tensors[output_tensor_index];
if (auto stats_op =
ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
vals_map[output_tensor_index] = stats_op->getResult(0);
} else {
vals_map[output_tensor_index] = pair.value();
}
}
}
// Construct return values
llvm::SmallVector<Value, 4> return_operands;
for (auto index : func_outputs) {
if (!vals_map.at(index)) {
auto& const_tensor = *subgraph.tensors[index];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
use_external_constant
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();
}
vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
}
return_operands.push_back(vals_map[index]);
}
op_builder.create<mlir::ReturnOp>(base_loc, return_operands);
return PostProcessFuncOp(func);
}
// TFLite subgraphs do not necessarily have names, though MLIR functions must
// have them, so we generate a name for subgraphs that are missing one here.
// Note: in TFLite, the first subgraph is the entry point, and in MLIR that
// represents TFLite, this entry point must be called "main"
// TODO(b/131175224,b/132239787) Support multiple entry points
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
if (index == 0) {
return "main";
}
if (subgraph.name.empty()) {
return llvm::formatv("fn_{0}", index).str();
}
return subgraph.name;
}
} // namespace
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
context->loadDialect<
mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {
return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
}
std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
auto builder = Builder(context);
std::vector<std::string> func_names;
for (auto& subgraph : model->subgraphs) {
func_names.push_back(subgraph->name);
}
auto module = mlir::ModuleOp::create(base_loc);
// We currently don't use this to make decisions, but we could
// use it in exports or if there are breaking changes
module.setAttr("tfl.schema_version",
builder.getI32IntegerAttr(model->version));
if (!model->description.empty()) {
module.setAttr("tfl.description",
builder.getStringAttr(model->description));
}
for (auto e : llvm::enumerate(model->subgraphs)) {
auto& subgraph = e.value();
std::string name = SubgraphName(e.index(), *subgraph);
auto func_or_error = ConvertSubgraph(
*subgraph, name, model->operator_codes, func_names, model->buffers,
base_loc, builder,
// TODO(b/131175224,b/132239787) Support multiple entry points
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
ordered_output_arrays,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name << ": "
<< func_or_error.status().error_message(),
nullptr;
}
module.push_back(func_or_error.ConsumeValueOrDie());
}
return OwningModuleRef(module);
}