Start open development of TF, TFLite, XLA MLIR dialects.

This change adds dialects and tools for TF, TFLite and XLA using MLIR (https://github.com/tensorflow/mlir). This is under active development and not built by default.

PiperOrigin-RevId: 255538027
This commit is contained in:
Jacques Pienaar 2019-06-27 21:48:02 -07:00 committed by TensorFlower Gardener
parent 3398e887f5
commit eab4b9c4cc
219 changed files with 45933 additions and 0 deletions

View File

@ -0,0 +1,56 @@
# Description:
# TensorFlow/TensorFlow Lite/XLA MLIR dialects and tools.
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
# To reference all tablegen files here when checking for updates to them.
filegroup(
name = "td_files",
srcs = glob(["**/*.td"]),
)
cc_library(
name = "tf_mlir_opt_main",
srcs = ["tf_mlir_opt_main.cc"],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/xla",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/core:lib",
"@llvm//:support",
"@local_config_mlir//:MlirOptLib",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Support",
"@local_config_mlir//test:TestDialect",
"@local_config_mlir//test:TestTransforms",
],
)
tf_cc_binary(
name = "tf-opt",
deps = [
":tf_mlir_opt_main",
],
)
filegroup(
name = "litfiles",
srcs = glob(["runlit*py"]),
)

View File

@ -0,0 +1,11 @@
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
This module contains the MLIR
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
dialects and utilities for
1. TensorFlow
2. XLA
3. TF Lite
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.

View File

@ -0,0 +1,47 @@
# Test definitions for Lit, the LLVM test runner.
#
# This is reusing the LLVM Lit test runner in the interim until the new build
# rules are upstreamed.
# TODO(b/136126535): remove this custom rule.
"""Lit runner globbing test
"""
def glob_lit_tests(
exclude = None,
test_file_exts = ["mlir"],
default_size = "small",
size_override = None,
data = None,
per_test_extra_data = None,
default_tags = None,
tags_override = None,
driver = None,
features = []):
"""Creates all plausible Lit tests (and their inputs) under this directory.
Args:
exclude: [str], paths to exclude (for tests and inputs).
test_file_exts: [str], extensions for files that are tests.
default_size: str, the test size for targets not in "size_override".
size_override: {str: str}, sizes to use for specific tests.
data: [str], additional input data to the test.
per_test_extra_data: {str: [str]}, extra data to attatch to a given file.
default_tags: [str], additional tags to attach to the test.
tags_override: {str: str}, tags to add to specific tests.
driver: str, label of the driver shell script.
features: [str], list of extra features to enable.
"""
native.py_test(
name = "glob_lit_tests",
srcs = ["@llvm//:lit"],
args = [
"tensorflow/compiler/mlir --config-prefix=runlit",
],
data = data + [
"//tensorflow/compiler/mlir:litfiles",
"@llvm//:FileCheck",
"@llvm//:count",
"@llvm//:not",
] + native.glob(["*." + ext for ext in test_file_exts]),
main = "lit.py",
)

View File

@ -0,0 +1,518 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
load(
"@local_config_mlir//:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
# TODO(jpienaar): Make the visibility more restrictive.
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
packages = [
"//learning/brain/experimental/mlir/...",
"//learning/brain/google/xla/...",
"//tensorflow/compiler/mlir/...",
],
)
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"ir/tfl_ops.td",
"@local_config_mlir//:OpBaseTdFiles",
"@local_config_mlir//:QuantizationOpsTdFiles",
],
)
gentbl(
name = "tensorflow_lite_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tfl_ops.h.inc",
),
(
"-gen-op-defs",
"ir/tfl_ops.cc.inc",
),
(
"-gen-op-doc",
"g3doc/tfl_ops.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_prepare_tf_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_prepare_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/prepare_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
],
)
gentbl(
name = "tensorflow_lite_lower_static_tensor_list_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_lower_static_tensor_list.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/tensorlist_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_legalize_tf_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_legalize_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/legalize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_optimize_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_optimize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/optimize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
],
)
gentbl(
name = "tensorflow_lite_quantize_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_quantize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/quantize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
],
)
cc_library(
name = "validators",
srcs = [
"utils/validators.cc",
],
hdrs = [
"utils/validators.h",
],
deps = [
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
],
)
cc_library(
name = "tensorflow_lite",
srcs = [
"ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc",
"utils/attribute_utils.cc",
],
hdrs = [
"ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"utils/quantization_utils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:TypeUtilities",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_lite_quantization_utils",
srcs = [
"utils/generated_op_quant_spec_getters.inc",
"utils/quantization_driver.cc",
"utils/quantization_utils.cc",
],
hdrs = [
"utils/quantization_utils.h",
],
deps = [
":tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
"transforms/generated_prepare_tf.inc",
"transforms/legalize_tf.cc",
"transforms/lower_static_tensor_list.cc",
"transforms/prepare_tf.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_quantization_utils",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_lite_optimize",
srcs = [
"transforms/generated_optimize.inc",
"transforms/optimize.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":validators",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_lite_quantize",
srcs = [
"transforms/generated_quantize.inc",
"transforms/post_quantize.cc",
"transforms/prepare_quantize.cc",
"transforms/quantize.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_quantization_utils",
":validators",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)
tf_native_cc_binary(
name = "op_quant_spec_getters_gen",
srcs = [
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
],
)
genrule(
name = "op_quant_spec_getters_inc",
srcs = [
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
],
outs = [
"utils/generated_op_quant_spec_getters.inc",
],
cmd = ("$(location :op_quant_spec_getters_gen) " +
"-I external/local_config_mlir/include " +
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
tools = [":op_quant_spec_getters_gen"],
)
# Library with tensorflow Lite dialect static initialization.
cc_library(
name = "tensorflow_lite_dialect_registration",
srcs = [
"ir/dialect_registration.cc",
],
deps = [
":tensorflow_lite",
"@local_config_mlir//:IR",
],
alwayslink = 1,
)
tf_native_cc_binary(
name = "operator-writer-gen",
srcs = [
"operator_writer_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
],
)
genrule(
name = "operator_writer_inc",
srcs = [
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
],
outs = [
"operator_writers.inc",
],
cmd = ("$(location :operator-writer-gen) " +
"-I external/local_config_mlir/include " +
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
tools = [":operator-writer-gen"],
)
cc_library(
name = "flatbuffer_tflite_operator_lib",
srcs = [
"flatbuffer_operator.cc",
"operator_writers.inc",
],
hdrs = [
"flatbuffer_operator.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:TransformUtils",
],
)
tf_native_cc_binary(
name = "flatbuffer_to_string",
srcs = ["flatbuffer_to_string.cc"],
deps = [
"//tensorflow/lite/schema:schema_fbs_with_reflection",
"@flatbuffers",
],
)
cc_library(
name = "flatbuffer_translate_lib",
srcs = [
"flatbuffer_translate.cc",
],
hdrs = [
"flatbuffer_translate.h",
],
deps = [
":flatbuffer_tflite_operator_lib",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Translation",
],
alwayslink = 1,
)
tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
],
)
cc_library(
name = "tf_tfl_translate_cl_options",
srcs = [
"tf_tfl_translate_cl.cc",
],
hdrs = [
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm//:support",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tf_tfl_translate",
srcs = ["tf_tfl_translate.cc"],
deps = [
":flatbuffer_translate_lib",
":tensorflow_lite",
":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
tf_cc_binary(
name = "mlir-tflite-runner",
srcs = ["mlir_tflite_runner.cc"],
deps = [
":flatbuffer_translate_lib",
"//tensorflow/core:lib",
"//tensorflow/core/platform/default/build_config:base",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "tf_to_tfl_flatbuffer",
srcs = ["tf_to_tfl_flatbuffer.cc"],
hdrs = [
"tf_to_tfl_flatbuffer.h",
],
deps = [
":flatbuffer_translate_lib",
":tensorflow_lite",
":tensorflow_lite_legalize_tf",
":tensorflow_lite_optimize",
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
],
)

View File

@ -0,0 +1,28 @@
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
This directory contains:
1. Experimental code for the new TF-Lite convertor.
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
## API:
The API for converting TensorFlow models to TensorFlow Lite is
tf.lite.TFLiteConverterV2.
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
- The Canonicalization pass iteratively applies canonicalization
transformations in a greedy way until no further changes occur.
Canonicalization includes constant folding.
- The Legalize pass converts TensorFlow operations to TensorFlow Lite
ones. The operations that cannot be mapped to TensorFlow Lite dialect
are left as TensorFlow operations. Unsupported op handling follows the
proposed TFLite mechanism.
- Optimizations are performed in both the TF & TFLite dialect; aiming for small
size and high performance (among the core value proposition of
TensorFlow Lite models).
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.

View File

@ -0,0 +1,113 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include <vector>
#include "llvm/ADT/StringSwitch.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/lite/schema/schema_generated.h"
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
// way when part of the translation of module.
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
.Case("NONE", tflite::ActivationFunctionType_NONE)
.Case("RELU", tflite::ActivationFunctionType_RELU)
.Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
.Case("RELU6", tflite::ActivationFunctionType_RELU6)
.Case("TANH", tflite::ActivationFunctionType_TANH)
.Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
}
static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
return llvm::StringSwitch<tflite::Padding>(str)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID);
}
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
switch (type.getKind()) {
case mlir::StandardTypes::F16:
return tflite::TensorType_FLOAT16;
case mlir::StandardTypes::F32:
return tflite::TensorType_FLOAT32;
case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::COMPLEX64:
return tflite::TensorType_COMPLEX64;
case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 1:
return tflite::TensorType_BOOL;
case 8:
return tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
default:
llvm_unreachable("invalid integer Type in conversion");
}
}
default:
llvm_unreachable("invalid Type in conversion");
}
}
// I32Attr already returns an int as required by flatbuffer builders.
static int ConvertI32AttrForOptionWriter(
llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) {
return i.getSExtValue();
}
// F32Attr already returns a float as required by flatbuffer builders.
static float ConvertF32AttrForOptionWriter(
llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
return f.convertToFloat();
}
// BoolAttr already returns a bool as required by flatbuffer builders.
static bool ConvertBoolAttrForOptionWriter(
bool b, flatbuffers::FlatBufferBuilder* builder) {
return b;
}
static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(
llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
std::vector<int> intVec(r.begin(), r.end());
return builder->CreateVector(intVec);
}
static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
.Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
.Case("SHUFFLED4x16INT8",
tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
}
// Pull in FlatBuffer writers for TFLite generated using TableGen
#include "tensorflow/compiler/mlir/lite/operator_writers.inc"

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
// Returns the builtin op code for the given MLIR operation on success; emits
// error and returns llvm::None on failure.
llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
// Packs the given MLIR operation into a TFLite FlatBuffer operator object.
// Returns the FlatBuffer offset for the operator on success; emits error and
// returns llvm::None on failure.
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
Operation *mlir_op, uint32_t opcode_index,
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
flatbuffers::FlatBufferBuilder *fbb);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_

View File

@ -0,0 +1,142 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Dumps a TFLite flatbuffer to a textual output format.
// This tool is intended to be used to simplify unit testing/debugging.
#include <stddef.h>
#include <stdint.h>
#include <fstream>
#include <iostream>
#include <string>
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "flatbuffers/minireflect.h" // TF:flatbuffers
#include "tensorflow/lite/schema/reflection/schema_generated.h"
namespace tflite {
namespace {
// Reads a model from a provided file path and verifies if it is a valid
// flatbuffer, and returns false with the model in serialized_model if valid
// else true.
bool ReadAndVerify(const std::string& file_path,
std::string* serialized_model) {
if (file_path == "-") {
*serialized_model = std::string{std::istreambuf_iterator<char>(std::cin),
std::istreambuf_iterator<char>()};
} else {
std::ifstream t(file_path);
if (!t.is_open()) {
std::cerr << "Failed to open input file.\n";
return true;
}
*serialized_model = std::string{std::istreambuf_iterator<char>(t),
std::istreambuf_iterator<char>()};
}
flatbuffers::Verifier model_verifier(
reinterpret_cast<const uint8_t*>(serialized_model->c_str()),
serialized_model->length());
if (!model_verifier.VerifyBuffer<Model>()) {
std::cerr << "Verification failed.\n";
return true;
}
return false;
}
// A FlatBuffer visitor that outputs a FlatBuffer as a string with proper
// indention for sequence fields.
// TODO(wvo): ToStringVisitor already has indentation functionality, use
// that directly instead of this sub-class?
struct IndentedToStringVisitor : flatbuffers::ToStringVisitor {
std::string indent_str;
int indent_level;
IndentedToStringVisitor(const std::string& delimiter,
const std::string& indent)
: ToStringVisitor(delimiter), indent_str(indent), indent_level(0) {}
void indent() {
for (int i = 0; i < indent_level; ++i) s.append(indent_str);
}
// Adjust indention for fields in sequences.
void StartSequence() override {
s += "{";
s += d;
++indent_level;
}
void EndSequence() override {
s += d;
--indent_level;
indent();
s += "}";
}
void Field(size_t /*field_idx*/, size_t set_idx,
flatbuffers::ElementaryType /*type*/, bool /*is_vector*/,
const flatbuffers::TypeTable* /*type_table*/, const char* name,
const uint8_t* val) override {
if (!val) return;
if (set_idx) {
s += ",";
s += d;
}
indent();
if (name) {
s += name;
s += ": ";
}
}
void StartVector() override { s += "[ "; }
void EndVector() override { s += " ]"; }
void Element(size_t i, flatbuffers::ElementaryType /*type*/,
const flatbuffers::TypeTable* /*type_table*/,
const uint8_t* /*val*/) override {
if (i) s += ", ";
}
};
void ToString(const std::string& serialized_model) {
IndentedToStringVisitor visitor(/*delimiter=*/"\n", /*indent=*/" ");
IterateFlatBuffer(reinterpret_cast<const uint8_t*>(serialized_model.c_str()),
ModelTypeTable(), &visitor);
std::cout << visitor.s << "\n\n";
}
} // end namespace
} // end namespace tflite
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "Missing input argument. Usage:\n"
<< argv[0] << " <filename or - for stdin>\n\n"
<< "Converts TensorFlowLite flatbuffer to textual output format. "
<< "One positional input argument representing the source of the "
<< "flatbuffer is supported.\n";
return 1;
}
std::string serialized_model;
if (tflite::ReadAndVerify(argv[1], &serialized_model)) return 1;
tflite::ToString(serialized_model);
return 0;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
#include <string>
#include "mlir/IR/Module.h" // TF:local_config_mlir
// These flags are used to control the emission or not of different kinds of ops
// during the flatbuffer translation.
extern bool emit_builtin_tflite_ops;
extern bool emit_select_tf_ops;
extern bool emit_custom_ops;
// The flag to control whether to lower tensorlist ops into TF ops.
extern bool lower_tensor_list_ops;
namespace tflite {
// Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string.
bool MlirToFlatBufferTranslateFunction(mlir::Module *module,
std::string *serialized_flatbuffer,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,19 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
// Static initialization for TensorFlow Lite op registration.
static mlir::DialectRegistration<mlir::TFL::TensorFlowLiteDialect> tfl_ops;

View File

@ -0,0 +1,574 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
: Dialect(/*name=*/"tfl", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
>();
}
//===----------------------------------------------------------------------===//
// Common support logic
//===----------------------------------------------------------------------===//
namespace {
// Returns true if the dimensions in `a` is a suffix of the ones in `b`.
// For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
// {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
if (a.size() > b.size()) return false;
return std::equal(a.rbegin(), a.rend(), b.rbegin());
}
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// The two operands are expected to both be scalar values.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<AttrElementT>();
auto rhs = operand2.cast<AttrElementT>();
assert(lhs.getType() == result_type && rhs.getType() == result_type &&
"values of incompatible types should be caught by op verification");
// TODO: Need to handle overflow/underflow cases.
return AttrElementT::get(result_type,
calculate(lhs.getValue(), rhs.getValue()));
}
// TODO: We have multiple functions to handle different attriubte kinds in the
// following. Consider add methods to ElementsAttr to unify these functions.
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// This function assumes that both operands are `AttrElementT` attributes.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpSplatSplat(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto type = result_type.cast<ShapedType>();
auto elem_type = type.getElementType();
auto element_result = ConstFoldBinaryOpScalarScalar<AttrElementT>(
elem_type, operand1, operand2, calculate);
if (!element_result) return {};
return DenseElementsAttr::get(type, element_result);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the first operand is a DenseElementsAttr and the
/// second one is a SplatElementsAttr, and both are verified to have value
/// attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpDenseSplat(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<DenseElementsAttr>();
// TODO: Support broadcast behavior
if (lhs.getType() != result_type || operand2.getType() != result_type)
return {};
auto rhs = operand2.cast<SplatElementsAttr>().getSplatValue();
auto type = result_type.cast<ShapedType>();
SmallVector<ElementValueT, 16> new_values;
new_values.reserve(lhs.rawSize());
// Add the splat value to each of the values in the dense elements
// attribute.
auto rhs_val = rhs.cast<AttrElementT>().getValue();
for (auto old_val : lhs.getValues<ElementValueT>()) {
new_values.push_back(calculate(old_val, rhs_val));
}
return DenseElementsAttr::get(type, new_values);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the both operands are DenseElementsAttr and verified
/// to have value attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<DenseElementsAttr>();
auto rhs = operand2.cast<DenseElementsAttr>();
if (lhs.getType() != rhs.getType()) {
// We only support the case that one of the operand's dimensions are
// a perfect suffix of the other.
// TODO: support the general broadcast behavior.
auto lhs_shape = lhs.getType().getShape();
auto rhs_shape = rhs.getType().getShape();
if (!IsTrailingDimensions(lhs_shape, rhs_shape) &&
!IsTrailingDimensions(rhs_shape, lhs_shape))
return {};
}
auto lhs_num_elements = lhs.getType().getNumElements();
auto rhs_num_elements = rhs.getType().getNumElements();
auto type = result_type.cast<ShapedType>();
auto num_elements = type.getNumElements();
// We assume the arguments have broadcast-compatible types. Make sure again.
assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements);
assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0);
SmallVector<ElementValueT, 16> lhs_old_values(lhs.getValues<ElementValueT>());
SmallVector<ElementValueT, 16> rhs_old_values(rhs.getValues<ElementValueT>());
SmallVector<ElementValueT, 16> new_values;
new_values.reserve(num_elements);
// Add each pair of the corresponding values in the dense elements
// attributes.
for (int i = 0; i < num_elements; ++i) {
// We only support a degenerated case here: the dimensions in one operand's
// shape is a perfect suffix to the other operand. Then conceptually it's
// similar to broadcasting a scalar to a 1-D vector.
// TODO: support the general broadcast behavior.
// We are tiling the operand with less elements an integral times to match
// the operand with more elements. We don't care which operand has less
// elements here because we are iterating its elements in circles, which can
// be achieved using the result index modulo the element count. For the
// operand with more elements, since the result has the same number of
// elements, we are only going over its elements once. The modulo operation
// also works for that.
int lhs_index = i % lhs_num_elements;
int rhs_index = i % rhs_num_elements;
new_values.push_back(
calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index]));
}
return DenseElementsAttr::get(type, new_values);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the two operands are verified to have value
/// attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
Attribute operand2, const CalculationT &calculate,
bool is_commutative) {
if (operand1.dyn_cast_or_null<AttrElementT>()) {
// Scalar op scalar case
if (operand2.dyn_cast_or_null<AttrElementT>())
return ConstFoldBinaryOpScalarScalar<AttrElementT>(result_type, operand1,
operand2, calculate);
} else if (auto lhs = operand1.dyn_cast_or_null<SplatElementsAttr>()) {
// Splat op splat case
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
return ConstFoldBinaryOpSplatSplat<AttrElementT>(
result_type, lhs.getSplatValue(), rhs.getSplatValue(), calculate);
// Splat op dense case
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>()) {
if (is_commutative) {
// Swap the two constant values to fall into the following case
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand2,
operand1, calculate);
}
}
} else if (auto lhs = operand1.dyn_cast_or_null<DenseElementsAttr>()) {
// Dense op splat case
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand1,
operand2, calculate);
// Dense op dense case
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>())
return ConstFoldBinaryOpDenseDense<AttrElementT>(result_type, operand1,
operand2, calculate);
}
// TODO: support other attribute kinds
return {};
}
/// Performs const folding with broadcast behavior on the two attributes in
/// `operands` and returns the result if possible.
/// Depending on the given `resultType`, either `floatCalculate` or
/// `intCalculate` is chosen to conduct the calculate.
Attribute ConstFoldBinaryOp(
Type result_type, ArrayRef<Attribute> operands,
std::function<APFloat(APFloat, APFloat)> float_calculate,
std::function<APInt(APInt, APInt)> int_calculate, bool is_commutative) {
// Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
// represented as tensor<f32>. So we are only handling tensor types here.
auto type = result_type.dyn_cast<ShapedType>();
if (!type) return {};
auto elemType = type.getElementType();
if (elemType.isa<FloatType>())
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
float_calculate, is_commutative);
if (elemType.isa<IntegerType>())
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
int_calculate, is_commutative);
return {};
}
void buildComparisonBinOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!result_type)
emitError(result->location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
result->addOperands({lhs, rhs});
auto resultShape = result_type.cast<ShapedType>().getShape();
// Comparison binary ops always return i1 tensor.
result->types.push_back(
builder->getTensorType(resultShape, builder->getI1Type()));
}
void buildFusedBroadcastableBinOp(Builder *builder, OperationState *result,
Value *lhs, Value *rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!result_type)
emitError(result->location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
result->addOperands({lhs, rhs});
result->addAttribute("fused_activation_function", fused_activation_function);
result->types.push_back(result_type);
}
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a + b; },
[](APInt a, APInt b) { return a + b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// ConcatenationOp
//===----------------------------------------------------------------------===//
// TODO(ashwinm): Implement shape inference for Concatenation
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
static void BuildGatherOp(Builder *builder, OperationState *result,
Value *params, Value *indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank())
return TFL::GatherOp::build(
builder, result, builder->getTensorType(params_type.getElementType()),
params, indices, axis);
int64_t params_rank = params_type.getRank();
int64_t indices_rank = indices_type.getRank();
// params rank is guaranteed to be at least 1.
// Produces an output tensor with shape:
// params.shape[:axis] + indices.shape + params.shape[axis + 1:]
std::vector<int64_t> shape(params_type.getShape());
int64_t axis_i = axis.getInt();
// For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
if (axis_i < 0) {
axis_i += params_rank;
}
// params must be atleast rank axis + 1
if (params_rank < axis_i + 1) {
emitError(result->location, "params must be atleast rank axis + 1");
}
if (indices_rank == 0) {
// Scalar indices (output is rank(params) - 1).
// Erase shape[axis]
shape.erase(shape.begin() + axis_i);
} else if (indices_rank == 1) {
// Vector indices (output is rank(params)).
// Copy indices.shape into params.shape[axis]
std::copy(std::begin(indices_type.getShape()),
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
} else {
// Higher rank indices (output is rank(params) + rank(indices) - 1).
shape.resize(params_rank + indices_rank - 1);
// Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
std::copy(std::begin(params_type.getShape()) + axis_i + 1,
std::end(params_type.getShape()),
std::begin(shape) + axis_i + indices_rank);
// Copy indices.shape into params.shape[axis]
std::copy(std::begin(indices_type.getShape()),
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
}
TFL::GatherOp::build(
builder, result,
builder->getTensorType(shape, params_type.getElementType()), params,
indices, axis);
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a * b; },
[](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
// TODO(b/133486129): Implement shape inference for pack
static LogicalResult Verify(PackOp op) {
// TODO(antiagainst): Implement other checks as in
// tensorflow/lite/kernels/pack.cc
if (op.getOperation()->getNumOperands() != op.values_count())
return op.emitOpError("input count should match 'values_count' attribute");
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
namespace {
/// This pattern matches and merges a tfl.reshape under the following
/// condition:
/// * The input's defining op is another tfl.reshape.
// TODO(antiagainst): This pattern probably should be moved to the peephole
// category, after we have the infra for peephole passes.
struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshape(MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand()->getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand()->getDefiningOp());
// Replace
// %1 = "tfl.reshape"(%0)
// %2 = "tfl.reshape"(%1)
// With
// %2 = "tfl.reshape"(%0)
rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand());
}
};
} // end anonymous namespace
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape.
if (getType() == getOperand()->getType()) return getOperand();
return nullptr;
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a - b; },
[](APInt a, APInt b) { return a - b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// TopKOp
//===----------------------------------------------------------------------===//
static void BuildTopKOp(Builder *builder, OperationState *result, Value *input,
Value *k) {
// Output size is only known if k is constant value. A negative dimension is
// considered dynamic so use -1 here if k is not a constant value.
int const_k = -1;
ElementsAttr cst;
if (matchPattern(k, m_Constant(&cst)))
// These casts should all be valid due to how Tensor constants are stored.
// TODO(jpienaar): This should use a helper function.
const_k = cst.getValue({}).cast<IntegerAttr>().getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!val_type.hasRank())
return TFL::TopKV2Op::build(
builder, result, builder->getTensorType(val_type.getElementType()),
builder->getTensorType(builder->getIntegerType(32)), input, k);
// Resultant shape is value.shape[:-1] + [k]
std::vector<int64_t> shape(val_type.getShape());
shape[shape.size() - 1] = const_k;
TFL::TopKV2Op::build(
builder, result, builder->getTensorType(shape, val_type.getElementType()),
builder->getTensorType(shape, builder->getIntegerType(32)), input, k);
}
//===----------------------------------------------------------------------===//
// FakeQuantOp
//===----------------------------------------------------------------------===//
// Return true if the op has non-empty "minmax" attribute.
static inline bool HasValidMinMaxAttribute(Operation *op) {
auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
return minmax && minmax.getValue().size() == 2;
}
namespace {
/// This pattern matches and remove a tfl.fake_quant if all the users of this op
/// and itself have "minmax" attribute set.
struct DropFakeQuant : public RewritePattern {
explicit DropFakeQuant(MLIRContext *context)
: RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
// We only match the op with valid "minmax" attribute.
if (!HasValidMinMaxAttribute(op)) return matchFailure();
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
// Replace the matched FakeQuantOp by its primiary operand.
rewriter.replaceOp(op, op->getOperand(0));
}
};
} // end anonymous namespace
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<DropFakeQuant>(context));
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
// TODO(b/133486129): Implement shape inference for unpack
static LogicalResult Verify(UnpackOp op) {
// TODO(antiagainst): Implement other checks as in
// tensorflow/lite/kernels/unpack.cc
if (op.getOperation()->getNumResults() != op.num())
return op.emitOpError("output count should match 'num' attribute");
return success();
}
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//
// TODO(b/133854225): Implement shape inference to Mean
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
namespace mlir {
namespace TFL {
class TensorFlowLiteDialect : public Dialect {
public:
explicit TensorFlowLiteDialect(MLIRContext *context);
};
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
namespace mlir {
namespace OpTrait {
namespace TFL {
using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
// The base class that all the quantization related OpTrait implements.
template <typename ConcreteType, template <typename> class TraitType>
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
static bool IsBias(int index) { return false; }
static bool IsQuantizable() { return true; }
};
// This class provides the API for TFL ops that requires same input and output
// scale as the quantization results. This is used as a trait like this:
//
// class TransposeOp
// : public Op<TransposeOp, OpTrait::TFL::SameOperandsAndResultsScale> {
//
template <typename ConcreteType>
class SameOperandsAndResultsScale
: public QuantizationSpecTraitBase<ConcreteType,
SameOperandsAndResultsScale> {};
// This class provides the API for TFL ops that has a fixed output value range.
// This is used as a trait like this:
//
// class SoftmaxOp
// : public Op<SoftmaxOp,
// OpTrait::TFL::FixedResultUniformScale<
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
//
// TODO(fengliuai): create a better way to epxress floating point scale in the
// template argument list.
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
class FixedResultUniformScale {
public:
template <typename ConcreteType>
class Impl
: public QuantizationSpecTraitBase<
ConcreteType, FixedResultUniformScale<
BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
StorageTypeMin, StorageTypeMax, Sign>::Impl> {
public:
QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation();
auto result_type =
op->getResult(index)->getType().template cast<TensorType>();
Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) *
::exp10(static_cast<double>(ScaleExp));
return UniformQuantizedType::getChecked(
Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
}
};
};
// This class provides the API for TFL ops that has input as bias. This is used
// as a trait like this:
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::AccumulatorScale<2, 0, 1>::Impl> {
//
// TODO(fengliuai): supports a configurable accumulator bit width.
template <int Bias, int... Operands>
class AccumulatorUniformScale {
public:
template <typename ConcreteType>
class Impl
: public QuantizationSpecTraitBase<
ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
public:
// Whether the index-th operand is a bias.
static bool IsBias(int index) { return index == Bias; }
// Returns the indexes of all the non-bias operands.
static std::vector<int> GetAllNonBiasOperands() {
return std::vector<int>({Operands...});
}
};
};
// This class provides the API for TFL ops that shouldn't be quantized. This is
// used as a trait like this:
//
// class LessOp : public Op<LessOp, OpTrait::TFL::NoQuantizableResult> {
//
template <typename ConcreteType>
class NoQuantizableResult
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
public:
static bool IsQuantizable() { return false; }
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -0,0 +1,139 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tool to run a TFLite computation from a MLIR input using the TFLite
// interpreter.
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
using llvm::cl::opt;
// NOLINTNEXTLINE
static opt<std::string> inputFileName(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// TODO(jpienaar): Move these functions to some debug utils.
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
auto begin = tensor.dims ? tensor.dims->data : nullptr;
auto end = tensor.dims ? tensor.dims->data + tensor.dims->size : nullptr;
return absl::StrJoin(begin, end, ", ");
}
template <typename T>
static std::string TfLiteTypedTensorString(const TfLiteTensor& tensor) {
const T* data = reinterpret_cast<T*>(tensor.data.raw);
if (!data) return "<null>";
int count = tensor.bytes / sizeof(T);
return absl::StrJoin(data, data + count, ", ");
}
// TODO(jpienaar): This really feels like something that should exist already.
static std::string TfLiteTensorString(const TfLiteTensor& tensor) {
switch (tensor.type) {
case kTfLiteInt32:
return TfLiteTypedTensorString<int32_t>(tensor);
case kTfLiteInt64:
return TfLiteTypedTensorString<int64_t>(tensor);
case kTfLiteFloat32:
return TfLiteTypedTensorString<float>(tensor);
default:
LOG(QFATAL) << "Unsupported type: " << TfLiteTypeGetName(tensor.type);
}
}
int main(int argc, char** argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
<< "': " << error.message() << "\n";
return 1;
}
// Load the MLIR module.
mlir::MLIRContext context;
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
std::unique_ptr<mlir::Module> module(
mlir::parseSourceFile(source_mgr, &context));
if (!module) return 1;
// TODO(jpienaar): Expand to support inputs.
mlir::Function* main = module->getNamedFunction("main");
QCHECK(main) << "No 'main' function specified.";
if (main->getType().getNumInputs() != 0)
LOG(QFATAL) << "NYI: Only nullary functions supported.";
// Convert to flatbuffer.
std::string serialized_flatbuffer;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops))
return 1;
// Create TFLite interpreter & invoke converted program.
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(serialized_flatbuffer.c_str(),
serialized_flatbuffer.size());
tflite::ops::builtin::BuiltinOpResolver builtins;
std::unique_ptr<tflite::Interpreter> interpreter;
QCHECK(tflite::InterpreterBuilder(*model, builtins)(&interpreter) ==
kTfLiteOk);
QCHECK(interpreter->AllocateTensors() == kTfLiteOk);
QCHECK(interpreter->Invoke() == kTfLiteOk);
// Print the resulting outputs.
// TODO(jpienaar): Allow specifying output stream/file.
QCHECK(interpreter->outputs().size() == main->getType().getNumResults());
for (int index : interpreter->outputs()) {
const auto& out = *interpreter->tensor(index);
// Print name if named.
if (out.name) fprintf(stdout, "%s: ", out.name);
// Print tensor result.
fprintf(stdout, "Tensor<type: %s, shape: %s, values: %s>\n",
TfLiteTypeGetName(out.type), TfLiteTensorDimString(out).c_str(),
TfLiteTensorString(out).c_str());
}
return 0;
}

View File

@ -0,0 +1,292 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <assert.h>
#include <sstream>
#include <string>
#include <vector>
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
using llvm::DefInit;
using llvm::dyn_cast;
using llvm::formatv;
using llvm::LessRecord;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::RecordRecTy;
using llvm::SmallVector;
using llvm::StringInit;
using llvm::StringRef;
// Returns the associated option name for the given op definition.
static inline std::string GetOperatorOptionName(const Record &def) {
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
assert(def.getName().endswith("Op") && "unexpected op suffix");
auto *custom_option = dyn_cast<StringInit>(def.getValueInit("customOption"));
std::ostringstream oss;
if (custom_option)
oss << custom_option->getValue().str();
else
oss << def.getName().drop_front(4).drop_back(2).str() << "Options";
return oss.str();
}
// Returns the builder function name for the given op definition.
static inline std::string GetOperatorBuilderName(StringRef op_name) {
assert(op_name.startswith("TFL_") && "unexpected op prefix");
assert(op_name.endswith("Op") && "unexpected op suffix");
// E.g., AddOp -> CreateAddOperator
std::ostringstream oss;
oss << "Create" << op_name.drop_front(4).str() << "erator";
return oss.str();
}
static void EmitOptionBuilders(const RecordKeeper &record_keeper,
const std::vector<Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;
const auto attr_type = record_keeper.getClass("Attr");
for (const auto *def : defs) {
// TFLite ops without options are skipped over.
if (!def->getValueAsBit("hasOptions")) continue;
StringRef op_name = def->getName().drop_front(4); // Strip 'TFL_' prefix
std::string option_name = GetOperatorOptionName(*def);
os << "flatbuffers::Offset<tflite::" << option_name << "> Create"
<< option_name << "(mlir::TFL::" << op_name
<< " op, flatbuffers::FlatBufferBuilder *fbb) {\n";
// Construct all the builder option needed.
SmallVector<std::string, 8> options;
// Add options due to attributes (not-derived).
auto *arg_values = def->getValueAsDag("arguments");
for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) {
auto arg = arg_values->getArg(i);
DefInit *arg_def = dyn_cast<DefInit>(arg);
if (!arg_def) continue;
if (arg_def->getDef()->isSubClassOf(attr_type)) {
// This binds the name of the attribute in the TD file with the name
// of the add function of the builder and also with the conversion
// function to convert from the internal representation to the format
// expected by the flatbuffer builder. While this constrains the
// naming of the ops/attributes in the TD file, this also removes the
// need for specifying indirection. This tool is specific to TFLite
// conversion generation and so the simplicity was chosen over the
// flexibility.
StringRef arg_name = arg_values->getArgNameStr(i);
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
options.push_back(arg_name.str());
}
}
// Add options due to derived attributes.
for (const auto &val : def->getValues()) {
if (auto *record = dyn_cast<RecordRecTy>(val.getType())) {
if (record->isSubClassOf(attr_type)) {
if (record->getClasses().size() != 1) {
PrintFatalError(
def->getLoc(),
"unsupported attribute modelling, only single class expected");
}
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
val.getName(), record->getClasses()[0]->getName());
options.push_back(val.getName());
}
}
}
os << " tflite::" << option_name << "Builder b(*fbb);\n";
for (const auto &option : options)
os << formatv(" b.add_{0}(std::move({0}));\n", option);
os << " return b.Finish();\n}\n";
}
}
// For each TFLite op, emits a builder function that packs the TFLite op into
// the corresponding FlatBuffer object.
//
// TODO(hinsu): Revisit if only builtin_options and mutating_variable_inputs
// arguments that depend on op definitions should be auto-generated and then
// operator should be built by the caller because it does not require
// auto-generation.
static void EmitOperatorBuilders(const std::vector<Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;
for (const auto *def : defs) {
StringRef op_name = def->getName().drop_front(4);
// Signature
os << "static flatbuffers::Offset<tflite::Operator> "
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
<< " tflOp, uint32_t opcode_index, "
<< "const std::vector<int32_t>& operands,"
<< "const std::vector<int32_t>& results,"
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
// Inputs & outputs
os << " auto inputs = fbb->CreateVector(operands);\n"
" auto outputs = fbb->CreateVector(results);\n\n";
// Build the FlatBuffer operator
os << " return tflite::CreateOperator(\n"
" *fbb, opcode_index, inputs, outputs,\n";
if (def->getValueAsBit("hasOptions")) {
auto option_name = GetOperatorOptionName(*def);
os << " tflite::BuiltinOptions_" << option_name << ", "
<< "Create" << option_name << "(tflOp, fbb).Union(),\n";
} else {
os << " tflite::BuiltinOptions_NONE, /*builtin_options=*/0,\n";
}
// Only builtin ops' builders are auto-generated. custom_options are only
// used by custom or flex ops and those ops are handled manually.
os << " /*custom_options=*/0, "
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
" /*mutating_variable_inputs=*/0);\n"
"}\n\n";
}
}
static inline std::string GetUpperCasedName(const Record &def) {
auto name = def.getValueAsString("opName");
return name.upper();
}
// Emits a function that returns builtin operator code for each TFLite op.
//
// The signature of the function is:
//
// llvm::Optional<tflite::BuiltinOperator>
// mlir::GetBuiltinOpCode(mlir::Operation* op);
//
// TODO(hinsu): Consider converting this to a static constant associative
// container instead of a series of if conditions, if required.
static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;
os << "llvm::Optional<tflite::BuiltinOperator> "
"mlir::GetBuiltinOpCode(mlir::Operation* op) {\n";
for (const auto *def : defs) {
StringRef op_name = def->getName().drop_front(4);
os << " if (isa<mlir::TFL::" << op_name << ">(op))\n"
<< " return tflite::BuiltinOperator_" << GetUpperCasedName(*def)
<< ";\n";
}
os << " return llvm::None;\n"
"}\n";
}
// Emits a builder function that returns the packed FlatBuffer object given
// a general mlir::Operation.
//
// The signature of the function is:
//
// llvm::Optional<Flatbuffers::Offset<tflite::Operator>>
// mlir::CreateFlatBufferOperator(
// mlir::Operation* op,
// uint32_t opcode_index,
// const std::vector<int32_t>& operands,
// const std::vector<int32_t>& results,
// flatbuffers::FlatBufferBuilder *fbb);
static void EmitBuildOperator(const std::vector<Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;
// Signature
os << "llvm::Optional<flatbuffers::Offset<tflite::Operator>>\n"
"mlir::CreateFlatBufferOperator(mlir::Operation* op, "
"uint32_t opcode_index, "
"const std::vector<int32_t>& operands,"
"const std::vector<int32_t>& results,"
"flatbuffers::FlatBufferBuilder *fbb) {\n";
for (const auto *def : defs) {
StringRef op_name = def->getName().drop_front(4);
// Try to cast to each op case and call the corresponding op builder
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
<< ">(op))\n"
<< " return " << GetOperatorBuilderName(def->getName())
<< "(tflOp, opcode_index, operands, results, fbb);\n";
}
os << " return llvm::None;\n"
"}\n";
}
// The function below has a non-constant reference as that is required by LLVM's
// TableGenMain.
// NOLINTNEXTLINE
static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
emitSourceFileHeader("MLIR TFLite FlatBuffer Builders", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("TFL_Op");
llvm::sort(defs, LessRecord());
for (const auto *def : defs) {
// TFLite ops in the .td file are expected to follow the naming convention:
// TFL_<OpName>Op.
// The generated TFLite op C++ class should be TFL::<OpName>Op.
// The generated operator's options should be tflite::<OpName>Options.
// The option builder should be Create<OpName>Options.
if (!def->getName().startswith("TFL_"))
PrintFatalError(def->getLoc(),
"unexpected op name format: 'TFL_' prefix missing");
if (!def->getName().endswith("Op"))
PrintFatalError(def->getLoc(),
"unexpected op name format: 'Op' suffix missing");
}
EmitOptionBuilders(records, defs, &os);
os << "\n\n";
EmitOperatorBuilders(defs, &os);
os << "\n\n";
EmitGetBuiltinOpCode(defs, &os);
os << "\n\n";
EmitBuildOperator(defs, &os);
return false;
}
int main(int argc, char **argv) {
llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
llvm::PrettyStackTraceProgram X(argc, argv);
llvm::llvm_shutdown_obj Y;
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
}

View File

@ -0,0 +1,30 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
package_group(
name = "friends",
packages = [
"//tensorflow/lite/toco/...",
],
)
cc_library(
name = "graphdef_to_tfl_flatbuffer",
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
hdrs = [
"graphdef_to_tfl_flatbuffer.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"@local_config_mlir//:IR",
],
)

View File

@ -0,0 +1,142 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
namespace tensorflow {
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) {
case toco::IODataType::FLOAT:
return DT_FLOAT;
case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::STRING:
return DT_STRING;
default:
return DT_INVALID;
}
}
// Give a warning for any unused flags that have been specified.
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags) {
if (toco_flags.inference_input_type()) {
LOG(WARNING) << "Ignored inference_input_type.";
}
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
LOG(WARNING) << "Ignored default_ranges_stats.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
if (toco_flags.reorder_across_fake_quant()) {
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
}
if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges.";
}
if (toco_flags.post_training_quantize()) {
LOG(WARNING) << "Ignored post_training_quantize.";
}
if (toco_flags.dump_graphviz_dir().empty()) {
LOG(WARNING) << "Ignored dump_graphviz_dir.";
}
if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video.";
}
if (model_flags.allow_nonexistent_arrays()) {
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
}
}
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
NodeSpecs specs;
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<float> node_mins;
std::vector<float> node_maxs;
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
for (auto& flag : model_flags.input_arrays()) {
node_names.push_back(flag.name());
node_dtypes.push_back(
DataType_Name(ConvertIODataTypeToDataType(flag.data_type())));
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
const float mean_value = flag.mean_value();
const float std_value = flag.std_value();
const float qmin = 0, qmax = 255;
node_mins.push_back((qmin - mean_value) / std_value);
node_maxs.push_back((qmax - mean_value) / std_value);
}
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
node_names, node_dtypes, node_shapes, inference_type, node_mins,
node_maxs, &specs.inputs));
// Parse output arrays.
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
model_flags.output_arrays().end());
TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
output_arrays, &specs.output_arrays, &specs.output_arrays_order));
// Other flags.
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
specs.prune_unused_nodes = true;
WarningUnusedFlags(model_flags, toco_flags);
bool emit_quant_adaptor_ops = false;
bool lower_tensor_list_ops = false;
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
return ConvertTFControlFlowToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
lower_tensor_list_ops, result);
}
} // namespace tensorflow

View File

@ -0,0 +1,36 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
namespace tensorflow {
// Converts the given GraphDef to a TF Lite FlatBuffer string according to the
// given model flags, toco flags and debug information. Returns error status if
// it fails to convert the input.
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
const GraphDef& input, string* result);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
],
)

View File

@ -0,0 +1,134 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: @broadcast_scalar_scalar_scalar
func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<i32>
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<i32>
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: @broadcast_tensor_scalar_tensor
func @broadcast_tensor_scalar_tensor(tensor<4xi32>, tensor<i32>) -> tensor<4xi32> {
^bb0(%arg0: tensor<4xi32>, %arg1: tensor<i32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// -----
// Check only one dimension has size 1
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}
// -----
// Check multiple dimensions have size 1
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
return %0 : tensor<8x7x6x5xi32>
}
// -----
// Check leading unknown dimension
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
return %0 : tensor<?x7x6x5xi32>
}
// -----
// Check unknown dimension in the middle
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> {
^bb0(%arg0: tensor<8x1x?x1xi32>, %arg1: tensor<7x1x5xi32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
return %0 : tensor<8x7x?x5xi32>
}
// -----
// Check incompatible vector and tensor result type
func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
// expected-error @+1 {{cannot broadcast vector with tensor}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
return %0 : vector<4xf32>
}
// -----
// Check incompatible operand types with known dimension
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x3xi32>):
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}
// -----
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
// expected-error @+1 {{does not have the same shape as the one computed}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
return %0 : tensor<4x3x3xi32>
}
// -----
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
// expected-error @+1 {{does not have the same shape as the one computed}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
return %0 : tensor<8x7x6x1xi32>
}
// -----
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}
// -----
// Check incompatible result type with unknown dimension
func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
// expected-error @+1 {{does not have the same shape as the one computed}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
return %0 : tensor<8x7x6x5xi32>
}
// -----
// Check unranked operand but ranked result
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<*xi32>):
// expected-error @+1 {{broadcast unranked tensor should result in unranked tensor}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}

View File

@ -0,0 +1,92 @@
// RUN: tf-opt -canonicalize %s | FileCheck %s
// Checks that tfl.reshape should be removed if its output's only user is
// another tfl.reshape
func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
^bb0(%arg0: tensor<4x4x4xf32>) :
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
// CHECK: return
}
// Checks that tfl.reshape should be removed if its output has more than one
// user but all users are tfl.reshape
func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32> {
^bb0(%arg0: tensor<4x4x4xf32>) :
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
%2 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
%3 = addf %1, %2 : tensor<64xf32>
return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1
// CHECK: return %2
}
// Checks that tfl.reshape should be kept if its output has more than one
// user and not all users are tfl.reshape
func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32>, tensor<64xf32>) {
^bb0(%arg0: tensor<4x4x4xf32>) :
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
// CHECK: return
}
// Checks that tfl.reshape should be removed if its output type is the same
// as its input type
func @reshape_removeIdentity(tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
^bb0(%arg0: tensor<4x4x4xf32>) :
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
return %0 : tensor<4x4x4xf32>
// CHECK-LABEL: func @reshape_removeIdentity
// CHECK: return %arg0 : tensor<4x4x4xf32>
}
// Checks that tfl.fake_quant should be removed if all its users have valid
// "minmax" attributes.
func @fakequant_dropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%1 = tfl.pow %arg0, %0 {minmax = [0.4, 0.6]} : tensor<i32>
%2 = tfl.pow %1, %0 {minmax = [0.5, 0.7]} : tensor<i32>
return %2 : tensor<i32>
// CHECK-LABEL: fakequant_dropfakequant
// CHECK-NEXT: %0 = tfl.pow %arg0, %arg0 {minmax = [4.000000e-01, 6.000000e-01]} : tensor<i32>
// CHECK-NEXT: %1 = tfl.pow %0, %arg0 {minmax = [5.000000e-01, 0.69999999999999996]} : tensor<i32>
// CHECK-NEXT: return %1 : tensor<i32>
}
// Checks that tfl.fake_quant should not be removed if some of its users or
// itself don't have valid "minmax" attributes.
func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%1 = tfl.pow %arg0, %0 : tensor<i32>
%2 = tfl.pow %1, %0 : tensor<i32>
%5 = "tfl.fake_quant"(%arg0) {name = 1, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%6 = tfl.pow %arg0, %5 : tensor<i32>
%7 = tfl.pow %6, %5 : tensor<i32>
%11 = addi %2, %7 : tensor<i32>
return %11 : tensor<i32>
// CHECK-LABEL: fakequant_notdropfakequant
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], name = 0 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
// CHECK: %3 = "tfl.fake_quant"(%arg0) {minmax = [1.000000e-01, 2.000000e-01], name = 1 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
}

View File

@ -0,0 +1,275 @@
// RUN: tf-opt %s -test-constant-fold | FileCheck %s
// CHECK-LABEL: @add_float
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
%0 = constant dense<4.5> : tensor<f32>
%1 = constant dense<1.5> : tensor<f32>
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
%7 = "tfl.add"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
%8 = "tfl.add"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%9 = "tfl.add"(%2, %3) {fused_activation_function = "SIGN_BIT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %5, %6, %7, %8, %9 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @add_int
func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%0 = constant dense<8> : tensor<i32>
%1 = constant dense<1> : tensor<i32>
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
%7 = "tfl.add"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
%8 = "tfl.add"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// CHECK-LABEL: @sub_float
func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
%0 = constant dense<4.5> : tensor<f32>
%1 = constant dense<1.5> : tensor<f32>
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
%7 = "tfl.sub"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
%8 = "tfl.sub"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %5, %6, %7, %8 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @sub_int
func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%0 = constant dense<8> : tensor<i32>
%1 = constant dense<1> : tensor<i32>
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
%7 = "tfl.sub"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
%8 = "tfl.sub"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// CHECK-LABEL: @mul_float
func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
%0 = constant dense<4.5> : tensor<f32>
%1 = constant dense<1.5> : tensor<f32>
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
%7 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
%8 = "tfl.mul"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %5, %6, %7, %8 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @mul_int
func @mul_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%0 = constant dense<8> : tensor<i32>
%1 = constant dense<1> : tensor<i32>
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK-DAG: [[cst0:%.*]] = constant dense<8> : tensor<i32>
// CHECK-DAG: [[cst1:%.*]] = constant dense<-16> : tensor<4xi32>
// CHECK-DAG: [[cst2:%.*]] = constant dense<4> : tensor<4xi32>
// CHECK-DAG: [[cst3:%.*]] = constant dense<-8> : tensor<4xi32>
// CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]]
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
%7 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
%8 = "tfl.mul"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// CHECK-LABEL: @add_dense_splat_int
func @add_dense_splat_int() -> tensor<4xi32> {
%0 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
%1 = constant dense< 5> : tensor<4xi32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_splat_dense_int
func @add_splat_dense_int() -> tensor<4xi32> {
%0 = constant dense< 5> : tensor<4xi32>
%1 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_dense_dense_int_same_shape
func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
%0 = constant dense<[15, 23, -44, -2]> : tensor<4xi32>
%1 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>) {
%cst_0 = constant dense<[10, 20]> : tensor<2xi32>
%cst_1 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%cst_2 = constant dense<[[[1, 1], [2, 2]], [[3, 3], [4, 4]]]> : tensor<2x2x2xi32>
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor< 2xi32>, tensor< 2x2xi32>) -> tensor< 2x2xi32>
%1 = "tfl.add"(%cst_2, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2xi32>, tensor< 2x2xi32>) -> tensor<2x2x2xi32>
%2 = "tfl.add"(%cst_0, %cst_2) {fused_activation_function = "NONE"} : (tensor< 2xi32>, tensor<2x2x2xi32>) -> tensor<2x2x2xi32>
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1
}
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%cst_0 = constant dense<[[1, 2]]> : tensor<1x2xi32>
%cst_1 = constant dense<[[3], [4]]> : tensor<2x1xi32>
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// We don't support this case yet.
// %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: %0 = "tfl.add"
// CHECK: return %0
}
// CHECK-LABEL: @add_dense_splat_float
func @add_dense_splat_float() -> tensor<4xf32> {
%0 = constant dense<[-10.0, -1.5, 42.0, 7.25]> : tensor<4xf32>
%1 = constant dense< 3.5> : tensor<4xf32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_splat_dense_float
func @add_splat_dense_float() -> tensor<4xf32> {
%0 = constant dense< 3.5> : tensor<4xf32>
%1 = constant dense<[-10.0, -1.5, 42.0, 7.25]> : tensor<4xf32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_dense_dense_float_same_shape
func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
%0 = constant dense<[1.5, 2.3, -4.4, -2.0]> : tensor<4xf32>
%1 = constant dense<[-10.4, -1.3, 42.4, 100.0]> : tensor<4xf32>
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>) {
%cst_0 = constant dense<[1., -4.]> : tensor<2xf32>
%cst_1 = constant dense<[[-5.5, 1.5], [7.5, -4.5]]> : tensor<2x2xf32>
%cst_2 = constant dense<[[[1., 1.], [2., 2.]], [[3., 3.], [4., 4.]]]> : tensor<2x2x2xf32>
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor< 2xf32>, tensor< 2x2xf32>) -> tensor< 2x2xf32>
%1 = "tfl.add"(%cst_2, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2xf32>, tensor< 2x2xf32>) -> tensor<2x2x2xf32>
%2 = "tfl.add"(%cst_0, %cst_2) {fused_activation_function = "NONE"} : (tensor< 2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32>
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1
}
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
%cst_0 = constant dense<[[1.5, -2.5]]> : tensor<1x2xf32>
%cst_1 = constant dense<[[-3.], [4.]]> : tensor<2x1xf32>
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// We don't support this case yet.
// CHECK: %0 = "tfl.add"
// CHECK: return %0
}

View File

@ -0,0 +1,31 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [
":debug_info_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["pbtxt"],
)
# Bundle together all the debug info files that are used by the tests.
filegroup(
name = "debug_info_files",
srcs = glob(
["**/*.debug"],
),
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
],
)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,175 @@
files: "fake/user/code/file_A.py"
files: "fake/user/code/file_B.py"
files: "fake/user/code/file_C.py"
files: "fake/user/code/file_D.py"
files: "fake/user/code/file_E.py"
files: "fake/user/code/file_F.py"
files: "fake/user/code/file_G.py"
files: "fake/user/code/file_H.py"
files: "fake/user/code/file_I.py"
files: "fake/user/code/file_J.py"
files: "fake/user/code/file_K.py"
files: "fake/user/code/file_L.py"
files: "fake/user/code/file_M.py"
files: "fake/user/code/file_N.py"
files: "fake/user/code/file_O.py"
files: "fake/user/code/file_P.py"
files: "fake/user/code/file_Q.py"
files: "fake/user/code/file_R.py"
files: "fake/user/code/file_S.py"
files: "fake/user/code/file_T.py"
files: "fake/user/code/file_U.py"
files: "fake/user/code/file_V.py"
files: "fake/user/code/file_W.py"
files: "fake/user/code/file_X.py"
files: "fake/user/code/file_Y.py"
files: "fake/user/code/file_Z.py"
files: "fake/user/code/file_1.py"
files: "fake/user/code/file_2.py"
files: "fake/user/code/file_3.py"
files: "fake/user/code/file_4.py"
files: "fake/user/code/file_5.py"
files: "fake/user/code/file_6.py"
files: "fake/user/code/file_a.py"
files: "fake/user/code/file_b.py"
files: "fake/user/code/file_c.py"
files: "fake/user/code/file_d.py"
files: "fake/user/code/file_e.py"
files: "fake/user/code/file_f.py"
files: "fake/user/code/file_g.py"
files: "fake/user/code/file_h.py"
files: "fake/user/code/file_i.py"
files: "fake/user/code/file_j.py"
files: "fake/user/code/file_k.py"
files: "fake/user/code/file_l.py"
files: "fake/user/code/file_m.py"
files: "fake/user/code/file_n.py"
files: "fake/user/code/file_o.py"
files: "fake/user/code/file_p.py"
files: "fake/user/code/file_q.py"
files: "fake/user/code/file_r.py"
files: "fake/user/code/file_s.py"
files: "fake/user/code/file_t.py"
files: "fake/user/code/file_u.py"
files: "fake/user/code/file_v.py"
files: "fake/user/code/file_w.py"
files: "fake/user/code/file_x.py"
files: "fake/user/code/file_y.py"
files: "fake/user/code/file_z.py"
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
value {
file_line_cols {
file_index: 33
line: 383
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
value {
file_line_cols {
file_index: 49
line: 886
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
value {
file_line_cols {
file_index: 38
line: 777
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
value {
file_line_cols {
file_index: 49
line: 915
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
value {
file_line_cols {
file_index: 44
line: 793
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
value {
file_line_cols {
file_index: 49
line: 335
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
value {
file_line_cols {
file_index: 44
line: 386
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
value {
file_line_cols {
file_index: 49
line: 492
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/weights"
value {
file_line_cols {
file_index: 54
line: 649
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/weights/read"
value {
file_line_cols {
file_index: 49
line: 421
}
}
}
traces {
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
value {
file_line_cols {
file_index: 5
line: 362
}
}
}
traces {
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
value {
file_line_cols {
file_index: 2
line: 27
}
}
}
traces {
key: "input"
value {
file_line_cols {
file_index: 40
line: 690
}
}
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,187 @@
files: "fake/user/code/file_A.py"
files: "fake/user/code/file_B.py"
files: "fake/user/code/file_C.py"
files: "fake/user/code/file_D.py"
files: "fake/user/code/file_E.py"
files: "fake/user/code/file_F.py"
files: "fake/user/code/file_G.py"
files: "fake/user/code/file_H.py"
files: "fake/user/code/file_I.py"
files: "fake/user/code/file_J.py"
files: "fake/user/code/file_K.py"
files: "fake/user/code/file_L.py"
files: "fake/user/code/file_M.py"
files: "fake/user/code/file_N.py"
files: "fake/user/code/file_O.py"
files: "fake/user/code/file_P.py"
files: "fake/user/code/file_Q.py"
files: "fake/user/code/file_R.py"
files: "fake/user/code/file_S.py"
files: "fake/user/code/file_T.py"
files: "fake/user/code/file_U.py"
files: "fake/user/code/file_V.py"
files: "fake/user/code/file_W.py"
files: "fake/user/code/file_X.py"
files: "fake/user/code/file_Y.py"
files: "fake/user/code/file_Z.py"
files: "fake/user/code/file_1.py"
files: "fake/user/code/file_2.py"
files: "fake/user/code/file_3.py"
files: "fake/user/code/file_4.py"
files: "fake/user/code/file_5.py"
files: "fake/user/code/file_6.py"
files: "fake/user/code/file_a.py"
files: "fake/user/code/file_b.py"
files: "fake/user/code/file_c.py"
files: "fake/user/code/file_d.py"
files: "fake/user/code/file_e.py"
files: "fake/user/code/file_f.py"
files: "fake/user/code/file_g.py"
files: "fake/user/code/file_h.py"
files: "fake/user/code/file_i.py"
files: "fake/user/code/file_j.py"
files: "fake/user/code/file_k.py"
files: "fake/user/code/file_l.py"
files: "fake/user/code/file_m.py"
files: "fake/user/code/file_n.py"
files: "fake/user/code/file_o.py"
files: "fake/user/code/file_p.py"
files: "fake/user/code/file_q.py"
files: "fake/user/code/file_r.py"
files: "fake/user/code/file_s.py"
files: "fake/user/code/file_t.py"
files: "fake/user/code/file_u.py"
files: "fake/user/code/file_v.py"
files: "fake/user/code/file_w.py"
files: "fake/user/code/file_x.py"
files: "fake/user/code/file_y.py"
files: "fake/user/code/file_z.py"
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
value {
file_line_cols {
file_index: 33
line: 383
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
value {
file_line_cols {
file_index: 49
line: 886
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
value {
file_line_cols {
file_index: 38
line: 777
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
value {
file_line_cols {
file_index: 49
line: 915
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
value {
file_line_cols {
file_index: 44
line: 793
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
value {
file_line_cols {
file_index: 49
line: 335
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
value {
file_line_cols {
file_index: 44
line: 386
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
value {
file_line_cols {
file_index: 49
line: 492
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/weights"
value {
file_line_cols {
file_index: 54
line: 649
}
}
}
traces {
key: "MobilenetV1/Conv2d_0/weights/read"
value {
file_line_cols {
file_index: 49
line: 421
}
}
}
traces {
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
value {
file_line_cols {
file_index: 5
line: 362
}
}
}
traces {
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
value {
file_line_cols {
file_index: 2
line: 27
}
file_line_cols {
file_index: 3
line: 28
}
file_line_cols {
file_index: 4
line: 29
}
file_line_cols {
file_index: 5
line: 30
}
}
}
traces {
key: "input"
value {
file_line_cols {
file_index: 40
line: 690
}
}
}

View File

@ -0,0 +1,20 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
],
)

View File

@ -0,0 +1,94 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
# Add two tensor<4xi32> inputs and return the result
node {
name: "Add"
op: "Add"
input: "input0"
input: "input1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
versions {
producer: 27
}
# CHECK: {
# CHECK-NEXT: version: 3,
# CHECK-NEXT: operator_codes: [ {
# CHECK-EMPTY:
# CHECK-NEXT: } ],
# CHECK-NEXT: subgraphs: [ {
# CHECK-NEXT: tensors: [ {
# CHECK-NEXT: shape: [ 4 ],
# CHECK-NEXT: type: INT32,
# CHECK-NEXT: buffer: 1,
# CHECK-NEXT: name: "input0",
# CHECK-NEXT: quantization: {
# CHECK-EMPTY:
# CHECK-NEXT: }
# CHECK-NEXT: }, {
# CHECK-NEXT: shape: [ 4 ],
# CHECK-NEXT: type: INT32,
# CHECK-NEXT: buffer: 2,
# CHECK-NEXT: name: "input1",
# CHECK-NEXT: quantization: {
# CHECK-EMPTY:
# CHECK-NEXT: }
# CHECK-NEXT: }, {
# CHECK-NEXT: shape: [ ],
# CHECK-NEXT: type: INT32,
# CHECK-NEXT: buffer: 3,
# CHECK-NEXT: name: "Add",
# CHECK-NEXT: quantization: {
# CHECK-EMPTY:
# CHECK-NEXT: }
# CHECK-NEXT: } ],
# CHECK-NEXT: inputs: [ 0, 1 ],
# CHECK-NEXT: outputs: [ 2 ],
# CHECK-NEXT: operators: [ {
# CHECK-NEXT: inputs: [ 0, 1 ],
# CHECK-NEXT: outputs: [ 2 ],
# CHECK-NEXT: builtin_options_type: AddOptions,
# CHECK-NEXT: builtin_options: {
# CHECK-EMPTY:
# CHECK-NEXT: }
# CHECK-NEXT: } ]
# CHECK-NEXT: name: "main"
# CHECK-NEXT: } ],
# CHECK-NEXT: description: "MLIR Converted.",
# CHECK-NEXT: buffers: [ {
# CHECK-EMPTY:
# CHECK-NEXT: }, {
# CHECK-EMPTY:
# CHECK-NEXT: }, {
# CHECK-EMPTY:
# CHECK-NEXT: }, {
# CHECK-EMPTY:
# CHECK-NEXT: } ]
# CHECK-NEXT: }

View File

@ -0,0 +1,232 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 8
}
dim {
size: 8
}
dim {
size: 2
}
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/w"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 3
}
dim {
size: 3
}
dim {
size: 2
}
dim {
size: 2
}
}
tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374<Z\346\204>"
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/w/read"
op: "Identity"
input: "conv_net_2d/conv_2d_0/w"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@conv_net_2d/conv_2d_0/w"
}
}
}
}
node {
name: "conv_net_2d_1/conv_2d_0/convolution"
op: "Conv2D"
input: "input"
input: "conv_net_2d/conv_2d_0/w/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
attr {
key: "dilations"
value {
list {
i: 1
i: 1
i: 1
i: 1
}
}
}
attr {
key: "explicit_paddings"
value {
list {
}
}
}
attr {
key: "padding"
value {
s: "SAME"
}
}
attr {
key: "strides"
value {
list {
i: 1
i: 1
i: 1
i: 1
}
}
}
attr {
key: "use_cudnn_on_gpu"
value {
b: true
}
}
}
node {
name: "conv_net_2d/conv_2d_0/b"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\315\314\314=\315\314\314="
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/b/read"
op: "Identity"
input: "conv_net_2d/conv_2d_0/b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@conv_net_2d/conv_2d_0/b"
}
}
}
}
node {
name: "conv_net_2d_1/conv_2d_0/BiasAdd"
op: "BiasAdd"
input: "conv_net_2d_1/conv_2d_0/convolution"
input: "conv_net_2d/conv_2d_0/b/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
}
node {
name: "conv_net_2d_1/Relu"
op: "Relu"
input: "conv_net_2d_1/conv_2d_0/BiasAdd"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "output_0"
op: "Identity"
input: "conv_net_2d_1/Relu"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}
# CHECK: 'main' inputs:
# CHECK-NEXT: name: 'input'
# CHECK-NEXT: 'main' outputs:
# CHECK-NEXT: name: 'output_0'

View File

@ -0,0 +1,45 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=output %s -o - --output-mlir -tf-extra-opdefs="name: 'BannaPotatoSaladWithColeslaw' input_arg: { name: 'a' type: DT_INT32 } input_arg: { name: 'b' type: DT_INT32 } output_arg: { name: 'c' type: DT_INT32 }" | FileCheck %s
node {
name: "output"
op: "BannaPotatoSaladWithColeslaw"
input: "input0"
input: "input1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
# CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg1) : (tensor<4xi32>) -> tensor<4xi32>
# CHECK-NEXT: %2 = "tf.BannaPotatoSaladWithColeslaw"(%0, %1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %2 : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -0,0 +1,54 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=4 -tf-input-data-types=DT_INT32 -tf-output-arrays=output %s -o - --output-mlir | FileCheck %s
node {
name: "default"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "input"
op: "Identity"
input: "default"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "output"
op: "Identity"
input: "input"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 27
}
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
# CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
# CHECK-NEXT: return %0 : tensor<4xi32>
# CHECK-NEXT: }

View File

@ -0,0 +1,757 @@
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
%4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
return %7: tensor<1xi32>
// CHECK-LABEL: addRelu
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: return %4 : tensor<1xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
// CHECK-LABEL: LeakyRelu
// CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha = 1.000000e-01 : f32} : (tensor<1xf32>) -> tensor<1xf32>
}
func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
%2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
return %2 : tensor<1x10x10x32xf32>
// CHECK-LABEL: biasAdd
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
// CHECK: %1 = tfl.add %0, %arg0 {fused_activation_function = "RELU6"} : tensor<1x10x10x32xf32>
}
func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> {
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xi32>, tensor<32xi32>) -> tensor<1x10x10x32xi32>
return %0 : tensor<1x10x10x32xi32>
// CHECK-LABEL: biasAddInt
// CHECK: %0 = "tf.BiasAdd"(%arg0, %arg1)
}
func @sqeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
%0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
%2 = constant dense<[2, 5]> : tensor<2xi32>
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
// CHECK-LABEL: sqeezeAndReshape
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: %1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: %2 = "tfl.reshape"(%0) : (tensor<1x10xf32>) -> tensor<2x5xf32>
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: return %3 : i32
}
func @dynamicReshape(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
// CHECK-LABEL: dynamicReshape
// CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
// CHECK: return %0 : tensor<?x?xf32>
}
func @avgPool2D(%arg0: tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
// OK
%0 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported data format
%1 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported ksize
%2 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [3, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported strides
%3 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 3]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
%5 = addf %0, %1 : tensor<1x1x1x16xf32>
%6 = addf %2, %3 : tensor<1x1x1x16xf32>
%7 = addf %5, %6 : tensor<1x1x1x16xf32>
return %7 : tensor<1x1x1x16xf32>
// CHECK-LABEL: func @avgPool2D
// CHECK: %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
// CHECK: %1 = "tf.AvgPool"(%arg0)
// CHECK: %2 = "tf.AvgPool"(%arg0)
// CHECK: %3 = "tf.AvgPool"(%arg0)
}
func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: softmax
// CHECK: %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantArgsFalse
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
}
func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = true} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantArgsTrue
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>) -> tensor<8x8x8x8xf32>
}
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%arg1 = constant dense<-0.1> : tensor<f32>
%arg2 = constant dense<0.2> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantVarsFalse
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
}
func @fakeQuantVarsTrue(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<8x8x8x8xf32> {
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {min = 0.0 : f32, max = 1.0 : f32, num_bits = 3, narrow_range = true} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantVarsTrue
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32, narrow_range = true, num_bits = 3 : i64}
}
func @const() -> tensor<2xi32> {
%0 = "tf.Const"() {device = "", name = "weights_quant/min", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>)
return %0: tensor<2xi32>
// CHECK-LABEL: @const
// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
}
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
// CHECK-LABEL: @placeholder
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
// CHECK-LABEL: @placeholder_min
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_type(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", type = i8} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
// CHECK-LABEL: @placeholder_type
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_min_max(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32, max = 0.1 : f32, type = i8} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %0: tensor<2x3xf32>
// CHECK-LABEL: @placeholder_min_max
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x3x!quant.uniform<u8:f32, 7.8431373717738134E-4:128>>}
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<2x3x!quant.uniform<u8:f32, 7.8431373717738134E-4:128>>)
}
func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<?x1001xf32>) -> tensor<2xi32>
%1 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<?x1001xf32>) -> tensor<2xi32>
%2 = "tf.Add"(%0, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %2: tensor<2xi32>
// CHECK-LABEL: shape
// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
// CHECK: %1 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
}
func @fill(%arg0: tensor<3xi32>, %arg1: tensor<f32>) -> tensor<?x?x?xf32> {
%0 = "tf.Fill"(%arg0, %arg1) : (tensor<3xi32>, tensor<f32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
// CHECK-LABEL:fill
// CHECK: %0 = "tfl.fill"(%arg0, %arg1) : (tensor<3xi32>, tensor<f32>) -> tensor<?x?x?xf32>
}
func @sigmoid(%arg0: tensor<?x88xf16>) -> tensor<?x88xf16> {
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
return %0 : tensor<?x88xf16>
// CHECK-LABEL: sigmoid
// CHECK: %0 = "tfl.logistic"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
}
func @log_softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.LogSoftmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: log_softmax
// CHECK: %0 = "tfl.log_softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.ZerosLike"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: zeros_like
// CHECK: %0 = "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
%4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
return %5: tensor<1xi32>
// CHECK-LABEL: divRelu
// CHECK: %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: return %3 : tensor<1xi32>
}
func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
return %1: tensor<1xi32>
// CHECK-LABEL: squaredDifferenceRelu
// CHECK: %0 = tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %1 : tensor<1xi32>
}
func @maxPool2D(%arg0: tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32> {
// OK
%0 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported data_format
%1 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported ksize
%2 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [3, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
// Unsupported strides
%3 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 3]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
%5 = addf %0, %1 : tensor<1x1x1x16xf32>
%6 = addf %2, %3 : tensor<1x1x1x16xf32>
%7 = addf %5, %6 : tensor<1x1x1x16xf32>
return %7 : tensor<1x1x1x16xf32>
// CHECK-LABEL: func @maxPool2D
// CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
// CHECK: %1 = "tf.MaxPool"(%arg0)
// CHECK: %2 = "tf.MaxPool"(%arg0)
// CHECK: %3 = "tf.MaxPool"(%arg0)
}
func @abs(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL:abs
// CHECK: %0 = "tfl.abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: ceil
// CHECK: %0 = "tfl.ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: return %0 : tensor<8x16xf32>
}
func @cos(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Cos"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
// CHECK-LABEL:cos
// CHECK: %0 = "tfl.cos"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @elu(%arg0: tensor<11x16xf32>) -> tensor<11x16xf32> {
%0 = "tf.Elu"(%arg0) : (tensor<11x16xf32>) -> tensor<11x16xf32>
return %0 : tensor<11x16xf32>
// CHECK-LABEL:elu
// CHECK: %0 = "tfl.elu"(%arg0) : (tensor<11x16xf32>) -> tensor<11x16xf32>
}
func @expandDims(%arg0: tensor<2x2xf32>, %arg1: tensor<i32>) -> tensor<1x2x2xf32> {
%0 = "tf.ExpandDims"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<i32>) -> tensor<1x2x2xf32>
return %0 : tensor<1x2x2xf32>
// CHECK-LABEL:expandDims
// CHECK: %0 = "tfl.expand_dims"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<i32>) -> tensor<1x2x2xf32>
}
func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor<i32>) -> tensor<2xf32> {
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
// CHECK-LABEL:gatherScalarIndices
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
}
func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tensor<3xf32> {
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL:gatherVectorIndices
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
}
func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>) -> tensor<4x5x3x6xf32> {
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
return %0 : tensor<4x5x3x6xf32>
// CHECK-LABEL:gatherHigherRankIndices
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
}
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = constant dense<[1]> : tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
return %1 : tensor<1x3x5x20xf32>
// CHECK-LABEL:gatherV2VectorIndices
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
}
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
// CHECK-LABEL:gatherV2VectorIndices
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = -1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
}
func @greater(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
return %0 : tensor<8x16xi1>
// CHECK-LABEL: greater
// CHECK: %0 = "tfl.greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0 : tensor<8x16xi1>
}
func @greater_equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
return %0 : tensor<8x16xi1>
// CHECK-LABEL: greater_equal
// CHECK: %0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0 : tensor<8x16xi1>
}
func @rank(%arg0: tensor<11x16xf32>) -> tensor<1xi32> {
%0 = "tf.Rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
// CHECK-LABEL:rank
// CHECK: %0 = "tfl.rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
}
func @floor(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Floor"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: floor
// CHECK: %0 = "tfl.floor"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: return %0 : tensor<8x16xf32>
}
func @floor_div(tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> {
^bb0(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>):
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: floor_div
// CHECK: %0 = tfl.floor_div %arg0, %arg1 : tensor<8x16xf32>
// CHECK: return %0 : tensor<8x16xf32>
}
func @not_equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
return %0 : tensor<8x16xi1>
// CHECK-LABEL: not_equal
// CHECK: %0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0 : tensor<8x16xi1>
}
func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
// CHECK-LABEL: select
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8xf32>
}
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
// CHECK-LABEL:sin
// CHECK: %0 = "tfl.sin"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tf.TopKV2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0, %1: tensor<?xf32>, tensor<?xi32>
// CHECK-LABEL: topk
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %arg1)
// CHECK: return %0
}
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
// CHECK-LABEL: topk_2
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
// CHECK: return %0
}
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
// CHECK-LABEL: topk_3
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
// CHECK: return %0
}
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
// CHECK-LABEL: topk_4
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
// CHECK: return %0
}
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
// CHECK-LABEL: topk_5
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
// CHECK: return %0
}
func @logicalAnd(%arg0: tensor<8xi1>, %arg1: tensor<8xi1>) -> tensor<8xi1> {
%0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<8xi1>, tensor<8xi1>) -> tensor<8xi1>
return %0: tensor<8xi1>
// CHECK-LABEL: logicalAnd
// CHECK: %0 = tfl.logical_and %arg0, %arg1 : tensor<8xi1>
// CHECK: return %0 : tensor<8xi1>
}
func @logicalNot(%arg0: tensor<8xi1>) -> tensor<8xi1> {
%0 = "tf.LogicalNot"(%arg0) : (tensor<8xi1>) -> tensor<8xi1>
return %0 : tensor<8xi1>
// CHECK-LABEL: logicalNot
// CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<8xi1>) -> tensor<8xi1>
}
func @logicalOr(%arg0: tensor<8xi1>, %arg1: tensor<8xi1>) -> tensor<8xi1> {
%0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<8xi1>, tensor<8xi1>) -> tensor<8xi1>
return %0: tensor<8xi1>
// CHECK-LABEL: logicalOr
// CHECK: %0 = tfl.logical_or %arg0, %arg1 : tensor<8xi1>
// CHECK: return %0 : tensor<8xi1>
}
func @addV2(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
// CHECK-LABEL: addV2
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
}
func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2x3x4xf32> {
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi32>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
// CHECK-LABEL:reverse_v2
// CHECK: %0 = "tfl.reverse_v2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi32>) -> tensor<1x2x3x4xf32>
// CHECK: return %0 : tensor<1x2x3x4xf32>
}
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL:maximum
// CHECK: %0 = "tfl.maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @minimum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL:minimum
// CHECK: %0 = "tfl.minimum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @realDiv(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: realDiv
// CHECK: %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<8x16xf32>
}
func @equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
return %0 : tensor<8x16xi1>
// CHECK-LABEL: equal
// CHECK: %0 = "tfl.equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
// CHECK: return %0 : tensor<8x16xi1>
}
func @pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%0 = "tf.Pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: pad
// CHECK: %0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
// CHECK: return %0 : tensor<?xf32>
}
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: padv2
// CHECK: %0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<?xf32>
// CHECK: return %0 : tensor<?xf32>
}
func @pack2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
%0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// CHECK-LABEL: pack2Tensors
// CHECK: %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
}
func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {N = 3 : i64, axis = 1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
// CHECK-LABEL: pack3Tensors
// CHECK: %0 = "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
}
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {N = 3 : i64, axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
// CHECK-LABEL: packNegAxis
// CHECK: %0 = "tfl.pack"(%arg0, %arg1, %arg2) {axis = -1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
}
func @unpack2Tensors(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
%0:2 = "tf.Unpack"(%arg0) {num = 2 : i64} : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: unpack2Tensors
// CHECK: %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
}
func @unpack3Tensors(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
%0:3 = "tf.Unpack"(%arg0) {num = 3 : i64, axis = 1 : i64} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: unpack3Tensors
// CHECK: %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
}
func @unpackNegAxis(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
%0:3 = "tf.Unpack"(%arg0) {num = 3 : i64, axis = -1 : i64} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: unpackNegAxis
// CHECK: %0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
}
func @mean(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> {
%0 = "tf.Mean"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
// CHECK-LABEL: mean
// CHECK: %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
}
func @mean_true(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> {
%0 = "tf.Mean"(%arg0, %arg1) {keep_dims = true} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
// CHECK-LABEL: mean_true
// CHECK: %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
}
func @sum(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: sum
// CHECK: %0 = "tfl.sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: sum_true
// CHECK: %0 = "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_min
// CHECK: %0 = "tfl.reduce_min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_min_true
// CHECK: %0 = "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_max
// CHECK: %0 = "tfl.reduce_max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_max_true
// CHECK: %0 = "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: batch_to_space_nd
// CHECK: %0 = "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
}
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: space_to_batch_nd
// CHECK: %0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
}
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed
// CHECK: %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
%0 = constant dense<[1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
}
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
}
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2With3Tensors
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_with_bilinear
// CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
}
// Note: half_pixel_centers isn't supported by TFLite, so it's not
// legalized.
func @resize_with_bilinear_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_with_bilinear_with_half_pixel_centers
// CHECK: "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true}
}
func @strided_slice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: strided_slice
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
}

View File

@ -0,0 +1,101 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<*x!tf.variant>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<3x10xf32>
return %1, %2 : tensor<10xf32>, tensor<3x10xf32>
// CHECK-LABEL: tensorlistGetItem
// CHECK: %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<3x10xf32>, tensor<i32>) -> tensor<10xf32>
// CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
}
func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<*x!tf.variant>, tensor<i32>, tensor<1xi32>) -> tensor<*xf32>
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<*xf32>
return %1, %2 : tensor<*xf32>, tensor<*xf32>
// CHECK-LABEL: tensorlistGetItemWithUnknownRank
// CHECK: %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
}
func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10xf32>) -> tensor<3x10xf32> {
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>):
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<*x!tf.variant>, tensor<i32>, tensor<10xf32>) -> tensor<*x!tf.variant>
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<3x10xf32>
return %2 : tensor<3x10xf32>
// CHECK-LABEL: tensorlistSetItem
// CHECK: %cst = constant dense<1> : tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
// CHECK: return %15 : tensor<3x10xf32>
}
func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i32>, tensor<f32>) -> tensor<5xf32> {
^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor<*x!tf.variant>
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<*x!tf.variant>, tensor<i32>, tensor<f32>) -> tensor<*x!tf.variant>
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<*x!tf.variant>, tensor<0xi32>) -> tensor<5xf32>
return %2 : tensor<5xf32>
// CHECK-LABEL: tensorlistSetItemWithScalarElements
// CHECK: %cst = constant dense<1> : tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
// CHECK: return %15 : tensor<5xf32>
}
func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xf32> {
^bb0(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
%0 = "tf.TensorListReserve"(%arg0, %arg1) {element_dtype = f32} : (tensor<3xi32>, tensor<i32>) -> tensor<*x!tf.variant>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<*x!tf.variant>, tensor<i32>, tensor<3xi32>) -> tensor<3xf32>
return %1 : tensor<3xf32>
// CHECK-LABEL: tensorlistReserve
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
// CHECK: return %3 : tensor<3xf32>
}

View File

@ -0,0 +1,20 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
],
)

View File

@ -0,0 +1,101 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MyCustomOp"
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "MyCustomOp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: ExpOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
// tf.MyCustomOp is the result of conversion to a Custom op
%3 = "tf.MyCustomOp"(%2, %1) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("MyCustomOp")
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
return %4 : tensor<4xf32>
}

View File

@ -0,0 +1,10 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
# CHECK: loc("disable_builtin.mlir":2:1): is a TFLite builtin op but builtin emission is not enabled
# CHECK-NEXT: Verification failed.
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>):
%0 = "std.constant" () {name = "Const2", value = dense<10> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Add" (%0, %1) {name = "add"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}

View File

@ -0,0 +1,14 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
# CHECK: loc("disable_flex.mlir":96:8): error: 'tf.div' op is a Flex op but Flex ops are not enabled for emission
# CHECK-NEXT: Verification failed.
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<4xf32>) -> tensor<4xf32>
%1 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// tf.div is the result of conversion to a Flex TF op
%3 = "tf.Div"(%2, %1) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%4 = "tfl.exp"(%3) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
return %4 : tensor<4xf32>
}

View File

@ -0,0 +1,98 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "mul1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 3 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: ExpOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul0")
%3 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul1")
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
return %4 : tensor<4xf32>
}

View File

@ -0,0 +1,48 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAddV2"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "tf.Placeholder.input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tf.AddV2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: custom_options: [ 5, 65, 100, 100, 86, 50, 0, 18, 18, 5, 65, 100, 100, 86, 50, 42, 7, 10, 1, 84, 18, 2, 48, 1, 50, 0, 0, 2, 27, 21, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tf.Placeholder.input"(%arg0) {name = "Placeholder"} : (tensor<3x2xf32>) -> tensor<3x2xf32>
%1 = "tf.AddV2"(%0, %0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}

View File

@ -0,0 +1,100 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexDiv"
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "div",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 3, 68, 105, 118, 0, 16, 18, 3, 68, 105, 118, 42, 7, 10, 1, 84, 18, 2, 48, 1, 50, 0, 0, 2, 23, 19, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: ExpOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
// tf.div is the result of conversion to a Flex TF op
%3 = "tf.Div"(%2, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
return %4 : tensor<4xf32>
}

View File

@ -0,0 +1,171 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LESS
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Experimental_If"
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "tfl.pseudo_input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.pseudo_input1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.less",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tf.If",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "tfl.add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: AddOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "cond_true"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "tfl.mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "cond_false"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
return %3 : tensor<1xf32>
}
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0 : tensor<*xf32>
}
func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -0,0 +1,89 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xi1>) -> tensor<4xi1> {
^bb0(%arg0: tensor<4xi1>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LOGICAL_OR
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: LOGICAL_AND
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "Const2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "logical_or",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "logical_and",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 3, 1 ],
// CHECK-NEXT: outputs: [ 4 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 1, 1, 1, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
// CHECK-EMPTY:
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xi1>) -> tensor<4xi1> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<true> : tensor<4xi1>} : () -> tensor<4xi1> loc("Const1")
%2 = "tfl.pseudo_const" () {value = dense<false> : tensor<4xi1>} : () -> tensor<4xi1> loc("Const2")
%3 = "tfl.logical_or"(%0, %2) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> loc("logical_or")
%4 = "tfl.logical_and"(%3, %1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> loc("logical_and")
return %4 : tensor<4xi1>
}

View File

@ -0,0 +1,137 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DIV
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: NEG
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "squared_difference",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "div",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "neg",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3, 2 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: DivOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 4 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: builtin_options_type: ExpOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 4,
// CHECK-NEXT: inputs: [ 5 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: builtin_options_type: NegOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.squared_difference"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
%3 = "tfl.mul"(%0, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
%4 = "tfl.div"(%3, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
%5 = "tfl.exp"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
%6 = "tfl.neg"(%5) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
return %6 : tensor<4xf32>
}

View File

@ -0,0 +1,55 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
^bb0(%arg0: tensor<1x6x6x16xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 6, 6, 16 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "avgpool",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: builtin_options_type: Pool2DOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: padding: VALID,
// CHECK-NEXT: stride_w: 1,
// CHECK-NEXT: stride_h: 3,
// CHECK-NEXT: filter_width: 6,
// CHECK-NEXT: filter_height: 3
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> loc("Input")
%1 = "tfl.average_pool_2d"(%0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool")
return %1 : tensor<1x1x1x16xf32>
}

View File

@ -0,0 +1,19 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
// | FileCheck %s
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
%2:2 = "tfl.fully_connected"(%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>, tensor<40x40xf32>)
return %2 : tensor<40x40xf32>
}
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],

View File

@ -0,0 +1,158 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: QUANTIZE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CONV_2D
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: RESHAPE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SOFTMAX
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DEQUANTIZE
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 224, 224, 3 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "tfl.pseudo_input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.quantize",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.007812 ],
// CHECK-NEXT: zero_point: [ 128 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 32, 3, 3, 3 ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.pseudo_qconst",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.021827 ],
// CHECK-NEXT: zero_point: [ 151 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 32 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tfl.pseudo_qconst1",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.000171 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "tfl.conv_2d",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.023528 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "tfl.reshape",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.023528 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "tfl.softmax",
// CHECK-NEXT: quantization: {
// CHECK-NEXT: scale: [ 0.003906 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "tfl.dequantize",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 7 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 1, 2, 3 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: Conv2DOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: stride_w: 5,
// CHECK-NEXT: stride_h: 4,
// CHECK-NEXT: dilation_w_factor: 3,
// CHECK-NEXT: dilation_h_factor: 2
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 4 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: builtin_options_type: ReshapeOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: new_shape: [ 1, 1001 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 5 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: builtin_options_type: SoftmaxOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: beta: 1.0
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 4,
// CHECK-NEXT: inputs: [ 6 ],
// CHECK-NEXT: outputs: [ 7 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%5 = "tfl.reshape"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
return %7 : tensor<1x1001xf32>
}

View File

@ -0,0 +1,53 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<3x2xi32>) -> tensor<6xi32> {
^bb0(%arg0: tensor<3x2xi32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: RESHAPE
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.reshape",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: builtin_options_type: ReshapeOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: new_shape: [ 6 ]
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%2 = "tfl.reshape" (%0) : (tensor<3x2xi32>) -> tensor<6xi32>
return %2 : tensor<6xi32>
}

View File

@ -0,0 +1,96 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SUB
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "sub",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "Const2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: builtin_options_type: SubOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: fused_activation_function: RELU6
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 3, 2 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: AddOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 10, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("Const2")
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
return %4 : tensor<3x2xi32>
}

View File

@ -0,0 +1,9 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>):
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK: error: 'unknown_op' op dialect is not registered
%1 = "unknown_op"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}

View File

@ -0,0 +1,211 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Experimental_While"
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: GREATER
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SUB
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "tfl.pseudo_input",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.pseudo_input1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tf.While",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tf.While:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "tfl.greater",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "cond"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "Const1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "tfl.sub",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 13,
// CHECK-NEXT: name: "tfl.add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3, 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: builtin_options_type: SubOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 1, 1 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: AddOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "body"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 1, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
// While %0 is greater than zero, element wise add %1 with itself.
%2:2 = "tf.While"(%0, %1) {
cond = @cond, body = @body
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
return %2#1 : tensor<1xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %1, %2 : tensor<*xi32>, tensor<*xf32>
}

View File

@ -0,0 +1,858 @@
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
// Unary math ops
// -----
// CHECK-LABEL: testCos
func @testCos(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.cos"(%arg0)
%0 = "tfl.cos"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid Cos input
func @testCosWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.cos"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testExp
func @testExp(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.exp"(%arg0)
%0 = "tfl.exp"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testFloor
func @testFloor(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.floor"(%arg0)
%0 = "tfl.floor"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testGather
func @testGather(%arg0 : tensor<?xf32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xf32>,tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testGather
func @testGather(%arg0 : tensor<2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<2xf32>,tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// ----
// CHECK-LABEL: testGatherUnknownRank
func @testGatherUnknownRank(%arg0 : tensor<*xf32>, %arg1 : tensor<1xi32>) -> tensor<*xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<*xf32>,tensor<1xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testGatherUnsupportedType(%arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
// expected-error @+1 {{op failed to verify that params and output must have same element type}}
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xi32>,tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testGatherUnsupportedRank(%arg0 : tensor<f32>, %arg1 : tensor<1xi32>) -> tensor<?xf32> {
// expected-error @+1 {{op failed to verify that operand 0 is 1-D}}
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<f32>,tensor<1xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testAbs
func @testAbs(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.abs"(%arg0)
%0 = "tfl.abs"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testAddN
func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>):
// CHECK: "tfl.add_n"(%arg0, %arg1, %arg2)
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid AddN
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
return %0 : tensor<? x f16>
}
// -----
// CHECK-LABEL: testLog
func @testLog(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.log"(%arg0)
%0 = "tfl.log"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testNeg
func @testNeg(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.neg"(%arg0)
%0 = "tfl.neg"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testRsqrt
func @testRsqrt(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.rsqrt"(%arg0)
%0 = "tfl.rsqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testSin
func @testSin(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.sin"(%arg0)
%0 = "tfl.sin"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid Sin input
func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.sin"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testSqrt
func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.sqrt"(%arg0)
%0 = "tfl.sqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testSquare
func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.square"(%arg0)
%0 = "tfl.square"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testTanh
func @testTanh(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.tanh"(%arg0)
%0 = "tfl.tanh"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testZerosLike
func @testZerosLike(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.zeros_like"(%arg0)
%0 = "tfl.zeros_like"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testDequantize
func @testDequantize(tensor<? x i32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x i32>):
// CHECK: "tfl.dequantize"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
%0 = "tfl.dequantize"(%arg0): (tensor<? x i32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testLogicalNot
func @testLogicalNot(tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>):
// CHECK: "tfl.logical_not"(%arg0)
%0 = "tfl.logical_not"(%arg0): (tensor<? x i1>) -> tensor<? x i1>
return %0 : tensor<? x i1>
}
// -----
func @testLogicalNotWrongOperandType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_not' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_not"(%arg0) : (tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// Binary math ops
// -----
// CHECK-LABEL: testAdd
func @testAdd(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// TODO(jpienaar): Enable specifying label of enum for parsing.
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testSub
func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testDiv
func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testLess
func @testLess(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: "tfl.less"(%arg0, %arg1)
%0 = "tfl.less"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
// CHECK-LABEL: testFloorDivI32
func @testFloorDivI32(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.floor_div %arg0, %arg1
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testFloorDivF32
func @testFloorDivF32(tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: tfl.floor_div %arg0, %arg1
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor<2 x f32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.floor_div"(%arg0, %arg1) : (tensor<2 x f32>, tensor<2 x i32>) -> tensor<2 x f32>
return %0#0 : tensor<2 x f32>
}
// -----
// CHECK-LABEL: testFloorMod
func @testFloorMod(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.floor_mod %arg0, %arg1
%0 = tfl.floor_mod %arg0, %arg1 : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testPow
func @testPow(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.pow %arg0, %arg1
%0 = tfl.pow %arg0, %arg1 : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testConv2D
func @testConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>):
// CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2)
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: testFakeQuant
func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], narrow_range = true, num_bits = 2 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.fake_quant"(%arg0) {minmax = [], num_bits = 2 : i32, narrow_range = true} : (tensor<? x f32>) -> tensor<? x f32>
// CHECK: %1 = "tfl.fake_quant"(%0) {minmax = [3.000000e-01, 1.400000e+00], narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%1 = "tfl.fake_quant"(%0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<? x f32>) -> tensor<? x f32>
return %1 : tensor<? x f32>
}
// CHECK-LABEL: testQuantize
func @testQuantize(tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<?x!quant.uniform<u8:f32, 1.000000e-01:128>>}
%0 = "tfl.quantize"(%arg0) {qtype = tensor<? x !quant.uniform<u8:f32, 0.1:128>>} : (tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1:128>>
}
// CHECK-LABEL: testLogicalAnd
func @testLogicalAnd(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
// CHECK: tfl.logical_and %arg0, %arg1
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
func @testLogicalAndWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_and' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testLogicalOr
func @testLogicalOr(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
// CHECK: tfl.logical_or %arg0, %arg1
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
func @testLogicalOrWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_or' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testEluF32
func @testEluF32(%arg0: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.elu"(%arg0)
%0 = "tfl.elu"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testTileF32
func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor<? x f32> {
// CHECK: "tfl.tile"(%arg0, %arg1)
%0 = "tfl.tile"(%arg0, %arg1): (tensor<4 x 1 x f32>, tensor<4 x i32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
// expected-error @+1 {{operand #0 must be tensor of floating-point values}}
%0 = "tfl.elu"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: "NONE"
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
// CHECK: "RELU"
%1 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU"} : tensor<4xi32>
// CHECK: "RELU_N1_TO_1"
%2 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"} : tensor<4xi32>
// CHECK: "RELU6"
%3 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<4xi32>
// CHECK: "TANH"
%4 = tfl.add %arg0, %arg1 {fused_activation_function = "TANH"} : tensor<4xi32>
// CHECK: "SIGN_BIT"
%5 = tfl.add %arg0, %arg1 {fused_activation_function = "SIGN_BIT"} : tensor<4xi32>
return %0, %1, %2, %3, %4, %5: tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
return %0: tensor<4xi32>
}
// -----
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
// CHECK: "SAME"
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: "VALID"
%1 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
}
// -----
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
// expected-error @+1 {{attribute 'padding' failed to satisfy constraint: padding enum}}
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SOMETHING", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
// -----
// CHECK-LABEL: testMaxPool2D
func @testMaxPool2D(tensor<256x32x32x3xf32>) -> tensor<?xf32> {
^bb0(%arg0: tensor<256x32x32x3xf32>):
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testMaxPool2DQuantized
func @testMaxPool2DQuantized(tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>> {
^bb0(%arg0: tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>):
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<?x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
// test invalid MaxPool2D
func @testMaxPool2DWrongOperandResultType(tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32> {
^bb0(%arg0: tensor<1x7x7x16xi32>):
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32>
return %0 : tensor<1x7x7x16xi32>
}
// -----
// test invalid MaxPool2D
func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>> {
^bb0(%arg0: tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>):
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
return %0 : tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
}
// -----
// CHECK-LABEL: testLogistic
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
// CHECK: "tfl.logistic"(%arg0)
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16>
return %0 : tensor<1x2x3x4x5xbf16>
}
// -----
// test invalid Logistic input
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstm
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test invalid none type applied to a tensor type arg
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test violation of projection weight and projection bias pred op trait
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testReverseV2
func @testReverseV2(%arg0: tensor<1x2x3x4xf32>, %arg1 : tensor<2xi32>) -> tensor<1x2x3x4xf32> {
// CHECK: "tfl.reverse_v2"(%arg0, %arg1)
%0 = "tfl.reverse_v2"(%arg0, %arg1): (tensor<1x2x3x4xf32>, tensor<2xi32>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
}
// -----
// test select
// CHECK-LABEL: testSelect
func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: topk
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0, %1: tensor<?xf32>, tensor<?xi32>
}
// -----
// CHECK-LABEL: topk
func @topk(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %0, %1: tensor<*xf32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: topk_2
func @topk_2(%arg0: tensor<3x4x8xf32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<3x4x8xf32>, tensor<i32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>)
return %1#0, %1#1: tensor<3x4x2xf32>, tensor<3x4x2xi32>
}
// -----
// CHECK-LABEL: topk_d
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
}
// -----
// CHECK-LABEL: topk_d
// TODO(jpienaar): This should fail but doesn't as the op definition does not
// include shape verification.
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x3xf32>, tensor<?x3xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x3xf32>, tensor<?x3xi32>)
return %1#0, %1#1: tensor<?x3xf32>, tensor<?x3xi32>
}
// -----
// CHECK-LABEL: topk_d
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: testEqual
func @testEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: "tfl.equal"(%arg0, %arg1)
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
// CHECK-LABEL: testPad
func @testPad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test Pad with invalid paddings size
func @testPadWithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 0's rank equals operand 1's size}}
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test Pad with invalid paddings rank
func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 1 is 2-D}}
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testPadV2
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
// CHECK: "tfl.padv2"(%arg0, %arg1, %cst)
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid paddings size
func @testPadV2WithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 0's rank equals operand 1's size}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid paddings rank
func @testPadV2WithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 1 is 2-D}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid constant rank
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<[2.0]> : tensor<1xf32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 2 is 0-D}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<1xf32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid constant data type
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2> : tensor<i32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that input and constant value operands must have same element type}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<i32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{input count should match 'values_count' attribute}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{output count should match 'num' attribute}}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: testMean
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
// CHECK-LABEL: testMean_true
func @testMean_true(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = true}
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
func @testMean_missing_keep_dims(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// expected-error @+1 {{'tfl.mean' op requires attribute 'keep_dims'}}
%0 = "tfl.mean"(%arg0, %arg1): (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
// CHECK-LABEL: testBatchToSpaceND
func @testBatchToSpaceND(%arg0 : tensor<4x2x2x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2)
%0 = "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testSpaceToBatchND
func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2)
%0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testConcat(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @testConcatQuantized(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i8:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
// expected-error @+1 {{'tfl.concatenation' op failed to verify that values and output must have same element type}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
// CHECK-LABEL: testResizeBilinear
func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xf32> {
// CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false}
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float values}}
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testStridedSlice
func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
}
// -----
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
// expected-error @+1 {{op failed to verify that input and output must have same element type}}
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xi32>
return %0 : tensor<1x2x2x5xi32>
}

View File

@ -0,0 +1,131 @@
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
// CHECK-LABEL: fusedConv2dRelu
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: return %0
}
// CHECK-LABEL: fusedDepthwiseConv2dRelu6
func @fusedDepthwiseConv2dRelu6(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.relu6"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: return %0
}
// CHECK-LABEL: fusedConv2dTanh
func @fusedConv2dTanh(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.tanh"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "TANH", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: return %0
}
// CHECK-LABEL: fuseAddIntoConv2d
func @fuseAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseAddWithRelu6IntoConv2d
func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
// CHECK-SAME: fused_activation_function = "RELU6"
}
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst)
// CHECK-SAME: fused_activation_function = "RELU6"
}
// CHECK-LABEL: intermOpUsedTwice
func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
%cst = constant dense<1.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00,
// CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00,
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK: return %0, %1
}
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x2xf32>
// CHECK: %cst = constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]], {{\[\[}}7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]], {{\[\[}}1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]]]]> : tensor<1x3x3x2xf32>
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: return %0
}
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%cst2 = constant dense<3.0> : tensor<112x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0.
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x2xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
// CHECK: return %1
}

View File

@ -0,0 +1,40 @@
// RUN: tf-opt %s -tfl-post-quantize | FileCheck %s
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%5 = "tfl.reshape"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
return %7 : tensor<1x1001xf32>
}
func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%2 = "tfl.pseudo_input"(%arg1) : (tensor<2x4xf32>) -> tensor<2x4xf32>
%3 = "tfl.quantize"(%2) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%4 = tfl.add %1, %3 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%5 = "tfl.dequantize"(%4) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
return %5 : tensor<2x4xf32>
}
// CHECK: func @main(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK-NEXT: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
// CHECK-NEXT: %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
// CHECK-NEXT: %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK-NEXT: %4 = "tfl.reshape"(%3) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK-NEXT: %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK-NEXT: return %5 : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
// CHECK-NEXT:}
// CHECK: func @main2(%arg0: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>, %arg1: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>> {
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg1) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg0) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: %2 = tfl.add %1, %0 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: return %2 : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT:}

View File

@ -0,0 +1,154 @@
// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s
// CHECK-LABEL: DequantizeAndQuantize
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%0 = "tfl.dequantize"(%cst) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %0 = "tfl.pseudo_qconst"()
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: %2 = "tfl.quantize"(%1)
// CHECK: return %2
}
// CHECK-LABEL: QuantizeConv2D
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%cst = constant dense<-1.23697901> : tensor<32xf32>
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
%5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0)
// CHECK: %3 = "tfl.pseudo_qconst"()
// CHECK: %4 = "tfl.dequantize"(%3)
// CHECK: %5 = "tfl.conv_2d"(%2, %4, %1)
// CHECK: %6 = "tfl.quantize"(%5)
// CHECK: return %6
}
// CHECK-LABEL: QuantizeDepthwiseConv2D
func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%cst = constant dense<-1.23697901> : tensor<32xf32>
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
%5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0)
// CHECK: %3 = "tfl.pseudo_qconst"()
// CHECK: %4 = "tfl.dequantize"(%3)
// CHECK: %5 = "tfl.depthwise_conv_2d"(%2, %4, %1)
// CHECK: %6 = "tfl.quantize"(%5)
// CHECK: return %6
}
// CHECK-LABEL: QuantizeAveragePool2D
func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.average_pool_2d"(%0) {
name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32
} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
return %1 : tensor<1x1x1x16xf32>
// CHECK: %0 = "tfl.dequantize"(%arg0)
// CHECK: %1 = "tfl.average_pool_2d"(%0)
// CHECK: %2 = "tfl.quantize"(%1)
// CHECK: %3 = "tfl.dequantize"(%2)
// CHECK: return %3 : tensor<1x1x1x16xf32>
}
// CHECK-LABEL: QuantizeReshape2D
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
return %1 : tensor<1x36x16xf32>
// CHECK: %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>}
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: return %3 : tensor<1x36x16xf32>
}
// CHECK-LABEL: QuantizeSoftmax
func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
return %1 : tensor<1x6x6x16xf32>
// CHECK: %0 = "tfl.dequantize"(%arg0)
// CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>}
// CHECK: %3 = "tfl.dequantize"(%2)
// CHECK: return %3 : tensor<1x6x6x16xf32>
}
// CHECK-LABEL: QuantizeChain
func @QuantizeChain(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%cst = constant dense<-1.23697901> : tensor<32xf32>
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
%5 = "tfl.average_pool_2d"(%2) {
name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32
} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
%6 = "tfl.conv_2d"(%5, %4, %cst) {
dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32
} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%8 = "tfl.dequantize"(%7) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x6x6x16xf32>
%9 = "tfl.reshape"(%8) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
%10 = "tfl.softmax"(%9) {beta = 1.000000e+00 : f32} : (tensor<1x36x16xf32>) -> tensor<1x36x16xf32>
return %10 : tensor<1x36x16xf32>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %3 = "tfl.pseudo_qconst"()
// CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>)
// CHECK: %5 = "tfl.average_pool_2d"(%2)
// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>}
// CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %8 = "tfl.conv_2d"(%7, %4, %1)
// CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>}
// CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK: %11 = "tfl.reshape"(%10)
// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>}
// CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK: %14 = "tfl.softmax"(%13)
// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>}
// CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>)
// CHECK: return %16 : tensor<1x36x16xf32>
}
// CHECK-LABEL: QuantizeConstant
func @QuantizeConstant() -> tensor<2x3xf32> {
%cst = constant dense<[[-3.0, -1.0, 0.0], [0.0, 1.0, 3.0]]> : tensor<2x3xf32>
return %cst : tensor<2x3xf32>
// CHECK: %cst = constant dense{{.*}}tensor<2x3xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.023529411764705882:128>>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<2x3xf32>
}

View File

@ -0,0 +1,197 @@
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) :
// OK
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
// Unsupported data format
%1 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
// OK
%2 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
// Unsupported padding
%3 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "EXPLICIT", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
// Unsupported strides
%4 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %0, %1, %2, %3, %4 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
// CHECK-LABEL: conv
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
// CHECK: %0 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %2 = "tf.Conv2D"
// CHECK: %3 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %5 = "tf.Conv2D"
// CHECK: %6 = "tf.Conv2D"
}
func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> (tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>) {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x4xf32>) :
// OK
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
// Unsupported data format
%1 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
// OK
%2 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
// Unsupported strides
%3 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
return %0, %1, %2, %3 : tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>
// CHECK-LABEL: depthwiseConv2D
// CHECK: %cst = constant dense<0.000000e+00> : tensor<12xf32>
// CHECK: %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
// CHECK: %0 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %2 = "tf.DepthwiseConv2dNative"
// CHECK: %3 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %5 = "tf.DepthwiseConv2dNative"
}
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Unsupported training
%1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Use other output
%2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNorm
// CHECK:%cst = constant dense<1.000000e-03> : tensor<f32>
// variance + epsilon
// CHECK: %0 = "tf.Add"(%arg4, %cst) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// rsqrt(variance + epsilon)
// CHECK: %1 = "tf.Rsqrt"(%0) : (tensor<8xf32>) -> tensor<8xf32>
// scale * rsqrt(variance + epsilon)
// CHECK: %2 = "tf.Mul"(%arg1, %1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// x * scale * rsqrt(variance + epsilon)
// CHECK: %3 = "tf.Mul"(%arg0, %2) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
// mean * scale * rsqrt(variance + epsilon)
// CHECK: %4 = "tf.Mul"(%arg3, %2) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %5 = "tf.Sub"(%arg2, %4) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %6 = "tf.Add"(%3, %5) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
// CHECK: %7:5 = "tf.FusedBatchNorm"(%6, %arg1, %arg2, %arg3, %arg4)
// CHECK: %8:5 = "tf.FusedBatchNorm"(%7#0, %arg1, %arg2, %arg3, %arg4)
}
func @fakeQuantNotFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>):
%arg1 = constant dense<-0.1> : tensor<f32>
%arg2 = constant dense<0.2> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantNotFollowedByQuant
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
// CHECK: return %2 : tensor<8x8x8x8xf32>
}
func @fakeQuantFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>):
%arg1 = constant dense<-0.1> : tensor<f32>
%arg2 = constant dense<0.2> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>
%2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
return %2 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantFollowedByQuant
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
// CHECK: return %2 : tensor<8x8x8x8xf32>
}
func @fakeQuantVarsNotConst(tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8x8x8x8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %1 : tensor<8x8x8x8xf32>
// CHECK-LABEL: fakeQuantVarsNotConst
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
// CHECK: return %0 : tensor<8x8x8x8xf32>
}
func @fakeQuantFollowedByTranspose(tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> (tensor<16x3x3x3xf32>) {
^bb0(%arg0: tensor<3x3x3x16xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
%cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
%1 = "tf.Transpose"(%0, %cst_0): (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
return %1 : tensor<16x3x3x3xf32>
// CHECK-LABEL: fakeQuantFollowedByTranspose
// CHECK: %cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
// CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
// CHECK: return %1 : tensor<16x3x3x3xf32>
}
func @fakeQuantFollowedByReshape(tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x3x3x12xf32>) {
^bb0(%arg0: tensor<3x3x3x4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
%cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x4xf32>
%1 = "tf.Reshape"(%0, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
return %1 : tensor<1x3x3x12xf32>
// CHECK-LABEL: fakeQuantFollowedByReshape
// CHECK: %cst = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
// CHECK: return %1 : tensor<1x3x3x12xf32>
}
func @identity(tensor<10xi32>) -> tensor<10xi32> {
^bb0(%arg0: tensor<10xi32>):
%0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32>
return %0: tensor<10xi32>
// CHECK-LABEL: identity
// CHECK: return %arg0
}
func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
%166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = false, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32>
return %166 : tensor<1x1000xf32>
// CHECK-LABEL: matmulNoTransposeAOrB
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
}
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
%166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = true, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32>
return %166 : tensor<1x1000xf32>
// CHECK-LABEL: matmulNoTransposeB
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
// CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
}

View File

@ -0,0 +1,132 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
// CHECK-LABEL: QuantizeFloatConst
func @QuantizeFloatConst() -> tensor<f32> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
return %2 : tensor<f32>
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<f32>
}
// CHECK-LABEL: QuantizeDenseFloatConst
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<2x2xf32>
}
// CHECK-LABEL: QuantizeSplatFloatConst
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
%0 = constant dense<3.0> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK: "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<2x2xf32>
}
// CHECK-LABEL: DequantizeAndQuantize
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%0 = "tfl.dequantize"(%cst) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: return %0 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
}
// CHECK-LABEL: QuantizeConv2D
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%cst = constant dense<-1.23697901> : tensor<32xf32>
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
%5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: return %2 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
}
// CHECK-LABEL: QuantizeDepthwiseConv2D
func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%cst = constant dense<-1.23697901> : tensor<32xf32>
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
%5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %0) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK: return %2
}
// CHECK-LABEL: QuantizeAveragePool2D
func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.average_pool_2d"(%0) {name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
return %1 : tensor<1x1x1x16xf32>
// CHECK: %0 = "tfl.average_pool_2d"(%arg0)
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x1x1x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32>
// CHECK: return %1 : tensor<1x1x1x16xf32>
}
// CHECK-LABEL: QuantizeReshape2D
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
return %1 : tensor<1x36x16xf32>
// CHECK: %0 = "tfl.reshape"(%arg0)
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32>
// CHECK: return %1 : tensor<1x36x16xf32>
}
// CHECK-LABEL: QuantizeSoftmax
func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
%1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
return %1 : tensor<1x6x6x16xf32>
// CHECK: %0 = "tfl.softmax"(%arg0)
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>) -> tensor<1x6x6x16xf32>
// CHECK: return %1 : tensor<1x6x6x16xf32>
}
// CHECK-LABEL: QuantizeAdd
func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>) -> tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>> {
^bb0(%arg0: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, %arg1: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>):
%0 = "tfl.dequantize"(%arg0) : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>) -> tensor<1x56x56x24xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>) -> tensor<1x56x56x24xf32>
%2 = tfl.add %0, %1 {fused_activation_function = "NONE"} : tensor<1x56x56x24xf32>
%3 = "tfl.quantize"(%2) {qtype = tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>} : (tensor<1x56x56x24xf32>) -> tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
return %3 : tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>)
// CHECK: return %0 : tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
}

View File

@ -0,0 +1,151 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using mlir::MLIRContext;
using mlir::Module;
using stream_executor::port::StatusOr;
using tensorflow::Status;
// NOLINTNEXTLINE
static llvm::cl::opt<bool> print_function_result_mapping(
"print-function-result-mapping",
llvm::cl::desc(
"Print the mapping of function result to flatbuffer output buffer"),
llvm::cl::init(false));
enum TranslationStatus { kTrSuccess, kTrFailure };
static int PrintFunctionResultMapping(const std::string &result,
Module *module) {
// Build model from the resultant string to extract the return values from
// their source of truth.
auto model =
tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
if (!model) return kTrFailure;
// Get an unknown location for where we don't have a terminator to get the
// location of the return value from.
auto unknown_loc = mlir::UnknownLoc::get(module->getContext());
auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
std::function<mlir::Location(int)> loc) {
const auto &output_tensor = (*subgraph.tensors())[buffer];
std::cout << "\tname: '"
<< (output_tensor->name() ? output_tensor->name()->str()
: "<<unnamed>>")
<< "' buffer: " << buffer;
if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
std::cout << '\n';
};
// For every subgraph print out the name (if available), each result's output
// buffer number and location of the return value (if available).
for (auto *subgraph : *(*model)->subgraphs()) {
std::string subgraph_name =
subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
std::cout << '\'' << subgraph_name << "' inputs:\n";
int i = 0;
for (auto input : *subgraph->inputs())
print_buffer(*subgraph, i++, input, nullptr);
std::cout << '\'' << subgraph_name << "' outputs:\n";
mlir::Operation *terminator = nullptr;
if (subgraph->name()) {
if (auto fn = module->getNamedFunction(subgraph->name()->str()))
terminator = fn->back().getTerminator();
}
i = 0;
for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
});
}
}
return kTrSuccess;
}
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
// TODO(jpienaar): Revise the command line option parsing here.
llvm::InitLLVM y(argc, argv);
// TODO(antiagainst): We are pulling in multiple transformations as follows.
// Each transformation has its own set of command-line options; options of one
// transformation can essentially be aliases to another. For example, the
// -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
// -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
// -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
// We need to disable duplicated ones to provide a cleaner command-line option
// interface. That also means we need to relay the value set in one option to
// all its aliases.
llvm::cl::ParseCommandLineOptions(
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
// TODO(ashwinm): Enable command line parsing for both sides.
int fake_argc = 1;
tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
MLIRContext context;
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
StatusOr<std::unique_ptr<Module>> module =
tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, extra_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, inference_type, min_values, max_values,
/*prune_unused_nodes=*/true, &source_mgr, &context);
// If errors occur, the library call in the above already logged the error
// message. So we can just return here.
if (!module.ok()) return kTrFailure;
std::string result;
auto status = tensorflow::ConvertTFControlFlowToTFLOrFlatbuffer(
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
lower_tensor_list_ops, &result);
if (!status.ok()) return kTrFailure;
auto output = mlir::openOutputFile(output_file_name);
output->os() << result;
output->keep();
// Print out debugging info related to function mapping.
if (print_function_result_mapping)
return PrintFunctionResultMapping(result, module.ValueOrDie().get());
return kTrSuccess;
}

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
using llvm::cl::opt;
// TODO(jpienaar): Revise the command line option parsing here.
// NOLINTNEXTLINE
opt<std::string> input_file_name(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
opt<bool> use_splatted_constant(
"use-splatted-constant",
llvm::cl::desc(
"Replace constants with randonmly generated splatted tensors"),
llvm::cl::init(false), llvm::cl::Hidden);
// NOLINTNEXTLINE
opt<bool> input_mlir(
"input-mlir",
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
"GraphDef format"),
llvm::cl::init(false), llvm::cl::Hidden);
// NOLINTNEXTLINE
opt<bool> output_mlir(
"output-mlir",
llvm::cl::desc(
"Output MLIR rather than FlatBuffer for the generated TFLite model"),
llvm::cl::init(false));
// The following is a temporary approach to allow injecting opdefs in addition
// to those that are already part of the global TF registry linked in. The
// primary goal is testing. This is not intended to be a general solution for
// unregistered ops. More appropriate mechanisms, such as op hints, should be
// used instead.
// NOLINTNEXTLINE
llvm::cl::list<std::string> extra_opdefs(
"tf-extra-opdefs", llvm::cl::desc("List of extra opdefs when importing "
"graphdef (testing purposes only)"));
// Quantize and Dequantize ops pair can be optionally emitted before and after
// the quantized model as the adaptors to receive and produce floating point
// type data with the quantized model. Set this to `false` if the model input is
// integer types.
// NOLINTNEXTLINE
opt<bool> emit_quant_adaptor_ops(
"emit-quant-adaptor-ops",
llvm::cl::desc(
"Emit Quantize/Dequantize before and after the generated TFLite model"),
llvm::cl::init(false));

View File

@ -0,0 +1,40 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
// This file contains command-line options aimed to provide the parameters
// required by the TensorFlow Graph(Def) to TF Lite Flatbuffer conversion. It is
// only intended to be included by binaries.
#include <string>
#include "llvm/Support/CommandLine.h"
// The commandline options are defined in LLVM style, so the caller should
// use llvm::InitLLVM to initilize the options.
//
// Please see the implementation file for documentation of details of these
// options.
// TODO(jpienaar): Revise the command line option parsing here.
extern llvm::cl::opt<std::string> input_file_name;
extern llvm::cl::opt<std::string> output_file_name;
extern llvm::cl::opt<bool> use_splatted_constant;
extern llvm::cl::opt<bool> input_mlir;
extern llvm::cl::opt<bool> output_mlir;
extern llvm::cl::list<std::string> extra_opdefs;
extern llvm::cl::opt<bool> emit_quant_adaptor_ops;
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_

View File

@ -0,0 +1,166 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
using mlir::MLIRContext;
using mlir::Module;
using stream_executor::port::StatusOr;
StatusOr<std::unique_ptr<Module>> LoadFromGraphdefOrMlirSource(
const std::string &input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string> &extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, absl::string_view inference_type,
absl::string_view min_values, absl::string_view max_values,
bool prune_unused_nodes, llvm::SourceMgr *source_mgr,
MLIRContext *context) {
if (input_mlir) {
// Set up the input file.
std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message);
if (!file) {
llvm::errs() << error_message << "\n";
return errors::InvalidArgument("fail to open input file");
}
source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
return std::unique_ptr<Module>(mlir::parseSourceFile(*source_mgr, context));
}
for (const auto &tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Register extra opdefs.
// TODO(b/133770952): Support shape functions.
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData *op_reg_data) -> Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return Status::OK();
});
}
if (use_splatted_constant) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
input_filename, debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, inference_type, min_values, max_values,
prune_unused_nodes, context);
}
return tensorflow::GraphdefToMlirTranslateFunction(
input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, inference_type, min_values, max_values, prune_unused_nodes,
context);
}
bool ShouldRunQuantizePasses(mlir::Module *m) {
if (mlir::Function *main_fn = m->getNamedFunction("main")) {
return main_fn->getAttrOfType<mlir::UnitAttr>("tf.quantize") !=
mlir::Attribute();
}
return false;
}
void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops,
mlir::PassManager *pass_manager) {
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
// TODO(jpienaar): Revise post dialect constants.
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
// Canonicalization includes const folding, which is utilized here to optimize
// away ops that can't get constant folded after PrepareTF pass. For example,
// tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
pass_manager->addPass(mlir::createCanonicalizerPass());
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (emit_builtin_tflite_ops) {
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
// TODO(haoliang): Add this pass by default.
if (lower_tensor_list_ops) {
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
pass_manager->addPass(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
if (run_quantize) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass());
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
pass_manager->addPass(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::createCSEPass());
}
}
Status ConvertTFControlFlowToTFLOrFlatbuffer(
mlir::Module *module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops, std::string *result) {
mlir::StatusScopedDiagnosticHandler statusHandler(module->getContext(),
/*propagate=*/true);
mlir::PassManager pm;
bool run_quantize = ShouldRunQuantizePasses(module);
AddTFToTFLConversionPasses(emit_builtin_tflite_ops, run_quantize,
emit_quant_adaptor_ops, lower_tensor_list_ops,
&pm);
if (failed(pm.run(module))) {
return statusHandler.ConsumeStatus();
}
if (export_to_mlir) {
llvm::raw_string_ostream os(*result);
module->print(os);
return Status::OK();
}
// Write MLIR TFLite dialect into FlatBuffer
if (tflite::MlirToFlatBufferTranslateFunction(
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops)) {
return statusHandler.ConsumeStatus();
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
// file; otherwise, load from a GraphDef.
// Setting prune_unused_nodes to true, would prune unreachable nodes if
// output_arrays is specified.
stream_executor::port::StatusOr<std::unique_ptr<mlir::Module>>
LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, absl::string_view inference_type,
absl::string_view min_values, absl::string_view max_values,
bool prune_unused_nodes, llvm::SourceMgr* source_mgr,
mlir::MLIRContext* context);
// Quantization passess will run only when the user specifies a quantized type
// in the `-tf-inference-type` flag, which is converted to the function
// attribute "tf.quantize" by the importer module.
// TODO(fengliuai): switch to the cmd flag once the flags are moved to this
// file with main method.
bool ShouldRunQuantizePasses(mlir::Module* m);
// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect
// to a MLIR `pass_manager`. These passes first raise the control flow in the TF
// control flow dialect, decode the constant tensors, and then legalize the
// module to TF Lite dialect with some optimizations afterwards.
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops. If `run_quantize` is true, quantization
// passes will be added. If `emit_quant_adaptor_ops` is true, Quantize and
// Dequantize ops are added to the inputs and outputs of the quantized model.
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
// TF ops before legalization to TF Lite dialect.
void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops,
mlir::PassManager* pass_manager);
// Taking a MLIR module in TF control flow dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and
// serializes the result to a string. Depending on an attribute in the module
// main function, Quantization is applied. If `export_to_mlir` is true, the
// result is exported in MLIR text format, otherwise exported in flat buffer.
Status ConvertTFControlFlowToTFLOrFlatbuffer(
mlir::Module* module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops, std::string* result);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
using llvm::LessRecord;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Operator;
// Helper macro that returns indented os.
#define OUT(X) os.indent((X))
// The function below has a non-constant reference as that is required by LLVM's
// TableGenMain.
// NOLINTNEXTLINE
static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"};
emitSourceFileHeader("TensorFlow Lite Ops Quant Spec Getters", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
"GetOpQuantSpec(mlir::Operation *op) {\n";
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
for (auto *def : defs) {
Operator op(def);
for (const auto t : op.getTraits()) {
if (auto opTrait = llvm::dyn_cast<mlir::tblgen::NativeOpTrait>(&t)) {
auto trait = opTrait->getTrait();
// We only handle TFL specific native op traits.
if (!trait.startswith("TFL::")) continue;
trait.consume_front("TFL::");
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
<< ">(op)) {\n";
// There is a "NoQuantizableResult" trait, set the flag.
if (trait.equals("NoQuantizableResult")) {
OUT(4) << "spec->is_quantizable = false;\n";
}
// There is a "SameOperandsAndResultScale" trait, set the flag.
if (trait.equals("SameOperandsAndResultsScale")) {
OUT(4) << "spec->requires_same_scale = true;\n";
}
// There is a "FixedResultUniformScale" trait, set the type for result.
if (trait.startswith("FixedResultUniformScale")) {
OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n";
OUT(6) << "spec->restricted_output_params.push_back(tfl."
"GetResultQuantizedType(i));\n";
}
// There is a "AccumulatorUniformScale" trait, set the type for bias.
auto trait_str = opTrait->getTrait().str();
llvm::SmallVector<llvm::StringRef, 1> matches;
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
<< "GetUniformQuantizedTypeForBias)));\n";
}
OUT(2) << "}\n";
}
}
}
OUT(2) << "return spec;\n";
OUT(0) << "}\n";
return false;
}
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram X(argc, argv);
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OpQuantSpecWriter);
}

View File

@ -0,0 +1,243 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// TFLite legalization patterns
include "mlir/IR/OpBase.td"
include "mlir/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
// Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr
// with builder.
class ExtractI32At<int i> : NativeCodeCall<
"$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
"].cast<IntegerAttr>().getInt())">;
// Merge the two Attributes to a ArrayAttr;
def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// storage type $3 to a QuantizedType.
def ConvertToQuantType : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3.getValue())">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4)">;
// Predicate that holds if all the three attributes are set.
def HasAll3Attrs : Constraint<CPred<"HasAll3Attrs($0, $1, $2)">>;
// Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall<
"$_builder.getI32IntegerAttr($0.cast<IntegerAttr>().getInt())">;
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
//===----------------------------------------------------------------------===//
// Unary ops patterns.
//===----------------------------------------------------------------------===//
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
def IsIntList1XY1 : AttrConstraint<CPred<"TFIntListIs1XY1($_self)">>;
def : Pat<(TF_AbsOp $arg), (TFL_AbsOp $arg)>;
def : Pat<(TF_AvgPoolOp $value,
IsIntList1XY1:$ksize,
IsIntList1XY1:$strides,
$padding,
IsDataFormatNHWC:$format),
(TFL_AveragePool2DOp $value,
/*filter_height=*/ExtractI32At<1>:$ksize,
/*filter_width=*/ExtractI32At<2>:$ksize,
/*padding=*/$padding,
/*stride_h=*/ExtractI32At<1>:$strides,
/*stride_w=*/ExtractI32At<2>:$strides,
/*fused_activation_function=*/TFL_AF_None)>;
def : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
def : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
def : Pat<(TF_EluOp $arg), (TFL_EluOp $arg)>;
def : Pat<(TF_ExpandDimsOp $input, $dim), (TFL_ExpandDimsOp $input, $dim)>;
def : Pat<(TF_FakeQuantWithMinMaxArgsOp $inputs,
$min, $max,
$num_bits, $narrow_range),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs,
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
def : Pat<(TF_FillOp $arg, $value), (TFL_FillOp $arg, $value)>;
def : Pat<(TF_FloorOp $arg), (TFL_FloorOp $arg)>;
def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>;
def : Pat<(TF_LogicalNotOp $arg), (TFL_LogicalNotOp $arg)>;
def : Pat<(TF_LogSoftmaxOp $arg), (TFL_LogSoftmaxOp $arg)>;
def : Pat<(TF_MaxPoolOp $value,
IsIntList1XY1:$ksize,
IsIntList1XY1:$strides,
$padding,
IsDataFormatNHWC:$format),
(TFL_MaxPool2DOp $value,
/*padding=*/$padding,
/*stride_w=*/ExtractI32At<2>:$strides,
/*stride_h=*/ExtractI32At<1>:$strides,
/*filter_width=*/ExtractI32At<2>:$ksize,
/*filter_height=*/ExtractI32At<1>:$ksize,
/*fused_activation_function=*/TFL_AF_None)>;
def : Pat<(TF_MaximumOp $arg1, $arg2), (TFL_MaximumOp $arg1, $arg2)>;
def : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>;
def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>;
def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
// The second operand is captured in the type for this transform.
def : Pat<(TF_ReshapeOp:$res AnyStaticShapeTensor:$arg, $ignored),
(TFL_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
// condition
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
def : Pat<(TF_SqueezeOp AnyStaticShapeTensor:$arg, $ignored_dims),
(TFL_ReshapeOp $arg)>;
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
// The following two rules can both match an tf.Placeholder.input node with
// min/max/type attributes, so we increase the benefit of the first rule by one
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
(TFL_DequantizeOp
(TFL_QuantizeOp
(TFL_InputOp $inputs),
(ConvertToQuantType $inputs, $min, $max, $type))),
[(HasAll3Attrs $min, $max, $type)], (addBenefit 1)>;
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
(TFL_InputOp $inputs)>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>;
def : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>;
def : Pat<(TF_GreaterEqualOp $l, $r), (TFL_GreaterEqualOp $l, $r)>;
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, TFL_AF_None)>;
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
[TF_Relu6Op, TFL_AF_Relu6]] in {
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
(ToOp $lhs, $rhs, actFnPair[1])>;
// TODO: Maybe move these below to general pass?
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
(ToOp $lhs, $rhs, actFnPair[1])>;
}
}
// Instantiated FusedBinary patterns for the from-to pairs of ops.
foreach fromToPair = [[TF_AddOp, TFL_AddOp],
[TF_AddV2Op, TFL_AddOp],
[TF_DivOp, TFL_DivOp],
[TF_MulOp, TFL_MulOp],
[TF_RealDivOp, TFL_DivOp],
[TF_SubOp, TFL_SubOp]] in
defm : FusedBinaryActivationFuncOpPat<fromToPair[0], fromToPair[1]>;
def : Pat<(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
IsDataFormatNHWC:$data_format),
(TFL_AddOp $l, $r, TFL_AF_None)>;
// TODO(jpienaar): These should be handled by the pattern rewriter, find out
// why it isn't.
def : Pat<(TF_Relu6Op (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
IsDataFormatNHWC:$data_format)),
(TFL_AddOp $l, $r, TFL_AF_Relu6)>;
def : Pat<(TF_FakeQuantWithMinMaxVarsOp $inputs,
(ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max),
$num_bits, $narrow_range),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs,
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
def : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
def : Pat<(TF_SquaredDifferenceOp $l, $r), (TFL_SquaredDifferenceOp $l, $r)>;
// Note(ycling): We can eliminate Relu from Relu(SquaredDifference(x, y)),
// since the result of SquaredDifference is always non-negative.
// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
// are failing without the following pattern to optimize Relu away fixes
// the problem.
def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
(TFL_SquaredDifferenceOp $l, $r)>;
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>;
def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
def : Pat<
(TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
(TFL_StridedSliceOp $input, $begin, $end, $strides,
(convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
(convertIntAttrTo32Bit $new_axis_mask), (convertIntAttrTo32Bit $shrink_axis_mask))>;

View File

@ -0,0 +1,222 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass converts operations in TensorFlow dialect into
// operations that are legal in the TensorFlow Lite dialect. Operations that
// can be legalized to TensorFlow Lite dialect with simple replacements are part
// of this pass and other operations that may create extra ops should be part of
// the PrepareTF pass which should be run before this pass. That way any
// constant folding opportunities from the extra ops can be exploited by the
// constant folding support for the TensorFlow ops.
#include <climits>
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// The actual LegalizeTF Pass.
namespace {
// Legalize operations in functions.
struct LegalizeTF : public FunctionPass<LegalizeTF> {
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
#define DECL_CONVERT_OP(tf_op) \
struct ConvertTF##tf_op##Op : public RewritePattern { \
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
PatternMatchResult matchAndRewrite( \
Operation* op, PatternRewriter& rewriter) const override; \
}
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
// operands are properly supported in declarative rewrite rule specification.
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(Gather);
DECL_CONVERT_OP(GatherV2);
DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(TopKV2);
DECL_CONVERT_OP(Unpack);
#undef DECL_CONVERT_OP
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
SmallVector<Value*, 4> values(tf_concat_op.values());
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
return matchFailure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
fused_activation_function);
return matchSuccess();
}
PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
SmallVector<Value*, 4> values(tf_concat_op.values());
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
return matchFailure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis),
fused_activation_function);
return matchSuccess();
}
PatternMatchResult mlir::TFL::ConvertTFGatherOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
// Gather in TF -> Gather in TFL with axis=0
IntegerType type = IntegerType::get(32, rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::GatherOp>(
op, op->getOperand(0), op->getOperand(1), IntegerAttr::get(type, 0));
return matchSuccess();
}
PatternMatchResult ConvertTFGatherV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_op = cast<TF::GatherV2Op>(op);
ElementsAttr axis;
if (!matchPattern(tf_op.axis(), m_Constant(&axis))) return matchFailure();
rewriter.replaceOpWithNewOp<GatherOp>(op, op->getOperand(0),
op->getOperand(1),
ExtractSingleElementAsInteger(axis));
return matchSuccess();
}
// The following is effectively:
// def : Pat<
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
// ConstBoolAttrTrue:$transpose_b),
// (TFL_FullyConnectedOp:$__0 $a, $b,
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_matmul_op = cast<TF::MatMulOp>(op);
if (tf_matmul_op.transpose_a()) return matchFailure();
if (!tf_matmul_op.transpose_b()) return matchFailure();
Type output_type = tf_matmul_op.getResult()->getType();
// TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
auto fc_op = rewriter.create<FullyConnectedOp>(
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {fc_op.getResult(0)});
return matchSuccess();
}
PatternMatchResult ConvertTFPackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_pack_op = cast<TF::PackOp>(op);
SmallVector<Value*, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output()->getType();
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N().getZExtValue());
// Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
axis);
return matchSuccess();
}
PatternMatchResult ConvertTFTopKV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
// TopK in TFL is always sorted so we ignore that attribute here.
rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, op->getOperand(0),
op->getOperand(1));
return matchSuccess();
}
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op);
auto* input = tf_unpack_op.value();
auto output_types = functional::map([](Value* v) { return v->getType(); },
tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num().getZExtValue());
// Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
return matchSuccess();
}
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* ctx = &getContext();
auto& func = getFunction();
// Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns);
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
ConvertTFTopKV2Op, ConvertTFUnpackOp>::build(patterns,
ctx);
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
FunctionPassBase* CreateLegalizeTFPass() { return new LegalizeTF(); }
static PassRegistration<LegalizeTF> pass(
"tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,338 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepares for legalization to the TFLite dialect by
// converting Tensorlist operations in TensorFlow dialect into operations that
// can be legalized to TensorFlow Lite dialect with simple replacements. The
// newly created operations are in the TensorFlow dialect if the operation can
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
// is used.
#include <climits>
#include <cstdint>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#define DEBUG_TYPE "tf-tfl-legalization"
//===----------------------------------------------------------------------===//
// The actual LowerStaticTensorList Pass.
//
namespace mlir {
namespace {
// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public FunctionPass<LowerStaticTensorListPass> {
void runOnFunction() override;
LogicalResult ModifyTensorList();
};
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
ArrayRef<int64_t> shape, int32_t val) {
auto type = rewriter->getTensorType(shape, rewriter->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
return rewriter->create<ConstantOp>(op->getLoc(), type, attr);
}
Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
Value *shape_tensor, int32_t val) {
auto scalar_val = CreateI32SplatConst(op, rewriter, {}, val);
return rewriter->create<TF::FillOp>(
op->getLoc(), rewriter->getTensorType({-1}, rewriter->getIntegerType(32)),
shape_tensor, scalar_val);
}
struct ConvertTFTensorListSetItem : public RewritePattern {
explicit ConvertTFTensorListSetItem(MLIRContext *context)
: RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
context) {}
// This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then
// expands the dimension of the `$item`, followed by another slice of the
// remaining rows starting from `$index` + 1. Lastly it concatenates the
// three parts together.
// On a high level, it's doing something like:
// def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
// (Concat
// concat_dim = 0,
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
auto input = tf_op.input_handle();
auto shape_dtype = rewriter.getIntegerType(32);
auto input_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
auto item = tf_op.item();
auto item_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
// Prepare the start position for the first slice op, which is [0, 0, ..,
// 0].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
scalar_zero);
// Fill all 0s into the first position tensor.
auto first_start_position =
CreateI32SplatTensor(op, &rewriter, position_shape, 0);
// Prepare the start position for the second slice op, which is
// [index + 1, 0, 0 .. 0].
// Calculate the first dimension, which is index + 1.
auto index = tf_op.index();
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
auto begin =
rewriter.create<TF::AddOp>(op->getLoc(), vector_type, index,
CreateI32SplatConst(op, &rewriter, {1}, 1));
// Followed by the first dimension `begin`, are `item_rank` of 0s.
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
scalar_zero);
auto partial_second_start_position =
CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
auto position_type = first_start_position->getType();
// Concatenate `begin` with the remaining 0s.
auto second_start_position = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({begin, partial_second_start_position}),
rewriter.getI64IntegerAttr(2));
// Create the size parameter for the first slice op, which is [index, -1,
// -1, .., -1].
auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), vector_type, index, scalar_zero);
auto partial_size1 =
CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
auto size1 = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({size1_leading_dim, partial_size1}),
rewriter.getI64IntegerAttr(2));
// Create the size parameter for the second slice, which is [-1, -1, ..,
// -1].
auto size2 = CreateI32SplatTensor(op, &rewriter, position_shape, -1);
// Create two slice ops.
auto element_type = input->getType().cast<TensorType>().getElementType();
auto unranked_tensor = rewriter.getTensorType(element_type);
auto slice1 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, first_start_position, size1);
auto slice2 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, second_start_position, size2);
// Expand the dimension of item so that it will have the same rank with
// input.
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), unranked_tensor, item, scalar_zero);
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero,
ArrayRef<Value *>({slice1, expanded_item, slice2}),
rewriter.getI64IntegerAttr(3));
return matchSuccess();
}
};
// TODO(hinsu): Fix end-to-end test when passing string `element_dtype`
// attribute.
struct ConvertTFTensorListReserve : public RewritePattern {
explicit ConvertTFTensorListReserve(MLIRContext *context)
: RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
context) {}
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
auto element_shape = tf_op.element_shape();
auto shape_dtype =
element_shape->getType().cast<TensorType>().getElementType();
auto num_elements = tf_op.num_elements();
int64_t input_rank = -1; // -1 means unknown dimension.
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
// Note that the first item of the shape array is the element's rank, add
// it by 1 to get the input's rank.
if (type.hasStaticShape()) {
input_rank = type.getShape()[0] + 1;
}
}
auto element_dtype = tf_op.element_dtype();
// The output shape of the result tensor should be [num_elements +
// element_shape].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
scalar_zero);
auto shape_type = rewriter.getTensorType({input_rank}, shape_dtype);
auto list_shape = rewriter.create<TF::ConcatOp>(
op->getLoc(), shape_type, scalar_zero,
ArrayRef<Value *>({leading_dim, element_shape}),
rewriter.getI64IntegerAttr(2));
// Create a zero-initialized constant tensor that has the same type
// as specified by element_dtype.
auto zero_type = rewriter.getTensorType({}, element_dtype);
auto zero_attr = rewriter.getZeroAttr(zero_type);
auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(
op, rewriter.getTensorType(element_dtype), list_shape, zero);
return matchSuccess();
}
};
} // namespace
namespace TFL {
namespace {
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
} // namespace
} // namespace TFL
LogicalResult LowerStaticTensorListPass::ModifyTensorList() {
// In `runOnFunction`, there is no guarantee about
// in which order those patterns will be applied. Our transformation requires
// that at runtime each `TensorListSetItem` op takes in a normal tensor type
// rather than a `DT_VARIANT` tensor. So here we need to manually walk-through
// the IR and change the argument/return types of each `TensorListSetItemOp`.
// TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/
// produces `DT_VARIANT` tensor. 2) More robust support for handling multiple
// different tensorlist types. For example, consider the case like:
// l1 = list_ops.tensor_list_from_tensor(t, element_shape1)
// l2 = list_ops.tensor_list_from_tensor(t, element_shape2)
// l1 = list_ops.tensor_list_set_item(l1, 0, item1)
// l2 = list_ops.tensor_list_set_item(l2, 0, item2)
// 3) Handle the case where a tensorlist output is passed to multiple
// functions.
for (Block &block : getFunction()) {
Type tensor_type;
for (Operation &op : block) {
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
tensor_type = tf_op.tensor()->getType();
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
tf_op.element_dtype().isF64() ||
tf_op.element_dtype().isa<IntegerType>())) {
return tf_op.emitError(
"requires element_dtype to be integer or 16-bit/32-bit/64-bit "
"float type during TF Lite transformation pass");
}
// TODO(haoliang): figure out better way of specify shape.
tensor_type = UnrankedTensorType::get(tf_op.element_dtype());
}
if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
tf_op.input_handle()->setType(tensor_type);
tf_op.getResult()->setType(tensor_type);
}
// Currently we will raise an error if an op other than the following
// contains a DT_VARIANT tensor as its input or output. Below ops already
// have proper transformation patterns that eliminate the need of
// `DT_VARIANT`, we consider it's safe to not raise an error on those ops.
if (llvm::isa<TF::TensorListFromTensorOp>(op) ||
llvm::isa<TF::TensorListReserveOp>(op) ||
llvm::isa<TF::TensorListSetItemOp>(op) ||
llvm::isa<TF::TensorListStackOp>(op) ||
llvm::isa<TF::TensorListGetItemOp>(op)) {
continue;
}
// Check if any of the input operand is a DT_VARIANT.
for (Type type : op.getOperandTypes()) {
if (type.isa<TF::VariantType>()) {
return op.emitError(
"op's input contains a DT_VARIANT tensor. Currently we only "
"allow "
"TensorListFromTensor/TensorListReserve/TensorListStack/"
"TensorListSetItem/"
"TensorListGetItem to have DT_VARIANT input/output");
}
}
// Check if any of the output is a DT_VARIANT.
for (Type type : op.getResultTypes()) {
if (type.isa<TF::VariantType>()) {
return op.emitError(
"op's output contains a DT_VARIANT tensor. Currently we only "
"allow "
"TensorListFromTensor/TensorListReserve/TensorListStack/"
"TensorListSetItem/"
"TensorListGetItem to have DT_VARIANT input/output");
}
}
}
}
return success();
}
void LowerStaticTensorListPass::runOnFunction() {
if (failed(ModifyTensorList())) {
signalPassFailure();
return;
}
OwningRewritePatternList patterns;
auto &func = getFunction();
TFL::populateWithGenerated(&getContext(), &patterns);
patterns.push_back(
llvm::make_unique<ConvertTFTensorListReserve>(&getContext()));
patterns.push_back(
llvm::make_unique<ConvertTFTensorListSetItem>(&getContext()));
applyPatternsGreedily(func, std::move(patterns));
}
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
FunctionPassBase *TFL::CreateLowerStaticTensorListPass() {
return new LowerStaticTensorListPass();
}
static PassRegistration<LowerStaticTensorListPass> pass(
"tfl-lower-static-tensor-list",
"Lower TensorList ops within TensorFlow Lite dialect");
} // namespace mlir

View File

@ -0,0 +1,66 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass takes operations in TensorFlowLite dialect and
// optimizes them to resulting operations in TensorFlowLite dialect.
#include <climits>
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// The actual Optimize Pass.
namespace {
// Optimize TFLite operations in functions.
struct Optimize : public FunctionPass<Optimize> {
void runOnFunction() override;
};
// Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
// types.
bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
void Optimize::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
// Add the generated patterns to the list.
TFL::populateWithGenerated(&getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
FunctionPassBase *CreateOptimizePass() { return new Optimize(); }
static PassRegistration<Optimize> pass(
"tfl-optimize", "Optimize within the TensorFlow Lite dialect");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,112 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the optimization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w)>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w,
$multiplier)>;
}
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
// activation functions.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_TanhOp, TFL_AF_Tanh]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
// If we see an add op adding a constant value to a convolution op with constant
// bias, we can fuse the add into the convolution op by constant folding the
// bias and the add op's constant operand.
// The following pattern restricts to float constant values for now.
def : Pat<(TFL_AddOp (TFL_Conv2DOp $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(TFL_AddOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w)>;
def : Pat<(TFL_AddOp (TFL_DepthwiseConv2DOp $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(TFL_AddOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier)>;
def BroadcastableElementsAttrs :
Constraint<CPred<"IsBroadcastableElementsAttrs($0, $1)">>;
// If we see a mul op multiplying a constant value to a convolution op with
// constant filter and bias, we can fuse the multiplication into the convolution
// op by constant folding the filter/bias and the mul op's constant operand.
// The following pattern restricts to float constant values for now.
def : Pat<(TFL_MulOp (TFL_DepthwiseConv2DOp $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(TFL_MulOp (ConstantOp $filter),
(ConstantOp $value),
TFL_AF_None),
(TFL_MulOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(BroadcastableElementsAttrs $filter, $value)]>;
// This pattern applies when the same quantize/dequantize have been used twice
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
namespace mlir {
class FunctionPassBase;
namespace TFL {
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
FunctionPassBase *CreateLegalizeTFPass();
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
FunctionPassBase *CreateOptimizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
FunctionPassBase *CreatePrepareTFPass();
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
FunctionPassBase *CreateLowerStaticTensorListPass();
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
FunctionPassBase *CreateQuantizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
FunctionPassBase *CreatePrepareQuantizePass();
// Creates a instance of the TensorFlow Lite dialect PostQuantize pass.
FunctionPassBase *CreatePostQuantizePass(bool emit_quant_adaptor_ops);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,137 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies soem clean up steps after quantization.
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
//===----------------------------------------------------------------------===//
// The post-quantize Pass.
//
namespace mlir {
namespace TFL {
namespace {
// Applies all the clean up steps after quantization.
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
public:
// Constructor used by the PassRegistration. This will remove the adaptor ops.
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
// Constructor used by manually creating the pass.
explicit PostQuantizePass(bool emit_quant_adaptor_ops)
: emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
void runOnFunction() override;
private:
// Set this flag to true if the inputs and outputs are in floating point. The
// quant adaptor ops convert them to fixed point values (i.e. quantize) before
// feeding them to the model and convert them back to floating point
// (i.e. dequantize) as the output.
bool emit_quant_adaptor_ops_;
};
void RemoveQuantizationAdaptorOps(Function* func) {
mlir::OpBuilder builder(func->getBody());
auto& bb = func->getBlocks().front();
auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments();
llvm::SmallVector<Type, 4> input_types;
input_types.reserve(num_args);
// Edit the block arguments and create the new input ops in place to replace
// the old input ops and quantize ops.
for (int i = 0; i != num_args; ++i) {
// Previous loop iteration may invalidate the insertion point so we have to
// reset insertion point each iteration.
builder.setInsertionPointToStart(&bb);
// In each iteration, a new argument is appended to the end of the list
// and the current argument is erased, so here we always process the first
// argument in the list.
auto* arg = bb.getArgument(0);
auto* input_op = *arg->user_begin();
auto input_result = input_op->getResult(0);
// We can drop the quantization adaptor only when the pseudo input op has
// one user and it is the quantize op. Otherwise, we have to keep the
// adaptor and allow the floating point inputs.
if (input_result->hasOneUse() &&
isa<QuantizeOp>(*input_result->user_begin())) {
auto* second_op = *input_result->user_begin();
auto quantize_output = second_op->getResult(0);
auto quantize_type = quantize_output->getType();
input_types.push_back(quantize_type);
auto* new_arg = bb.addArgument(quantize_type);
// Make a copy of input op with quantized input and output type.
auto new_input =
builder.create<InputOp>(input_op->getLoc(), quantize_type, new_arg);
quantize_output->replaceAllUsesWith(new_input);
second_op->erase();
input_op->erase();
} else {
// Make a copy of current argument and append it to the end of the list.
Type arg_type = arg->getType();
input_types.push_back(arg_type);
auto* new_arg = bb.addArgument(arg_type);
arg->replaceAllUsesWith(new_arg);
}
arg->dropAllUses();
bb.eraseArgument(0);
}
// Edit the return ops and remove the dequantize ops in place.
int num_return_operands = terminator->getNumOperands();
llvm::SmallVector<Type, 4> output_types;
output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
auto* returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value->getDefiningOp();
if (isa<DequantizeOp>(returned_op)) {
auto* dequantized_result = returned_op->getOperand(0);
output_types.push_back(dequantized_result->getType());
terminator->setOperand(i, dequantized_result);
returned_op->erase();
} else {
output_types.push_back(returned_value->getType());
}
}
auto new_func_type = builder.getFunctionType(input_types, output_types);
func->setType(new_func_type);
}
void PostQuantizePass::runOnFunction() {
auto& func = getFunction();
if (!emit_quant_adaptor_ops_) {
RemoveQuantizationAdaptorOps(&func);
}
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
FunctionPassBase* CreatePostQuantizePass(bool emit_quant_adaptor_ops) {
return new PostQuantizePass(emit_quant_adaptor_ops);
}
static PassRegistration<PostQuantizePass> pass(
"tfl-post-quantize", "Apply post quantization clean up after quantization");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,100 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
// Converts tf.FusedBatchNorm into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
// (x - mean) * scale / sqrt(variance + epsilon) + offset
//
// Let multiplier = scale / sqrt(variance + epsilon),
// to compute
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
// is then to compute
// (x * multiplier) + (offset - mean * multiplier).
def : Pattern<
(TF_FusedBatchNormOp $x, $scale, $offset, $mean, $variance,
F32Attr:$epsilon, $data_format,
FalseBoolAttr:$is_training),
[(TF_AddOp
(TF_MulOp
$x,
(TF_MulOp:$multiplier
$scale,
(TF_RsqrtOp
(TF_AddOp $variance,
(TF_ConstOp $epsilon))))),
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
/*batch_mean=*/(verifyUnusedValue),
/*batch_variance=*/(verifyUnusedValue),
/*reserve_space_1=*/(verifyUnusedValue),
/*reserve_space_2=*/(verifyUnusedValue)
]>;
// TODO(jpienaar): Move to opbase something more general.
def TFi32ElementsAttr : Attr<CPred<"$_self.isa<DenseIntElementsAttr>">,
"scalar int attribute"> {
let storageType = [{ DenseIntElementAttr }];
let constBuilderCall = "$_builder.getDenseElementsAttr("
"$_builder.getTensorType({}, $_builder.getIntegerType(32)), "
"{$_builder.getI32IntegerAttr($0)})";
}
class TFi32<int v> : ConstantAttr<TFi32ElementsAttr, !cast<string>(v)>;
// Matmul without transpose on b to matmul with explicit transpose op and
// transposed b.
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $b),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
$at, ConstBoolAttrTrue)>;
// Matmul with transpose on a to matmul with explicit transpose op and a not
// transposed.
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $a),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
ConstBoolAttrFalse, $bt)>;
//===----------------------------------------------------------------------===//
// Op removal patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_IdentityOp $arg), (replaceWithValue $arg)>;
//===----------------------------------------------------------------------===//
// Op quantization pass-through patterns.
//===----------------------------------------------------------------------===//
// TODO(fengliuai): Implement similar rule in the QuantizePass if the constant
// folding hook of tfl.transpose and tfl.reshape are implemented.
def : Pat<(TF_TransposeOp
(TF_FakeQuantWithMinMaxVarsOp
$input, $min, $max, $num_bits, $narrow_range),
$perm),
(TF_FakeQuantWithMinMaxVarsOp (TF_TransposeOp $input, $perm),
$min, $max, $num_bits, $narrow_range)>;
def : Pat<(TF_ReshapeOp
(TF_FakeQuantWithMinMaxVarsOp
$input, $min, $max, $num_bits, $narrow_range),
$shape),
(TF_FakeQuantWithMinMaxVarsOp (TF_ReshapeOp $input, $shape),
$min, $max, $num_bits, $narrow_range)>;

View File

@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on TFLite dialect.
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
//===----------------------------------------------------------------------===//
// The prepare-quantize Pass.
//
namespace mlir {
namespace TFL {
namespace {
// Applies prepare quantization on the model in TFL dialect. This pass runs
// before the quantization pass and propagate the quantization parameters
// across ops. This step is necessary for post-training quantization and also
// making the quantizaton rule for some operations in the quantization-awre
// training quantization simpler.
struct PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
void runOnFunction() override;
};
void PrepareQuantizePass::runOnFunction() {
ApplyQuantizationParamsPropagation(&getFunction());
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
FunctionPassBase *CreatePrepareQuantizePass() {
return new PrepareQuantizePass();
}
static PassRegistration<PrepareQuantizePass> pass(
"tfl-prepare-quantize", "Prepare TFL dialect for quantization");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,379 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepares for legalization to the TFLite dialect by
// converting operations in TensorFlow dialect into operations that can be
// legalized to TensorFlow Lite dialect with simple replacements. The newly
// created operations are in the TensorFlow dialect if the operation can be
// represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
// used. For example, Conv2D in TFLite which uses OHWI data format for filters
// is not supported in TensorFlow because TensorFlow requires filters in the
// HWIO data format.
//
// Motivation to prepare for the TFLite legalization before the actual
// legalization is to exploit constant folding opportunities in any newly
// created ops by leveraging constant folding support for the TensorFlow ops.
// This way TFLite can be used as a serialization format only and does not
// require access to the TFLite runtime for optimizations as required by the
// TFLite team.
#include <climits>
#include <cstdint>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#define DEBUG_TYPE "tf-tfl-legalization"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// The actual PrepareTF Pass.
//
// TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
// pass.
namespace {
// Prepare TF operations in functions for subsequent legalization.
struct PrepareTFPass : public FunctionPass<PrepareTFPass> {
void runOnFunction() override;
};
// TODO(fengliuai): move this rule to PreparePatterns.td
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair after the
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
// folding logic will use a "std.constant" op to replace the
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
// convert the output type to the next op.
struct InsertTFLQuantOpsAfterTFFakeQuantOp : public RewritePattern {
InsertTFLQuantOpsAfterTFFakeQuantOp(MLIRContext *context)
: RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 1,
context) {}
struct MatchedState : public PatternState {
FloatAttr min;
FloatAttr max;
APInt num_bits;
bool narrow_range;
};
PatternMatchResult match(Operation *op) const override {
auto tf_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
auto res = tf_op.outputs();
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
return matchFailure();
auto state = absl::make_unique<MatchedState>();
ElementsAttr min_value, max_value;
if (!matchPattern(tf_op.min(), m_Constant(&min_value)))
return matchFailure();
if (!matchPattern(tf_op.max(), m_Constant(&max_value)))
return matchFailure();
state->min = ExtractSingleElementAsFloat(min_value);
state->max = ExtractSingleElementAsFloat(max_value);
if (!state->min || !state->max) return matchFailure();
state->num_bits = tf_op.num_bits();
state->narrow_range = tf_op.narrow_range();
return matchSuccess(std::move(state));
}
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
auto &s = *static_cast<MatchedState *>(state.get());
Location loc = op->getLoc();
Value *copied = OpBuilder(op).clone(*op)->getResult(0);
Type res_type = copied->getType();
Type storage_type = rewriter.getIntegerType(s.num_bits.getSExtValue());
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, s.min, s.max,
storage_type, s.narrow_range);
Value *quantize_op =
rewriter.create<TFL::QuantizeOp>(loc, qtype.getValue(), copied, qtype);
rewriter.replaceOpWithNewOp<TFL::DequantizeOp>(op, res_type, quantize_op);
}
};
// Templated class for declaring a converter from some TensorFlow convolution
// op into its counterpart in TensorFlow Lite.
//
// The `ConcreteType` deriving from this template must provide the following
// method for constructing TensorFlow Lite op:
//
// TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
// PatternRewriter &rewriter, Location loc,
// Type result_type, Value *input,
// Value *filter, Value *bias) const;
//
// And also the following method for getting the dimension for bias tensor:
//
// int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
template <typename ConcreteType, typename TFConvOpType>
struct ConvertTFConvOp : public RewritePattern {
// Transient state for preserving data from match to rewrite
struct ConvertTFConvOpMatchState : public PatternState {
IntegerAttr dilation_height_factor;
IntegerAttr dilation_width_factor;
StringAttr padding;
IntegerAttr stride_height;
IntegerAttr stride_width;
};
ConvertTFConvOp(MLIRContext *context)
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
PatternMatchResult match(Operation *op) const override {
// Assumes TensorFlow convolution op is already verified to be
// in valid form.
// Match a TFConvOpType under the following conditions:
// * The 'T' attribute must exist and be of value DT_FLOAT.
// * The 'data_format' attribute must exist and be of value "NHWC".
// * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
// * The 'dilations' attribute is optional, but it must be of the form
// [1, X, Y, 1] if exists.
TFConvOpType tf_op = cast<TFConvOpType>(op);
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
return matchFailure();
IntegerAttr height, width;
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure();
auto state = llvm::make_unique<ConvertTFConvOpMatchState>();
state->stride_height = height;
state->stride_width = width;
if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
state->dilation_height_factor = height;
state->dilation_width_factor = width;
} else {
// If the 'dilations' attribute is missing, we use the default value (1)
// for both dilation height and width factor.
state->dilation_height_factor = intAttrOne;
state->dilation_width_factor = intAttrOne;
}
StringAttr padding_attr;
if (!TFPaddingIsSameOrValid(op, &padding_attr)) return matchFailure();
state->padding = padding_attr;
// Additionally, we require the filter operand to be of 4-D tensor type so
// that we can extract info from the shape (e.g., for constructing bias
// tensor, for setting depth_multiplier attribute, etc.).
auto filter_type =
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
if (filter_type && filter_type.getRank() == 4)
return matchSuccess(std::move(state));
return matchFailure();
}
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
// TensorFlow convolution op only has two inputs, while the TFLite one has
// three, with the bias vector marked as optional. However, TOCO has a
// dedicated pass, EnsureBiasVectors, to create default bias vectors for all
// those missing. So we model TFLite convolution op as requiring three
// inputs to achieve the legalization task of EnsureBiasVector. this
// requires the filter tensor to have static shape.
// TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
TFConvOpType tf_op = cast<TFConvOpType>(op);
// Get a splat zero tensor with the expected dimension for the bias tensor
auto filter = tf_op.filter();
auto filter_type = filter->getType().template cast<RankedTensorType>();
auto elem_type = filter_type.getElementType();
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
filter_type.getShape());
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
auto bias_attr = rewriter.getZeroAttr(bias_type);
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
conv_state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(),
filter, bias);
rewriter.replaceOp(op, conv_op.getResult());
}
const IntegerAttr intAttrOne;
};
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
public:
using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
ConvertTFConv2D(MLIRContext *context) : BaseType(context) {}
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
return filterShape.back();
}
TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
Type result_type, Value *input, Value *filter,
Value *bias) const {
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::Conv2DOp>(
loc, result_type, input, filter, bias,
/*dilation_h_factor=*/state->dilation_height_factor,
/*dilation_w_factor=*/state->dilation_width_factor,
/*fused_activation_function=*/rewriter.getStringAttr("NONE"),
/*padding=*/state->padding,
/*stride_h=*/state->stride_height,
/*stride_w=*/state->stride_width);
}
private:
// Legalize the given filter by converting it from TensorFlow filter data
// format HWIO to TFLite Conv2D op filter data format OHWI and return Value
// for the converted filter. Requires that filter is verified by the match
// method that it is a 4-D RankedTensorType.
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
Value *filter) const {
// Create a constant op for HWIO to OHWI transpose permutation.
SmallVector<int, 4> perm = {3, 0, 1, 2};
auto perm_type = rewriter.getTensorType({static_cast<int>(perm.size())},
rewriter.getIntegerType(32));
auto perm_attr =
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>();
auto result_shape = functional::map(
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
perm);
auto elem_type = filter_type.getElementType();
auto result_type = rewriter.getTensorType(result_shape, elem_type);
return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
}
};
class ConvertTFDepthwiseConv2dNative
: public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
TF::DepthwiseConv2dNativeOp> {
public:
using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
TF::DepthwiseConv2dNativeOp>;
ConvertTFDepthwiseConv2dNative(MLIRContext *context) : BaseType(context) {}
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
return filterShape[2] * filterShape[3];
}
TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
Type result_type, Value *input,
Value *filter, Value *bias) const {
// Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
// 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
// fourth dimension in the 4-D filter tensor. We query the multiplier from
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::DepthwiseConv2DOp>(
loc, result_type, input, filter, bias,
/*dilation_h_factor=*/state->dilation_height_factor,
/*dilation_w_factor=*/state->dilation_width_factor,
/*fused_activation_function=*/rewriter.getStringAttr("NONE"),
/*padding=*/state->padding,
/*stride_h=*/state->stride_height,
/*stride_w=*/state->stride_width,
/*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
}
private:
/// Legalize the given filter by converting it from TensorFlow filter data
/// format to TFLite DepthwiseConv2D op filter data format and return Value
/// for the converted filter. TensorFlow filter data format is
/// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
/// filter data format is [1, filter_height, filter_width, out_channels].
/// Requires that filter is verified by the match method that it is a 4-D
/// RankedTensorType.
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
Value *filter) const {
auto filter_type = filter->getType().cast<RankedTensorType>();
auto filterShape = filter_type.getShape();
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
filterShape[2] * filterShape[3]};
auto elem_type = filter_type.getElementType();
auto result_type = rewriter.getTensorType(result_shape, elem_type);
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
auto shape_attr =
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
TFL::populateWithGenerated(&getContext(), &patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.
// This will allow optimizing any TF_Mul->TF_Conv in the graph
// and any expanded from FusedBatchNorm. We need to do this
// before converting TF_Conv to TFL_Conv
applyPatternsGreedily(func, std::move(patterns));
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
patterns.push_back(
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
patterns.push_back(
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
FunctionPassBase *CreatePrepareTFPass() { return new PrepareTFPass(); }
static PassRegistration<PrepareTFPass> pass(
"tfl-prepare-tf", "Prepare TF for legalization to TensorFlow Lite dialect");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies quantization on TFLite dialect.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// The actual Quantize Pass.
//
namespace {
/// Applies quantization on the model in TFL dialect.
struct QuantizePass : public FunctionPass<QuantizePass> {
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
void QuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
auto *context = func.getContext();
populateWithGenerated(context, &patterns);
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace
/// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
FunctionPassBase *CreateQuantizePass() { return new QuantizePass(); }
static PassRegistration<QuantizePass> pass(
"tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,128 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the quantization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
// Call the generic builder of `op`. Use the result type of $0 in the new op.
class ReplaceWith<string op> : NativeCodeCall<"$_builder.create<" # op #
">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">;
// Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be
// squashed to a requantize op if the scales are different.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// Quantize the value of a constant op if the quantization parameters have been
// propagated to the output.
def : Pat<(TFL_QuantizeOp
(ConstantOp ElementsAttr:$value),
$qtype),
(TFL_QConstOp
$qtype,
(QuantizeByQuantizedType $value, $qtype))>;
// Quantize the AddOp if both inputs are dequantized and the output is
// quantized.
def : Pat<(TFL_QuantizeOp:$q
(TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs),
$fused_activation_function),
$output_type),
(ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs,
$fused_activation_function)>;
// Quantize the Conv2DOp if the input and weight are dequantized. The scale of
// the bias input is determined by the scales of input and weight operands.
// TODO(fengliuai): propagate the quantization parameters to the bias input.
def : Pat<(TFL_QuantizeOp
(TFL_Conv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w),
$output_type),
(TFL_Conv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w)>;
// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The
// scale of the bias input is determined by the scales of input and weight
// operands.
// TODO(fengliuai): propagate the quantization parameters to the bias input.
def : Pat<(TFL_QuantizeOp
(TFL_DepthwiseConv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier),
$output_type),
(TFL_DepthwiseConv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass can guarantee both quantization parameters are the
// same.
def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type),
(TFL_ReshapeOp $in)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass has set the output quantization parameters to a
// pre-defined value.
def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta),
$output_type),
(TFL_SoftmaxOp $in, $beta)>;
// Quantize the AveragePool2DOp if the input is dequantized and output is
// quantized. The pre-quantize pass can guarantee both quantization parameters
// are the same.
def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in),
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w), $output_type),
(TFL_AveragePool2DOp $in,
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w)>;

View File

@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
//===----------------------------------------------------------------------===//
// TensorList transformation patterns.
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
// into regular tensors. We also assume that each element in the `TensorList` has
// a same constant shape.
//===----------------------------------------------------------------------===//
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
(replaceWithValue $tensor)>;
def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements),
(replaceWithValue $input)>;
def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) {
if (attr.getType().getNumElements() != 1 ||
!attr.getType().getElementType().isa<FloatType>()) {
return {};
}
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
return attr.getValue(index).cast<FloatAttr>();
}
FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
if (auto m = attr.dyn_cast_or_null<ElementsAttr>()) {
return ExtractSingleElementAsFloat(m);
} else {
return attr.dyn_cast_or_null<FloatAttr>();
}
}
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
if (attr.getType().getNumElements() != 1 ||
!attr.getType().getElementType().isa<IntegerType>()) {
return {};
}
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
return attr.getValue(index).cast<IntegerAttr>();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
// Returns true if none of the three attributes are empty.
inline bool HasAll3Attrs(Attribute a, Attribute b, Attribute c) {
return a != Attribute() && b != Attribute() && c != Attribute();
}
// Returns the single float element from an ElementsAttr. Returns empty
// attribute if the number of elements in the attribute is not 1 or the
// element isn't a float attribute.
FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr);
// Returns the single float element if the input is an ElementsAttr, or return
// itself as a float element. Returns empty attribute if the number of elements
// in the attribute is not 1, the element or itself isn't a float attribute.
FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr);
// Returns the single integer element from an ElementsAttr. Returns empty
// attribute if the number of elements in the attribute is not 1 or the
// element isn't a integer attribute.
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr);
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_

View File

@ -0,0 +1,607 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <unordered_map>
#include "absl/memory/memory.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace TFL {
namespace {
using QuantParams = quant::QuantizedType;
using AccumulatorScaleFunc =
std::function<QuantParams(const std::vector<QuantParams> &)>;
// Quantization specs of ops, driving the TF Lite quantization algorithm.
struct OpQuantSpec {
// Whether the op has quantizable result. This flag is set to false if the op
// has "TFL::NoQuantizableResult" trait.
bool is_quantizable = true;
// Whether it requires same inputs and result scale. This flag is set to true
// if the op has "TFL::SameOperandsAndResultScale" trait.
bool requires_same_scale = false;
// Maps the operand index of a bias input to its quantization specifications,
// including the non-bias operand indexes and the method retrieving
// quantization parameters from list of parameters of the non-bias operands.
// This map is empty if the op doesn't havea bias operand.
std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
biases_params;
// Quantization parameters for value restricted outputs. This is the
// "hard-coded" parameters and should be used unconditionally for the
// quantized op. This vector is empty if the op doesn't have value resctricted
// outputs.
llvm::SmallVector<QuantParams, 1> restricted_output_params;
};
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
// The state for each op result during the quantization parameters propagation.
struct QuantState {
// Quantization parameters propagated to an op result.
QuantParams params;
// A flag indicates this state (the params) shouldn't be changed after it is
// initialized. This flag will be set to true if the quantization parameters
// are from the quantization-aware training.
const bool immutable;
bool IsEmpty() { return EmptyParams(params); }
};
// The state for rescaling the propagated quantization parameters. This can be
// on the input side to satisfy the constraint of previous operation, or on the
// output side to satisfy the constraint of the next operation.
struct RequantizeState {
// Sometimes, we have to "requantize" the quantization result to satisfy all
// the constraints. The "requantize" can happen either on the input or output
// of the quantization result.
enum RequantizePosition {
NO_REQUANTIZE,
ON_INPUT,
ON_OUTPUT
} pos = NO_REQUANTIZE;
// Quantization parameters will be used to add the requantize ops.
QuantParams params;
};
// This is a worklist-driven driver for propagating quantization parameters
// across operations.
//
// The initial quantization parameters are extracted from the quantized type
// between adjacent tfl.quantize and tfl.dequantize ops. All these initial
// parameters are marked as immutable because they are from quantization-aware
// training.
//
// The algorithm traverses each op and sets the quantization parameters of its
// operands and results, according to its quantization specification, and then
// adds the operands and results to the worklist. If there are any conflicts
// (for example, there are quantization parameters propagated from the previous
// iteration), this process stops if the existing parameters are the immutable,
// or adding `requantize` op to resolve the conflicts.
//
// After the algorithm is converaged, pairs of tfl.quantize and tfl.dequantize
// are inserted to the right position to materialize the propagation and
// requantize results.
//
class QuantizationDriver {
public:
explicit QuantizationDriver(Function *fn) : builder_(fn->getBody()) {}
// The entry point of the quantization parameters propagation.
void Run();
private:
// This is used to identify an operand or result of an op. The second element
// of this pair is the index of the operand or result.
using OpValue = std::pair<mlir::Operation *, int>;
// Sets up the states for all the op results in the function.
void Initialize();
// Propagates the quantization parameters across all the ops.
bool PropagateParams();
// Inserts the Quantize and Dequantize ops according to the propagation
// result.
void Finalize();
// Whether the constant is used as a bias input of another op. Here we assume
// bias is used immediately by the user. This assumption is always correct
// after constant folding.
bool UsedAsBias(ConstantOp cst) {
Value *value = cst.getResult();
for (auto &use : value->getUses()) {
auto biases = GetQuantSpec(use.getOwner())->biases_params;
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
}
return false;
}
// Returns all the related quantization constraints of the op.
std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
// Whether Quantization parameters have been propagated to the results of this
// op.
bool IsQuantized(Operation *op);
// Adds all the users of index-th result of op to the work list.
void AddUserToList(Operation *op, int index) {
for (auto &user : op->getResult(index)->getUses()) {
work_list_.push_back(user.getOwner());
}
}
// Adds the defining op of index-th operand of op to the work list.
void AddOperandToList(Operation *op, int index) {
if (auto *inst = op->getOperand(index)->getDefiningOp())
work_list_.push_back(inst);
}
// Returns the quantization params for the bias input from the non-bias
// operands which have their indexes in the `non_biases` vector. The returned
// parameters are calculated by `func`.
QuantParams GetBiasParams(Operation *op, int bias,
const std::vector<int> &non_biases,
AccumulatorScaleFunc func);
// Sets the quantization parameters of the result to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen on
// the input of propagated quantization.
bool SetResultParams(Operation *op, int index, QuantParams params);
// Sets the quantization parameters of the operand to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen on
// the output of propagated quantization.
bool SetOperandParams(Operation *op, int index, QuantParams params);
// Sets the quantization parameters of the constant result according to its
// content.
bool SetConstantResultParams(Operation *op, unsigned storage_type_width,
bool narrow_range);
// Inserts the Quantize and Dequantize ops for quantizing the index-th result
// of the op.
void QuantizeOpResult(Operation *op, int index, QuantParams params);
// Retrieves all the operands and results quantization states of the op.
// Mutable and immutable states are collected in two vectors. Return false
// if the there are more than one immutable states and their scales are
// different because there are no way the same scale constraint can be
// satisfied.
bool GetAllQuantStatesCanBeSameScale(
Operation *op, std::vector<QuantState *> *mutable_states,
std::vector<QuantState *> *immutable_states);
// A heuristic to determine what the quantization parameters are to stisfy
// the same scale constraints. Return NULL if it isn't determined, mainly
// because none of the values are quantized.
QuantState *GetFinalSameState(
Operation *op, const std::vector<QuantState *> &mutable_states,
const std::vector<QuantState *> &immutable_states);
// Returns the state of the index-th operand of the op.
QuantState &GetOperandState(Operation *op, int index) {
return states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
QuantState &GetResultState(Operation *op, int index) {
return states_[result_states_[{op, index}]];
}
// Returns the state of the index-th operand of the op.
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
return rescale_states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
return rescale_states_[result_states_[{op, index}]];
}
// Uses the type of `val` to set the initial state of the index-th result if
// `as_result` is true or index-th operand if `as_result` is false. The state
// is immutable if the type is a quantized type. Returns the index of this
// new state in the state vector.
int InitializeState(Operation *op, int index, Value *val, bool as_result);
// Sets the state of the index-th operand of the op. If this operand is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeOperandState(Operation *op, int index, Value *in,
llvm::DenseMap<Value *, int> *cache) {
auto cached = cache->insert({in, 0});
if (!cached.second) {
operand_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
}
// Sets the state of the index-th result of the op. If this result is cached,
// uses the cached result without creating new entry in the state vector.
// Otherwise, allocate a new entry in the state vector.
void InitializeResultState(Operation *op, int index, Value *res,
llvm::DenseMap<Value *, int> *cache) {
auto cached = cache->insert({res, 0});
if (!cached.second) {
result_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
}
OpBuilder builder_;
// All the ops needs to propagate the quantization parameters to.
std::vector<Operation *> work_list_;
// The vector contains all the quantization parameters propagated from the
// defining operations of the value, or from the quantization aware training.
std::vector<QuantState> states_;
// The map contains all the quantization parameters which are required to
// satisfy the same operands and results constraint. The keys of this map are
// the values from `operand_states_` and `result_state_`.
std::unordered_map<int, RequantizeState> rescale_states_;
// Maps of indexes to the propagation state vector from the ops results and
// op operands. Both maps are unmodified after initialization.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
};
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
} // namespace
// TODO(fengliuai): cache the quantization parameters.
std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
return GetOpQuantSpec(op);
}
bool QuantizationDriver::IsQuantized(Operation *op) {
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
if (GetResultState(op, i).IsEmpty()) return false;
}
return true;
}
int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
bool as_result) {
QuantParams params =
quant::QuantizedType::getQuantizedElementType(val->getType());
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
if (as_result)
result_states_.insert({{op, index}, next_state_index});
else
operand_states_.insert({{op, index}, next_state_index});
return next_state_index;
}
bool QuantizationDriver::SetConstantResultParams(Operation *op,
unsigned storage_type_width,
bool narrow_range) {
ElementsAttr attr;
Value *res = op->getResult(0);
if (!matchPattern(res, m_Constant(&attr))) {
return false;
}
auto final_type = GetUniformQuantizedTypeForElementsAttr(
attr, storage_type_width, narrow_range)
.dyn_cast_or_null<quant::QuantizedType>();
if (!final_type) return false;
return SetResultParams(op, 0, final_type);
}
bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
QuantParams params) {
auto &state = GetResultState(op, res_index);
if (state.immutable) return false;
if (state.IsEmpty()) {
state.params = params;
AddUserToList(op, res_index);
return true;
}
if (state.params != params) {
auto rescale = GetResultRequantizeState(op, res_index);
rescale.params = params;
rescale.pos = RequantizeState::ON_INPUT;
return true;
}
return false;
}
QuantParams QuantizationDriver::GetBiasParams(
Operation *op, int bias, const std::vector<int> &non_biases,
AccumulatorScaleFunc func) {
auto &bias_state = GetOperandState(op, bias);
if (!bias_state.IsEmpty()) {
return bias_state.params;
}
std::vector<QuantParams> op_types;
op_types.reserve(non_biases.size());
for (auto non_bias : non_biases) {
auto &non_bias_type = GetOperandState(op, non_bias);
op_types.push_back(non_bias_type.params);
}
if (op_types.empty()) return {};
return func(op_types);
}
bool QuantizationDriver::SetOperandParams(Operation *op, int index,
QuantParams params) {
auto &state = GetOperandState(op, index);
if (state.immutable) return false;
if (state.IsEmpty()) {
state.params = params;
AddOperandToList(op, index);
return true;
}
if (state.params != params) {
auto rescale = GetOperandRequantizeState(op, index);
rescale.params = params;
rescale.pos = RequantizeState::ON_OUTPUT;
return true;
}
return false;
}
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
QuantParams params) {
Value *original_result = op->getResult(index);
Type expressed_type = original_result->getType();
Type new_type = params.castFromExpressedType(expressed_type);
TypeAttr type_attr = builder_.getTypeAttr(new_type);
builder_.setInsertionPoint(op);
auto quantize = builder_.create<TFL::QuantizeOp>(op->getLoc(), new_type,
original_result, type_attr);
auto dequantize = builder_.create<TFL::DequantizeOp>(
op->getLoc(), expressed_type, quantize.output());
// New ops are inserted before `op`, so here to adjust the order.
op->moveBefore(quantize);
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
original_result->replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, original_result);
}
bool QuantizationDriver::GetAllQuantStatesCanBeSameScale(
Operation *op, std::vector<QuantState *> *mutable_states,
std::vector<QuantState *> *immutable_states) {
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto &state = GetOperandState(op, i);
if (state.immutable) {
if (immutable_states->empty() ||
state.params == immutable_states->front()->params) {
immutable_states->push_back(&state);
} else {
// Multiple immutable states have different scale, quantization fails.
return false;
}
} else {
mutable_states->push_back(&state);
}
}
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
auto &state = GetResultState(op, i);
if (state.immutable) {
if (immutable_states->empty() ||
state.params == immutable_states->front()->params) {
immutable_states->push_back(&state);
} else {
// Multiple immutable states have different scale, quantization fails.
return false;
}
} else {
mutable_states->push_back(&state);
}
}
return true;
}
// A heuristic to determine what the quantization parameters are to satisfy
// the same scale constraints:
// - use an immutable state, or,
// - use the single input if it is ready, or,
// - use the single output if it is ready, or,
// - use use the first ready one in the collection.
QuantState *QuantizationDriver::GetFinalSameState(
Operation *op, const std::vector<QuantState *> &mutable_states,
const std::vector<QuantState *> &immutable_states) {
if (!immutable_states.empty()) {
return immutable_states.front();
}
if (op->getNumOperands() == 1) {
auto &state = GetOperandState(op, 0);
if (!state.IsEmpty()) return &state;
}
if (op->getNumResults() == 1) {
auto &state = GetResultState(op, 0);
if (!state.IsEmpty()) return &state;
}
// The first one which is not empty. This case is rare.
for (auto *state : mutable_states) {
if (!state->IsEmpty()) return state;
}
return nullptr;
}
// This method scans the operations in the function to setup the initial
// states for quantization parameter propagation.
// TODO(fengliuai): This algorithm assumes there are only one pair of
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
// check should be applied.
void QuantizationDriver::Initialize() {
llvm::DenseMap<Value *, int> value_to_state;
builder_.getRegion()->walk([&](Operation *op) {
if (op->isKnownTerminator()) return;
if (!GetQuantSpec(op)->is_quantizable) return;
work_list_.push_back(op);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto *operand = op->getOperand(i);
auto *inst = operand->getDefiningOp();
if (!inst) continue;
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
}
InitializeOperandState(op, i, operand, &value_to_state);
}
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
auto *result = op->getResult(res);
// If the result has been quantized, it should only be used by a
// tfl.quantize op. For this case, we uses the quantized result to create
// the state and mark it immutable.
if (result->hasOneUse()) {
auto user = result->use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
}
}
InitializeResultState(op, res, result, &value_to_state);
}
});
}
bool QuantizationDriver::PropagateParams() {
// TODO(fengliuai): uses a typed indicator instead of a bool value.
bool changed = false;
while (!work_list_.empty()) {
Operation *op = work_list_.back();
work_list_.pop_back();
auto spec = GetQuantSpec(op);
// If the op has no quantizable result, the quantization parameters will not
// be propagated to the results.
if (!spec->is_quantizable) continue;
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
// This constant is used as a bias in another op, then the quantization
// parameters are determined by that op.
if (UsedAsBias(cst) || IsQuantized(op)) continue;
// The quantization parameters are determined by the content of the
// constant.
// TODO(fengliuai): the storage_type_width should be from higher level.
changed |= SetConstantResultParams(op, /*storage_type_width=*/8,
/*narrow_range=*/false);
continue;
}
if (spec->requires_same_scale) {
std::vector<QuantState *> mutable_states, immutable_states;
if (!GetAllQuantStatesCanBeSameScale(op, &mutable_states,
&immutable_states)) {
// Constraints couldn't be satisfied, so this needs to return `false`
// unconditionally, then the Finalize step will be skipped. It shouldn't
// continue or partially quantize the model.
return false;
}
auto *final_state =
GetFinalSameState(op, mutable_states, immutable_states);
if (!final_state) continue;
// Use the final state to set all the operands' parameters.
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
changed |= SetOperandParams(op, i, final_state->params);
// Use the final state to set all the results' parameters.
for (int res = 0, e = op->getNumResults(); res != e; ++res)
changed |= SetResultParams(op, res, final_state->params);
}
for (int i = 0, e = spec->restricted_output_params.size(); i != e; ++i)
changed |= SetResultParams(op, i, spec->restricted_output_params[i]);
for (auto &it : spec->biases_params) {
auto params =
GetBiasParams(op, it.first, it.second.first, it.second.second);
changed |= SetOperandParams(op, it.first, params);
}
}
return changed;
}
void QuantizationDriver::Finalize() {
for (auto it : result_states_) {
Operation *op = it.first.first;
int res_index = it.first.second;
auto &state = GetResultState(op, res_index);
if (state.IsEmpty() || state.immutable) {
continue;
}
QuantizeOpResult(op, res_index, state.params);
auto &requantize = GetResultRequantizeState(op, res_index);
DCHECK(requantize.pos == RequantizeState::NO_REQUANTIZE)
<< "Unimplemented requantize handling";
}
}
void QuantizationDriver::Run() {
Initialize();
if (PropagateParams()) {
Finalize();
}
}
void ApplyQuantizationParamsPropagation(mlir::Function *func) {
QuantizationDriver(func).Run();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,122 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
namespace TFL {
// Returns the quantized type for the
// input_type/min/max/storag_type_width/narrow_range.
static Type GetQuantizedType(Builder builder, Type input_type, double min,
double max, int storage_type_width,
bool narrow_range) {
auto converter =
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
converter.expressedType);
return converter.convert(quantizedEleType);
}
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
FloatAttr max, Type storage_type,
bool narrow_range) {
int storage_type_width = storage_type.cast<IntegerType>().getWidth();
Type final_type = GetQuantizedType(
builder, input_type, min.getValueAsDouble(), max.getValueAsDouble(),
storage_type_width, narrow_range);
return builder.getTypeAttr(final_type);
}
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
Attribute max, IntegerAttr num_bits,
BoolAttr narrow_range) {
FloatAttr min_value = GetSingleElementAsFloatOrSelf(min);
FloatAttr max_value = GetSingleElementAsFloatOrSelf(max);
if (!min_value || !max_value) return {};
return GetQuantizedTypeAttr(builder, input_type, min_value, max_value,
builder.getIntegerType(num_bits.getInt()),
narrow_range.getValue());
}
Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
unsigned storage_type_width,
bool narrow_range) {
Builder builder(attr.getContext());
double min = std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::min();
if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
double ele_value = FloatAttr::getValueAsDouble(*it);
min = std::min(min, ele_value);
max = std::max(max, ele_value);
}
// The range must straddle zero.
if (min > 0.0 || max < 0.0) return {};
auto type = GetQuantizedType(builder, attr.getType(), min, max,
storage_type_width, narrow_range);
if (auto ele_type = type.dyn_cast_or_null<TensorType>())
return ele_type.getElementType();
}
// The range from SplatElementAttr and other element attribute types couldn't
// straddle zero, so the quantization parameters couldn't be derived from its
// range.
return {};
}
quant::QuantizedType GetUniformQuantizedTypeForBias(
const std::vector<quant::QuantizedType>& op_types) {
if (op_types.empty()) return {};
double scale = 1.0;
for (unsigned i = 0, e = op_types.size(); i != e; ++i) {
auto qtype = op_types[i].dyn_cast_or_null<quant::UniformQuantizedType>();
if (!qtype) return {};
scale *= qtype.getScale();
}
auto type = op_types.back().cast<quant::UniformQuantizedType>();
Builder builder(type.getContext());
IntegerType storageType = builder.getIntegerType(32);
return quant::UniformQuantizedType::getChecked(
/*flags=*/true, storageType, type.getExpressedType(), scale,
/*zeroPoint=*/0,
quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32),
quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32),
builder.getUnknownLoc());
}
ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
if (auto q_type =
quant::QuantizedType::getQuantizedElementType(tensor_type)) {
Type converted_type;
return quant::quantizeAttr(real_value, q_type, converted_type)
.cast<ElementsAttr>();
}
return {};
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,70 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
// Converts the min/max/storage_type/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType.
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
FloatAttr max, Type storage_type,
bool narrow_range = false);
// Converts the min/max/num_bits/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType.
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
Attribute max, IntegerAttr num_bits,
BoolAttr narrow_range);
// Quantizes the elements in the attribute `real_value` by the quantization
// parameters in `tensor_type`. Returns empty Attribute if the
// `tensor_type` is not a QuantizedType or the quantization fails.
ElementsAttr Quantize(Attribute real_value, Type tensor_type);
// Returns the quantized type for an element attribute. The quantization
// parameters in this type is based on the min and max element of the attribute.
// When the elements in the `attr` are not in floating-point, or the value range
// isn't straddling zero, an empty type is returned.
Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
unsigned storage_type_width,
bool narrow_range = false);
// Returns the quantized type of a bias input, given the quantized types of
// other operands which are multiply-accumulated (the bias is added to the
// accumulated value).
quant::QuantizedType GetUniformQuantizedTypeForBias(
const std::vector<quant::QuantizedType>& op_types);
// Propagates quantization parameters across ops in this function and satisfy
// the quantization specification of the ops. This methods assumes the initial
// quantization parameters are stored as adjacent quantize and dequantize ops
// and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function.
void ApplyQuantizationParamsPropagation(mlir::Function* func);
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_

View File

@ -0,0 +1,72 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
// Returns true if the given `op`
// * has an attribute with the given `name`,
// * and the attribute is an integer list of the form [1, X, Y, 1],
// and writes X, Y as 32-bit integer attribute to `x`, `y`.
bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
IntegerAttr *y) {
auto attr = op->getAttrOfType<ArrayAttr>(name);
if (!attr) return false;
auto elements = attr.getValue();
if (elements.size() != 4 ||
std::any_of(elements.begin(), elements.end(),
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
return false;
if (elements.front().cast<IntegerAttr>().getInt() != 1 ||
elements.back().cast<IntegerAttr>().getInt() != 1)
return false;
Builder b(op->getContext());
*x = b.getI32IntegerAttr(elements[1].cast<IntegerAttr>().getInt());
*y = b.getI32IntegerAttr(elements[2].cast<IntegerAttr>().getInt());
return true;
}
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
bool TFIntListIs1XY1(const ArrayAttr &attr) {
const auto &elements = attr.getValue();
if (elements.size() != 4 ||
std::any_of(elements.begin(), elements.end(),
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
return false;
if (elements.front().cast<IntegerAttr>().getValue() != 1 ||
elements.back().cast<IntegerAttr>().getValue() != 1)
return false;
return true;
}
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
// This would return false if we had unranked tensors (where they should
// probably be considered as broadcastable), but given we are working with
// attributes here that shouldn't be an issue,
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,73 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common validators used by TFLite transformation
// passes to validate op attributes or values.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
// TODO(jpienaar): Change these to being one of these variants and/or generate
// these predicates.
// Returns true if the given TensorFlow op does not have a `data_format`
// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC".
inline bool TFDataFormatIsNHWC(Operation *op) {
auto attr = op->getAttrOfType<StringAttr>("data_format");
return !attr || attr.getValue() == "NHWC";
}
// Returns true if the given `op`
// * has an attribute with the given `name`,
// * and the attribute is an integer list of the form [1, X, Y, 1],
// and writes X, Y as 32-bit integer attribute to `x`, `y`.
bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
IntegerAttr *y);
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
bool TFIntListIs1XY1(const ArrayAttr &attr);
// Returns true iff the given value is a float tensor.
// is "DT_FLOAT".
inline bool TFTypeIsFloatTensor(Value *value) {
auto tensorType = value->getType().dyn_cast<TensorType>();
if (!tensorType) return false;
return tensorType.getElementType().isa<FloatType>();
}
// Returns true iff the given TensorFlow op has a `padding` attribute whose
// value is "SAME" or "VALID", and writes the attribute to `padding`.
inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
auto padding_attr = op->getAttrOfType<StringAttr>("padding");
if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID")
return false;
*padding = padding_attr;
return true;
}
/// Returns whether the given `a` and `b` have broadcast-compatible
/// types.
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_

View File

@ -0,0 +1,57 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lit runner configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import lit.formats
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
# Lint for undefined variables is disabled as config is not defined inside this
# file, instead config is injected by way of evaluating runlit.cfg.py from
# runlit.site.cfg.py which in turn is evaluated by lit.py. The structure is
# common for lit tests and intended to only persist temporarily (b/136126535).
# pylint: disable=undefined-variable
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'MLIR ' + os.path.basename(config.mlir_test_dir)
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# test_source_root: The root path where tests are located.
config.test_source_root = config.mlir_test_dir
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.environ['RUNFILES_DIR']
llvm_config.use_default_substitutions()
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = config.mlir_tf_tools_dirs + [
config.mlir_tools_dir, config.llvm_tools_dir
]
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)
# pylint: enable=undefined-variable

View File

@ -0,0 +1,56 @@
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lit runner site configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import lit.llvm
# Lint for undefined variables is disabled as config is not defined inside this
# file, instead config is injected by lit.py. The structure is common for lit
# tests and intended to only persist temporarily (b/136126535).
# pylint: disable=undefined-variable
config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm')
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
'local_config_mlir')
# TODO(jpienaar): Replace with sufffices in build rule.
config.suffixes = ['.td', '.mlir', '.pbtxt']
mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/xla',
]
config.mlir_tf_tools_dirs = [
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'], s)
for s in mlir_tf_tools_dirs
]
test_dir = os.environ['TEST_TARGET']
test_dir = test_dir.strip('/').rsplit(':', 1)[0]
config.mlir_test_dir = os.path.join(os.environ['TEST_SRCDIR'],
os.environ['TEST_WORKSPACE'], test_dir)
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(
config,
os.path.join(
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'],
'tensorflow/compiler/mlir/runlit.cfg.py')))
# pylint: enable=undefined-variable

View File

@ -0,0 +1,567 @@
load("@local_config_mlir//:tblgen.bzl", "gentbl")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
packages = [
"//learning/brain/experimental/mlir/...",
"//learning/brain/google/xla/...",
"//tensorflow/compiler/mlir/...",
],
)
filegroup(
name = "tensorflow_ops_td_files",
srcs = [
"ir/tf_generated_ops.td",
"ir/tf_op_base.td",
"ir/tf_ops.td",
"@local_config_mlir//:OpBaseTdFiles",
],
)
gentbl(
name = "tensorflow_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tf_ops.h.inc",
),
(
"-gen-op-defs",
"ir/tf_ops.cc.inc",
),
(
"-gen-op-doc",
"g3doc/tf_ops.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "ir/tf_ops.td",
td_srcs = [
":tensorflow_ops_td_files",
],
)
gentbl(
name = "tensorflow_executor_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tf_executor.h.inc",
),
(
"-gen-op-defs",
"ir/tf_executor.cc.inc",
),
(
"-gen-op-doc",
"g3doc/tf_executor.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "ir/tf_executor_ops.td",
td_srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"@local_config_mlir//:include/mlir/StandardOps/Ops.td",
],
)
gentbl(
name = "tensorflow_canonicalize_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_canonicalize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/canonicalize.td",
td_srcs = [
":tensorflow_ops_td_files",
],
)
cc_library(
name = "tensorflow",
srcs = [
"ir/control_flow_ops.cc",
"ir/tf_executor.cc",
"ir/tf_executor.cc.inc",
"ir/tf_executor.h.inc",
"ir/tf_ops.cc",
"ir/tf_ops.cc.inc",
"ir/tf_ops.h.inc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/optimize.cc",
"transforms/raise_control_flow.cc",
],
hdrs = [
"ir/control_flow_ops.h",
"ir/tf_executor.h",
"ir/tf_ops.h",
"ir/tf_types.def",
"ir/tf_types.h",
"transforms/passes.h",
],
includes = ["include"],
deps = [
":tensorflow_canonicalize_inc_gen",
":tensorflow_executor_inc_gen",
":tensorflow_ops_inc_gen",
":tensorflow_optimize_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:TransformUtils",
"@local_config_mlir//:TypeUtilities",
],
# TODO(jpienaar): Merge in the dialect registration.
alwayslink = 1,
)
# Library with TensorFlow dialect static initialization.
cc_library(
name = "tensorflow_dialect_registration",
srcs = ["ir/dialect_registration.cc"],
deps = [
":tensorflow",
"@local_config_mlir//:IR",
],
alwayslink = 1,
)
cc_library(
name = "convert_graphdef",
srcs = [
"translate/export_graphdef.cc",
"translate/import_graphdef.cc",
],
hdrs = [
"translate/export_graphdef.h",
"translate/import_graphdef.h",
],
deps = [
":convert_tensor",
":convert_type",
":export_tf_dialect_op",
":export_utils",
":mangling_util",
":mlir_roundtrip_flags",
":tensorflow",
"//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "import_utils",
srcs = [
"utils/import_utils.cc",
],
hdrs = [
"utils/import_utils.h",
],
deps = [
":error_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/strings",
"@llvm//:support",
],
)
cc_library(
name = "export_utils",
srcs = [
"utils/export_utils.cc",
],
hdrs = [
"utils/export_utils.h",
],
deps = [
":convert_tensor",
":convert_type",
":mangling_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "export_tf_dialect_op",
srcs = [
"translate/derived_attr_populator.inc",
"translate/export_tf_dialect_op.cc",
],
hdrs = [
"translate/export_tf_dialect_op.h",
],
deps = [
":convert_type",
":export_utils",
":tensorflow",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "translate_tf_dialect_op",
srcs = ["translate/translate_tf_dialect_op.cc"],
deps = [
":export_tf_dialect_op",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
"@local_config_mlir//:Translation",
],
alwayslink = 1,
)
cc_library(
name = "mlir_roundtrip_pass",
srcs = ["translate/mlir_roundtrip_pass.cc"],
hdrs = ["translate/mlir_roundtrip_pass.h"],
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "mlir_roundtrip_flags",
srcs = ["translate/mlir_roundtrip_flags.cc"],
hdrs = ["translate/mlir_roundtrip_flags.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@llvm//:support",
],
)
cc_library(
name = "convert_type",
srcs = ["utils/convert_type.cc"],
hdrs = ["utils/convert_type.h"],
deps = [
":tensorflow",
":tensorflow_dialect_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "convert_tensor",
srcs = ["utils/convert_tensor.cc"],
hdrs = ["utils/convert_tensor.h"],
deps = [
":convert_type",
":mangling_util",
":tensorflow",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "mangling_util",
srcs = ["utils/mangling_util.cc"],
hdrs = ["utils/mangling_util.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "error_util",
srcs = ["utils/error_util.cc"],
hdrs = ["utils/error_util.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "tf_dialect_passes",
srcs = [
"transforms/constant_fold.cc",
"transforms/decode_constant.cc",
"transforms/dialect_hooks.cc",
],
hdrs = [
"transforms/constant_fold.h",
"transforms/decode_constant.h",
],
deps = [
":convert_tensor",
":eval_util",
":tensorflow",
"//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:framework",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)
cc_library(
name = "tf_dialect_lib",
deps = [
":tensorflow_dialect_registration",
":tf_dialect_passes",
"@local_config_mlir//:StandardDialectRegistration",
],
)
cc_library(
name = "eval_util",
srcs = ["utils/eval_util.cc"],
hdrs = ["utils/eval_util.h"],
deps = [
":convert_tensor",
":convert_type",
":export_tf_dialect_op",
":mangling_util",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "translate_lib",
srcs = [
"translate/tf_mlir_translate.cc",
],
hdrs = [
"translate/tf_mlir_translate.h",
],
deps = [
":convert_graphdef",
":error_util",
":import_utils",
":mangling_util",
":mlir_roundtrip_flags",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
],
)
cc_library(
name = "translate_cl_options",
srcs = [
"translate/tf_mlir_translate_cl.cc",
],
hdrs = [
"translate/tf_mlir_translate_cl.h",
],
deps = [
"@llvm//:support",
],
alwayslink = 1,
)
cc_library(
name = "translate_registration",
srcs = [
"translate/tf_mlir_translate_registration.cc",
],
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
":translate_cl_options",
":translate_lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Translation",
],
alwayslink = 1,
)
tf_cc_binary(
name = "tf-mlir-translate",
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
":translate_cl_options",
":translate_lib",
":translate_registration",
":translate_tf_dialect_op",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Translation",
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
],
)
tf_cc_test(
name = "error_util_test",
srcs = ["utils/error_util_test.cc"],
deps = [
":error_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
tf_native_cc_binary(
name = "derived_attr_populator_gen",
srcs = [
"translate/derived_attr_populator_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
],
)
genrule(
name = "derived_attr_populator_inc",
srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"ir/tf_generated_ops.td",
"ir/tf_op_base.td",
"ir/tf_ops.td",
],
outs = [
"translate/derived_attr_populator.inc",
],
cmd = ("$(location :derived_attr_populator_gen) " +
"-I external/local_config_mlir/include " +
"$(location //tensorflow/compiler/mlir/tensorflow:ir/tf_ops.td) " + " -o $@"),
tools = [":derived_attr_populator_gen"],
)
filegroup(
name = "tensorflow_optimize_td_files",
srcs = [
"transforms/optimize.td",
],
)
gentbl(
name = "tensorflow_optimize_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_optimize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/optimize.td",
td_srcs = [
":tensorflow_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements the named operations for the "Control Flow" dialect of
// TensorFlow graphs
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
namespace mlir {
namespace TFControlFlow {
// TODO(ycao): Implement following verify methods when we know more about their
// invariant.
LogicalResult EnterOp::verify() { return success(); }
LogicalResult MergeOp::verify() { return success(); }
LogicalResult NextIterationSourceOp::verify() { return success(); }
LogicalResult NextIterationSinkOp::verify() { return success(); }
LogicalResult LoopCondOp::verify() { return success(); }
LogicalResult SwitchOp::verify() { return success(); }
LogicalResult ExitOp::verify() { return success(); }
TFControlFlowDialect::TFControlFlowDialect(MLIRContext *context)
: Dialect(/*name=*/"_tf", context) {
addOperations<SwitchOp, MergeOp, EnterOp, NextIterationSourceOp,
NextIterationSinkOp, ExitOp, LoopCondOp>();
addTypes<TFControlType>();
// We allow unregistered TensorFlow operations in the control dialect.
allowUnknownOperations();
}
// Parses a type registered to this dialect.
Type TFControlFlowDialect::parseType(StringRef tyData, Location loc) const {
if (tyData != "control")
return (emitError(loc, "unknown TFControl type: " + tyData), nullptr);
return TFControlType::get(getContext());
}
// Prints a type registered to this dialect.
void TFControlFlowDialect::printType(Type type, raw_ostream &os) const {
assert(type.isa<TFControlType>());
os << "control";
}
} // namespace TFControlFlow
} // namespace mlir

View File

@ -0,0 +1,278 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations for the "Control Flow" dialect of TensorFlow
// graphs. The TensorFlow control flow dialect represents control flow with
// Switch/Merge and a few related control flow nodes, along with control
// dependencies. This dialect can be raised to the standard TensorFlow dialect
// by transforming Switch/Merge and other control flow ops into functional
// control flow ops and removing control dependencies.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
namespace mlir {
namespace TFControlFlow {
class TFControlFlowDialect : public Dialect {
public:
explicit TFControlFlowDialect(MLIRContext *context);
// Parses a type registered to this dialect.
Type parseType(StringRef tyData, Location loc) const override;
// Prints a type registered to this dialect.
void printType(Type type, raw_ostream &os) const override;
};
namespace TensorFlowControlTypes {
enum Kind {
Control = Type::FIRST_TENSORFLOW_CONTROL_TYPE,
};
}
class TFControlType : public Type::TypeBase<TFControlType, Type> {
public:
using Base::Base;
static TFControlType get(MLIRContext *context) {
return Base::get(context, TensorFlowControlTypes::Control);
}
// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == TensorFlowControlTypes::Control;
}
};
// The "_tf.Enter" operation forwards its input to Tensorflow while loop. Each
// tensor needs its own _tf.Enter to be made available inside the while loop.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("Enter")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
// .Attr("frame_name: string")
// .Attr("is_constant: bool = false")
// .Attr("parallel_iterations: int = 10")
//
// For example:
// %1 = "_tf.Enter"(%0#0) {T: "tfdtype$DT_INT32", frame_name:
// "while/while_context",} : (tensor<i32>) -> (tensor<*xi32>)
//
// Note: Additional result corresponds to the control output.
class EnterOp
: public Op<EnterOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Enter"; }
Value *getData() { return getOperand(0); }
void setData(Value *value) { setOperand(0, value); }
LogicalResult verify();
};
// The "_tf.Merge" operation takes a list of input operands and returns a value
// of the operand type along with the index of the first match encountered.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("Merge")
// .Input("inputs: N * T")
// .Output("output: T")
// .Output("value_index: int32")
//
// For example:
// %2 = _tf.Merge %0, %1, %2, %3 : tensor<??xf32>
//
// Note: Additional result corresponds to the control output.
class MergeOp : public Op<MergeOp, OpTrait::VariadicOperands,
OpTrait::NResults<3>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Merge"; }
LogicalResult verify();
};
// The "_tf.NextIteration.source" and "_tf.NextIteration.sink" operations form
// a logical pair. Together, they represent NextIteration op in Tensorflow.
//
// Tensorflow NextIteration operation forwards its input to the next iteration
// of a while loop. Each loop variable needs its own NextIteration op.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// NextIteration op is broken into _tf.NextIteration.sink and
// _tf.NextIteration.source because NextIteration is a back-edge in Tensorflow
// graph, which would form a data flow cycle if expressed naively in a basic
// block. _tf.NextIteration.source takes no input but returns results while
// _tf.NextIteration.sink takes input but doesn't return anything. When
// optimizing these ops, they are paired by op names and considered as a
// single op.
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("NextIteration")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
//
// For example:
// %11 = "_tf.NextIteration.source"() {name: "while/NextIteration", T:
// "tfdtype$DT_INT32", id: 0} : () -> (tensor<*xi32>, _tf.control)
// "_tf.NextIteration.sink"(%10#0) {name: "while/NextIteration", T:
// "tfdtype$DT_INT32", id: 0} : (tensor<*xi32>) -> ()
//
// Note: Additional result corresponds to the control output.
class NextIterationSourceOp
: public Op<NextIterationSourceOp, OpTrait::NResults<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.NextIteration.source"; }
LogicalResult verify();
};
class NextIterationSinkOp
: public Op<NextIterationSinkOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::OneResult> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.NextIteration.sink"; }
Value *getData() { return getOperand(0); }
void setData(Value *value) { setOperand(0, value); }
LogicalResult verify();
};
// The "_tf.LoopCond" operation forwards a boolean value as loop condition of
// Tensorflow while loops.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("LoopCond")
// .Input("input: bool")
// .Output("output: bool")
//
// For example:
// %5 = "_tf.LoopCond"(%4#0) {device: "", name: "while/LoopCond"} :
// (tensor<*xi1>) -> (i1, !_tf.control)
//
// Note: Additional result corresponds to the control output.
class LoopCondOp
: public Op<LoopCondOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.LoopCond"; }
Value *getData() { return getOperand(0); }
void setData(Value *value) { setOperand(0, value); }
LogicalResult verify();
};
// The "_tf.Switch" operation takes a data operand and a boolean predicate
// condition, and returns two values matching the type of the data predicate.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("Switch")
// .Input("data: T")
// .Input("pred: bool")
// .Output("output_false: T")
// .Output("output_true: T")
//
// For example:
// %2 = _tf.Switch %0, %1 : tensor<??xf32>
//
// Note: Additional result corresponds to the control output.
class SwitchOp : public Op<SwitchOp, OpTrait::AtLeastNOperands<2>::Impl,
OpTrait::NResults<3>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Switch"; }
Value *getData() { return getOperand(0); }
void setData(Value *value) { setOperand(0, value); }
Value *getPredicate() { return getOperand(1); }
void setPredicate(Value *value) { setOperand(1, value); }
LogicalResult verify();
};
// The "_tf.Exit" operation forwards a value from an while loop to its consumer
// outside of loop. Each returned tensor needs its own _tf.Exit.
//
// More details can be found in Tensorflow Controlflow white paper:
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
//
// This is defined in Tensorflow as:
//
// REGISTER_OP("Exit")
// .Input("data: T")
// .Output("output: T")
// .Attr("T: type")
//
// For example:
// %1 = "_tf.Exit"(%0#0) {T: "tfdtype$DT_INT32",} : (tensor<*xi32>) ->
// (tensor<*xi32>, !_tf.control)
//
// Note: Additional result corresponds to the control output.
class ExitOp : public Op<ExitOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static StringRef getOperationName() { return "_tf.Exit"; }
Value *getData() { return getOperand(0); }
void setData(Value *value) { setOperand(0, value); }
LogicalResult verify();
};
} // namespace TFControlFlow
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
// Static initialization for TF dialect registration.
static DialectRegistration<TFControlFlow::TFControlFlowDialect>
tf_control_flow_ops;
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
tf_excutor_dialect;
} // namespace mlir

View File

@ -0,0 +1,643 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include <algorithm>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir {
namespace tf_executor {
//===----------------------------------------------------------------------===//
// TF Executor Dialect
//===----------------------------------------------------------------------===//
TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
: Dialect(/*name=*/"tf_executor", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
>();
addTypes<ControlType>();
}
Type TensorFlowExecutorDialect::parseType(StringRef data_type,
Location loc) const {
if (data_type == "control") return ControlType::get(getContext());
emitError(loc) << "unknown tf_executor type: " << data_type;
return nullptr;
}
void TensorFlowExecutorDialect::printType(Type type, raw_ostream &os) const {
assert(type.isa<ControlType>());
os << "control";
}
//===----------------------------------------------------------------------===//
// Implementation for all the operations defined in ODS (op definition spec).
//===----------------------------------------------------------------------===//
namespace {
// Inserts `tf_executor.Terminator` at the end of the region's only block if it
// does not have a terminator already. If the region is empty, insert a new
// block first.
template <typename Terminator>
void EnsureExecutorTerminator(Region *region, Builder *builder, Location loc) {
if (region->empty()) region->push_back(new Block);
Block &block = region->back();
if (!block.empty() && block.back().isKnownTerminator()) return;
OperationState terminator_state(loc, Terminator::getOperationName());
Terminator::build(builder, &terminator_state, {});
block.push_back(Operation::create(terminator_state));
}
// Verifies that every control operands are at the end of the list.
// Used by the constraint `ControlOperandsAfterAllData` in ODS.
LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
bool found_control = false;
for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
if (op->getOperand(operand_idx)->getType().isa<ControlType>()) {
found_control = true;
continue;
}
if (found_control)
return op->emitOpError() << "found non-control operand #" << operand_idx
<< " after control operand";
}
return success();
}
//===----------------------------------------------------------------------===//
// tf_executor.graph
//===----------------------------------------------------------------------===//
LogicalResult Verify(GraphOp graph) {
auto *executorDialect = graph.getDialect();
if (graph.GetBody().empty())
return graph.emitOpError() << "expects a non-empty body";
// Only tf_executor dialect operations are allowed to be immediately nested
// in a tf_executor.graph region.
for (Operation &op : graph.GetBody()) {
if (op.getDialect() != executorDialect)
return op.emitOpError() << "unallowed inside a tf_executor.graph region";
}
Operation &fetch = graph.GetBody().back();
if (!isa<FetchOp>(fetch))
return fetch.emitOpError()
<< "invalid tf_executor.graph terminator, fetch expected";
// Ensures that the fetch terminator operands matches the graph result type.
// All the non-control operands of the fetch operation must match the graph
// returned value.
if (fetch.getNumOperands() < graph.getNumResults())
return fetch.emitOpError() << "does not have enough operands to cover the "
"graph returned values";
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
Value *operand = fetch.getOperand(i);
// Breaks out of the loop at the first control operand encountered.
if (operand->getType().isa<ControlType>()) {
if (i != graph.getNumResults())
return fetch.emitOpError()
<< "operand #" << i
<< " is a control type, can't be bound to a graph result";
break;
}
if (i >= graph.getNumResults())
return fetch.emitOpError()
<< "operand #" << i << " does not have a graph results to bind";
if (graph.getResult(i)->getType() != operand->getType())
return fetch.emitOpError()
<< "operand #" << i << " type mismatch graph results";
}
return success();
}
void Print(GraphOp graph, OpAsmPrinter *p) {
*p << graph.getOperationName();
p->printRegion(graph.getOperation()->getRegion(0));
p->printOptionalAttrDict(graph.getAttrs());
}
ParseResult ParseGraphOp(OpAsmParser *parser, OperationState *result) {
llvm::SMLoc loc = parser->getCurrentLocation();
// Parses the body region.
Region &body = *result->addRegion();
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
if (body.getBlocks().size() > 1)
return parser->emitError(loc) << "expects a single block region";
// Ensures that the region is well formed: it contains at least a block with
// a FetchOp terminator.
EnsureExecutorTerminator<FetchOp>(&body, &parser->getBuilder(),
result->location);
// Gets the results type from the terminator type inside the graph.
Operation &fetch = body.back().back();
if (!isa<FetchOp>(fetch))
return parser->emitError(loc) << "expects a tf_executor.fetch terminator";
// The return value of the graph operation are the non-control operands of
// the fetch operation.
result->types.reserve(fetch.getNumOperands());
for (Type type : fetch.getOperandTypes()) {
if (type.isa<ControlType>()) break;
result->types.push_back(type);
}
// Parses the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes)) return failure();
return success();
}
//===----------------------------------------------------------------------===//
// tf_executor.fetch
//===----------------------------------------------------------------------===//
void Print(FetchOp fetch, OpAsmPrinter *p) {
*p << fetch.getOperationName();
if (fetch.getNumOperands() > 0) {
*p << ' ';
p->printOperands(fetch.operand_begin(), fetch.operand_end());
*p << " : ";
interleaveComma(fetch.getOperandTypes(), *p);
}
p->printOptionalAttrDict(fetch.getAttrs());
}
ParseResult ParseFetchOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser->getCurrentLocation();
return failure(
parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands) ||
parser->parseOptionalAttributeDict(result->attributes)
);
}
//===----------------------------------------------------------------------===//
// tf_executor.island
//===----------------------------------------------------------------------===//
LogicalResult Verify(IslandOp island) {
if (island.GetBody().empty())
return island.emitOpError() << "expects a non-empty body";
Operation &yield = island.GetBody().back();
if (!isa<YieldOp>(yield))
return yield.emitOpError()
<< "invalid tf_executor.island terminator, yield expected";
// Ensures that the yield terminator operands matches the island results type.
int result_count = island.getNumResults() - 1; // -1 for the control token
if (yield.getNumOperands() != result_count)
return yield.emitOpError()
<< "has " << yield.getNumOperands()
<< " operand, but island returns " << result_count;
for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
if (island.getResult(operand_idx)->getType() !=
yield.getOperand(operand_idx)->getType())
return yield.emitOpError()
<< "operand #" << operand_idx << " type mismatch island results";
}
// Checks that there aren't any control results other than the last one.
Type control_type = ControlType::get(island.getContext());
for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
if (island.getResult(operand_idx)->getType() == control_type)
return yield.emitOpError()
<< "unexpected control type for operand #" << operand_idx;
}
return success();
}
void Print(IslandOp op, OpAsmPrinter *p) {
*p << op.getOperationName();
if (op.getNumOperands()) {
// These are always control operand, no explicit type needed.
*p << '(';
p->printOperands(op.getOperands());
*p << ')';
}
p->printRegion(op.getOperation()->getRegion(0));
p->printOptionalAttrDict(op.getAttrs());
}
ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) {
llvm::SMLoc loc = parser->getCurrentLocation();
Type control_type = ControlType::get(parser->getBuilder().getContext());
// Parses optional argument list (control dependencies only).
SmallVector<OpAsmParser::OperandType, 4> op_infos;
if (parser->parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
return failure();
if (!op_infos.empty()) {
SmallVector<Type, 2> types;
types.push_back(control_type);
parser->resolveOperands(op_infos, types, loc, result->operands);
}
// Parses the body region.
Region &body = *result->addRegion();
// TODO(b/134773778): the custom parser is missing support to implement to
// short syntax right now.
// if (!parser->parseOptionalKeyword("wraps")) {
// body.push_back(new Block);
// Block &block = body.back();
// parser->getBuilder().setInsertionPointToEnd(&block);
// if (parser->parseOperation())
// return failure();
// }
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
EnsureExecutorTerminator<YieldOp>(&body, &parser->getBuilder(),
result->location);
// Gets the results type for the island from the terminator operands.
Operation &yield = body.back().back();
result->types.reserve(yield.getNumOperands() + 1);
result->types.append(yield.operand_type_begin(), yield.operand_type_end());
result->types.push_back(control_type);
// Parses the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes)) return failure();
return success();
}
//===----------------------------------------------------------------------===//
// tf_executor.yield
//===----------------------------------------------------------------------===//
void Print(YieldOp yield, OpAsmPrinter *p) {
*p << yield.getOperationName();
if (yield.getNumOperands() > 0) {
*p << ' ';
p->printOperands(yield.operand_begin(), yield.operand_end());
*p << " : ";
interleaveComma(yield.getOperandTypes(), *p);
}
p->printOptionalAttrDict(yield.getAttrs());
}
ParseResult ParseYieldOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> op_info;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser->getCurrentLocation();
return failure(
parser->parseOperandList(op_info) ||
(!op_info.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(op_info, types, loc, result->operands) ||
parser->parseOptionalAttributeDict(result->attributes));
}
//===----------------------------------------------------------------------===//
// tf_executor.Switch
//===----------------------------------------------------------------------===//
ParseResult ParseSwitchOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
if (parser->parseOperandList(op_infos, 2) ||
parser->parseColonTypeList(types))
return failure();
if (types.size() != 1)
return parser->emitError(parser->getNameLoc())
<< " expects only a single data type";
// Supports parsing either a functional type (in which case all the types are
// fully qualified) or a short form with a single type (in which case the data
// input and the outputs are all using this type).
if (types.front().isa<FunctionType>()) {
FunctionType type = types.front().cast<FunctionType>();
if (type.getNumInputs() != 2)
return parser->emitError(parser->getNameLoc())
<< " expects a single data type and a predicate";
result->types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
} else {
Type control_type = ControlType::get(parser->getBuilder().getContext());
result->types.append(2, types[0]);
result->types.push_back(control_type);
Type i1_type = parser->getBuilder().getI1Type();
RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
types.push_back(predicate_type);
types.append(op_infos.size() - 2, control_type);
}
llvm::SMLoc loc = parser->getCurrentLocation();
if (parser->resolveOperands(op_infos, types, loc, result->operands))
return failure();
return parser->parseOptionalAttributeDict(result->attributes);
}
void Print(SwitchOp switch_op, OpAsmPrinter *p) {
*p << switch_op.getOperationName() << ' ';
p->printOperands(switch_op.getOperands());
Type data_operand_ty = switch_op.data()->getType();
// If the types aren't perfectly matching, print the functional type syntax
// else print the shorter single type.
*p << " : ";
if (switch_op.trueOutput()->getType() != data_operand_ty ||
switch_op.falseOutput()->getType() != data_operand_ty) {
p->printFunctionalType(switch_op.getOperation());
} else {
*p << switch_op.getType(0);
}
p->printOptionalAttrDict(switch_op.getAttrs());
}
//===----------------------------------------------------------------------===//
// tf_executor.SwitchN
//===----------------------------------------------------------------------===//
LogicalResult Verify(SwitchNOp switchn) {
IntegerAttr num_outs = switchn.getAttrOfType<IntegerAttr>("num_outs");
if (!num_outs)
return switchn.emitOpError() << "expects a `num_outs` integer attribute";
// Expects num_outs results + 1 control output.
if (switchn.getNumResults() != num_outs.getInt() + 1)
return switchn.emitOpError()
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
<< (switchn.getNumResults() - 1);
auto operand0_type = switchn.getOperand(0)->getType();
for (Value *result : switchn.outputs())
if (operand0_type != result->getType())
return switchn.emitOpError()
<< "type mismatch between data operand and result: "
<< operand0_type << " vs " << result->getType();
return success();
}
void Print(SwitchNOp switchn, OpAsmPrinter *p) {
*p << switchn.getOperationName() << ' ';
auto operands = switchn.getOperands();
// Prints the 2 data operands.
p->printOperands(operands.begin(), std::next(operands.begin(), 2));
*p << " of " << (switchn.getNumResults() - 1);
// Prints control dependencies if any
if (!llvm::empty(switchn.controlInputs())) {
*p << " (";
p->printOperands(switchn.controlInputs());
*p << ")";
}
*p << " : " << switchn.getType(0);
p->printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
}
ParseResult ParseSwitchNOp(OpAsmParser *parser, OperationState *result) {
// Parsing:
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
// Where the first operand is the data to replicate, the second is an i32
// indicating which output to populate, followed by the keyword `by` and the
// number of outputs (+1 for the control token).
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
llvm::SMLoc loc = parser->getCurrentLocation();
IntegerAttr num_outs;
Type i64_type = parser->getBuilder().getIntegerType(64);
if (parser->parseOperandList(op_infos, 2) || parser->parseKeyword("of") ||
parser->parseAttribute(num_outs, i64_type, "num_outs",
result->attributes) ||
parser->parseOperandList(op_infos,
OpAsmParser::Delimiter::OptionalParen) ||
parser->parseColonTypeList(types))
return failure();
if (types.size() != 1)
return parser->emitError(parser->getNameLoc())
<< " expects only a single data type";
if (num_outs.getInt() <= 0)
return parser->emitError(parser->getNameLoc())
<< " expects a positive number of outputs";
// `types` already contains the type for the data, add an i32 for the
// output_index, and then the optional control inputs.
types.push_back(parser->getBuilder().getIntegerType(32));
Type control_type = ControlType::get(parser->getBuilder().getContext());
types.append(op_infos.size() - 2, control_type);
if (parser->resolveOperands(op_infos, types, loc, result->operands))
return failure();
// Output result types is a replication `num_outs` times the data input type.
result->types.append(num_outs.getInt(), types[0]);
result->types.push_back(control_type);
return parser->parseOptionalAttributeDict(result->attributes);
}
//===----------------------------------------------------------------------===//
// tf_executor.Merge
//===----------------------------------------------------------------------===//
LogicalResult Verify(MergeOp merge) {
if (!merge.getNumOperands())
return merge.emitOpError() << "expects at least one operand";
Type data_type = merge.getOperand(0)->getType();
if (data_type.isa<ControlType>())
return merge.emitOpError() << "expects a non-control input";
// Checks that all operands can be broadcasted to a common type compatible
// with the result type.
Type broadcasted_type = merge.output()->getType();
for (Type operand_type : merge.getOperandTypes()) {
if (operand_type.isa<ControlType>()) break;
Type new_broadcasted_type =
OpTrait::util::getBroadcastedType(broadcasted_type, operand_type);
if (!new_broadcasted_type)
return merge.emitOpError()
<< "expects all operands to be broadcastable"
<< " but got " << broadcasted_type << " vs " << operand_type;
// Uses the broadcasted type unless we're losing the rank information here.
// This is because for example starting with a result of tensor<4xf32>, if
// the first operand is unranked, the broadcasted type will be unranked.
// Then any tensor operand will be broadcastable to this unranked type.
if ((broadcasted_type.isa<TensorType>() &&
!broadcasted_type.cast<TensorType>().hasRank()) ||
(new_broadcasted_type.isa<TensorType>() &&
new_broadcasted_type.cast<TensorType>().hasRank()))
broadcasted_type = new_broadcasted_type;
}
return success();
}
void Print(MergeOp merge, OpAsmPrinter *p) {
*p << merge.getOperationName() << ' ';
p->printOperands(merge.getOperands());
// Prints the type signature of the operation.
*p << " : " << merge.getType(0);
p->printOptionalAttrDict(merge.getAttrs());
}
ParseResult ParseMergeOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
llvm::SMLoc loc = parser->getCurrentLocation();
if (parser->parseOperandList(op_infos) || parser->parseColonTypeList(types))
return failure();
if (types.size() != 1)
return parser->emitError(parser->getNameLoc())
<< " expects only a single data type";
// Expects the type once, but use it for both operands.
types.push_back(types.front());
// Extra operands are expected to be control inputs.
Type control_type = ControlType::get(parser->getBuilder().getContext());
types.append(op_infos.size() - 2, control_type);
if (parser->resolveOperands(op_infos, types, loc, result->operands))
return failure();
RankedTensorType i32_tensor =
RankedTensorType::get({}, parser->getBuilder().getIntegerType(32));
result->types = {types.front(), i32_tensor, control_type};
return parser->parseOptionalAttributeDict(result->attributes);
}
//===----------------------------------------------------------------------===//
// tf_executor.Enter
//===----------------------------------------------------------------------===//
// Default number for the parallel_iterations attributes on Enter nodes.
constexpr int kDefaultParallelIterations = 10;
void Print(EnterOp enter, OpAsmPrinter *p) {
*p << enter.getOperationName() << ' ';
p->printOperands(enter.getOperands());
*p << " frame \"";
printEscapedString(enter.frame_name(), p->getStream());
*p << "\"";
if (enter.parallel_iterations() != kDefaultParallelIterations)
*p << " parallel_iterations " << enter.parallel_iterations();
if (enter.is_constant()) *p << " constant ";
// If the types aren't perfectly matching, print the functional type syntax
// else print the shorter single type.
*p << " : ";
if (enter.data()->getType() != enter.output()->getType()) {
p->printFunctionalType(enter.getOperation());
} else {
*p << enter.getType(0);
}
p->printOptionalAttrDict(
enter.getAttrs(), {"frame_name", "parallel_iterations", "is_constant"});
}
ParseResult ParseEnterOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
llvm::SMLoc loc = parser->getCurrentLocation();
MLIRContext *context = parser->getBuilder().getContext();
if (parser->parseOperandList(op_infos)) return failure();
if (op_infos.empty())
return parser->emitError(loc) << " expects at least one data operand";
Attribute frame;
if (parser->parseKeyword("frame") ||
parser->parseAttribute(frame, NoneType::get(context), "frame_name",
result->attributes))
return failure();
Type i64 = parser->getBuilder().getIntegerType(64);
if (parser->parseOptionalKeyword("parallel_iterations")) {
result->addAttribute("parallel_iterations",
IntegerAttr::get(i64, kDefaultParallelIterations));
} else {
IntegerAttr parallel_iterations;
if (parser->parseAttribute(parallel_iterations, i64, "parallel_iterations",
result->attributes))
return failure();
}
bool has_constant = succeeded(parser->parseOptionalKeyword("constant"));
result->addAttribute("is_constant", BoolAttr::get(has_constant, context));
SmallVector<Type, 1> types;
if (parser->parseColonTypeList(types)) return failure();
if (types.size() != 1)
return parser->emitError(loc) << " expects only a single data type";
// Support parsing either a functional type (in which case all the types are
// fully qualified) or a short form with a single type (in which case the data
// input and the outputs are all using this type).
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
if (type.getNumInputs() != 1)
return parser->emitError(parser->getNameLoc())
<< " expects a single data type";
result->types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
} else {
Type control_type = ControlType::get(context);
types.append(op_infos.size() - 1, control_type);
result->addTypes({types.front(), control_type});
}
// Extra operands are expected to be control inputs.
if (parser->resolveOperands(op_infos, types, loc, result->operands))
return failure();
return parser->parseOptionalAttributeDict(result->attributes);
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
} // namespace tf_executor
} // namespace mlir

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the tf_executor dialect: it models the TensorFlow executor
// semantics and can represent arbitrary TensorFlow graphs. As such it follows
// the existing execution model that includes deadness propagation, concurrent
// semantics, and control dependencies.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace tf_executor {
class TensorFlowExecutorDialect : public Dialect {
public:
explicit TensorFlowExecutorDialect(MLIRContext *context);
// Parses a type registered to this dialect.
Type parseType(StringRef data, Location loc) const override;
// Prints a type registered to this dialect.
void printType(Type type, raw_ostream &os) const override;
};
namespace TFTypes {
enum Kind {
Control = Type::FIRST_TENSORFLOW_EXECUTOR_TYPE,
};
} // namespace TFTypes
// The Control type is a token-like value that models control dependencies from
// TensorFlow graphs.
class ControlType : public Type::TypeBase<ControlType, Type> {
public:
using Base::Base;
static ControlType get(MLIRContext *context) {
return Base::get(context, TFTypes::Control);
}
// Supports method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == TFTypes::Control; }
};
// Declares the operations for this dialect using the generated header.
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h.inc"
} // namespace tf_executor
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_

View File

@ -0,0 +1,423 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the definition file for the TensorFlow Executor Dialect.
//
#ifdef TF_EXECUTOR_DIALECT
#else
#define TF_EXECUTOR_DIALECT
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
//===----------------------------------------------------------------------===//
// TensorFlow dialect definitions
//===----------------------------------------------------------------------===//
def TfExecutor_Dialect : Dialect {
let name = "tf_executor";
let description = [{
The TensorFlow Executor dialect.
This dialect models the TensorFlow executor semantics and can represent
arbitrary TensorFlow graphs. As such it follows the existing execution model
that includes deadness propagation, concurrent semantics, and control
dependencies.
Operations in this dialect return a value of type `!tf_executor.control` as
last returned value (exceptions are `tf_executor.NextIteration.graph`,
`tf_executor.NextIteration.sink` and `tf_executor.fetch` which dont return any
value).
}];
let cppNamespace = "tf_executor";
}
// Control type.
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">;
//===----------------------------------------------------------------------===//
// TensorFlow Executor Type Constraint
//===----------------------------------------------------------------------===//
// Predicate to verify that the opId'th operand can be broadcasted to the type
// of the resId'th result.
def ControlOperandsAfterAllData :
PredOpTrait<"all control inputs must appear after any non-control input",
CPred<"succeeded(VerifyControlOperandsAfterAllData(&$_op))">>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//
// Base class for the operation in this dialect, it'll forward the verifier,
// printer, and parser to free functions.
class TfExecutor_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TfExecutor_Dialect, mnemonic, traits> {
let verifier = [{ return ::mlir::tf_executor::Verify(*this); }];
let printer = [{ return ::mlir::tf_executor::Print(*this, p); }];
let parser = [{ return Parse$cppClass(parser, result); }];
}
def TfExecutor_GraphOp : TfExecutor_Op<"graph", []> {
let summary = [{The `tf_executor.graph` operation contains a region with a
single block that lists the operations in a TensorFlow graph.}];
let description = [{
The operations are topologically sorted in-order (no cycles are allowed in
the values). The execution model for operations in this block follows the
TensorFlow executor semantics:
1. Operations that dont have any transitive dependencies through the
def/use chains may be executed in parallel
(`tf_executor.NextIteration.Source` is the exception).
2. SSA values in this block can be implicitly dead. This means that every
SSA value defined in a `tf_executor.graph` can be considered implicitly
wrapped in a conceptual `dead_or<T>` structure, and includes a runtime
flag indicating if the value is dead or present.
3. Operations may have special case handling of dead values.
The `tf_executor.graph` op only allows specific `tf_executor` dialect
operations in its body: the `tf_executor.graph` verifier will reject any
unknown operation. In order to execute standard `tf` dialect operations
(like `tf.Add`) they must be wrapped in the `tf_executor.island` operation.
The `tf_executor.graph` operation does not accept any operands, inputs are
implicitly captured by the region, representing the feeds to the graph.
The region attached to `tf_executor.graph` is terminated by a
`tf_executor.fetch` operation. The operands of the terminator correspond to
the result values (or fetches) of the `tf_executor.graph` operation. The
behavior is undefined if any of the operands of the `tf_executor.fetch` is
dead.
}];
let results = (outs
Variadic<AnyType>:$results
);
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
Block &GetBody() { return getOperation()->getRegion(0).front(); }
}];
}
def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfterAllData]> {
let summary = [{
The `tf_executor.fetch` operation terminates the graph and returns values";
}];
let description = [{
The non-control operands of the fetch operation are returned outside of the
graph and must match the return type of the graph.
}];
let arguments = (ins
Variadic<AnyType>:$fetches
);
let verifier = ?;
}
def TfExecutor_IslandOp : TfExecutor_Op<"island", []> {
let summary = [{
The `tf_executor.island` operation is a wrapper for operations in other
dialects to be nested in a `tf_executor.graph`.
}];
let description = [{
The `tf_executor.graph` operation does not allow `tf` dialect operations to
be immediately nested underneath it. The `tf_executor.island` is introduced
as a wrapper for `tf` dialect operations: this results in a more consistent
representation which makes analysis and transformation simpler.
The `tf_executor.island` operation has a single region with a single block
attached (only functional control flow is allowed). The block is terminated
by a `tf_executor.yield` operation. The operands of the terminator
correspond to the result values of the `tf_executor.graph` operation. An
extra result of type `!tf_executor.control` is always produced by every
`tf_executor.island`.
Within an island, execution semantics follow standard sequential behavior as
expected by TF2 and by compiler analyses and transformations, and values
cant be dead. Other nested `tf_executor.graph` operations can be present in
the region to re-enable the TensorFlow executor for a subsection of the
code.
- Initially the functional control flow operations are calling functions
involving graphs, if `tf_executor.graph` werent allowed in an island,
these operations would need to have an equivalent in the `tf_executor`
dialect to be modelled in a graph.
- Nesting also allows forming islands without involving inter-procedural
analyses: any function call may involve a callee with a graph.
The `tf_executor.island` region allows implicit capture. If any value
captured by a `tf_executor.island` is dead, the whole region does not
execute and every produced value is marked as dead as well.
An arbitrary number of `tf_executor.control` operands are accepted by a
`tf_executor.island` operation.
If any operand or implicitly captured value are dead, the region is not
executed and dead values are immediately returned for every result.
}];
let arguments = (ins
Variadic<TfeControlType>:$controlInputs
);
let results = (outs
Variadic<AnyType>:$outputs,
TfeControlType:$control
);
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
Block &GetBody() { return getOperation()->getRegion(0).front(); }
}];
}
def TfExecutor_YieldOp :
TfExecutor_Op<"yield", [Terminator, ControlOperandsAfterAllData]> {
let summary = [{
The `tf_executor.yield` operation terminates and returns values for the
`tf_executor.island` operation.
}];
let arguments = (ins
Variadic<AnyType>:$fetches
);
let verifier = ?;
}
def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
[NoSideEffect, ControlOperandsAfterAllData,
PredOpTrait<"data operand must be broadcastable to true result",
TCOpIsBroadcastableToRes<0, 0>>,
PredOpTrait<"data operand must be broadcastable to false result",
TCOpIsBroadcastableToRes<0, 1>>]>{
let summary = [{
The "tf_executor.Switch" operation takes a data operand and a boolean
predicate condition, and returns two values matching the type of the data
predicate.
}];
let description = [{
More details can be found in Tensorflow Control Flow white paper:
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
This is defined in TensorFlow as:
REGISTER_OP("Switch")
.Input("data: T")
.Input("pred: bool")
.Output("output_false: T")
.Output("output_true: T")
For example:
%2 = tf_executor.Switch %0, %1 : tensor<*xf32>
Note: Additional result corresponds to the control output.
}];
let arguments = (ins
AnyType:$data,
TensorOf<[I1]>:$predicate,
// Optional extra control inputs.
Variadic<TfeControlType>:$controlInputs
);
let results = (outs
AnyType: $trueOutput,
AnyType: $falseOutput,
TfeControlType: $control
);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands = {}",
[{
assert(operands.size() >= 2 && "tf_executor.Switch builder expects at "
"least two operands");
return build(builder, result, operands[0], operands[1], operands.drop_front(2));
}]>,
OpBuilder<
"Builder *builder, OperationState *result, Value *data, Value *predicate, ArrayRef<Value *> controls = {}",
[{
Type dataTy = data->getType();
Type controlTy = ControlType::get(builder->getContext());
result->types = { dataTy, dataTy, controlTy };
result->operands.push_back(data);
result->operands.push_back(predicate);
result->operands.insert(result->operands.end(), controls.begin(), controls.end());
}]>
];
let verifier = ?;
}
def TfExecutor_SwitchNOp :
TfExecutor_Op<"SwitchN", [NoSideEffect, ControlOperandsAfterAllData]> {
let summary = [{
The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and
an integer attribute `num_outs` indicating the number of outputs. The `data`
input is copied to output indicated by the `index` input. The other outputs
are marked as dead. If one of the inputs or a control token is dead, then
all of the outputs are marked as dead as well.
}];
let description = [{
This is defined in TensorFlow as:
REGISTER_OP("_SwitchN")
.Input("data: T")
.Input("output_index: int32")
.Output("outputs: num_outs * T")
.Attr("num_outs: int >= 1")
.Attr("T: type")
.SetShapeFn(SwitchNShape);
For example:
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
Note: One additional result corresponds to the control output.
}];
let arguments = (ins
AnyType:$data,
I32:$index,
// Optional extra control inputs.
Variadic<TfeControlType>:$controlInputs,
I64Attr:$num_outs
);
let results = (outs
Variadic<AnyType>:$outputs,
TfeControlType: $control
);
}
def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
let summary = [{
The "tf_executor.Merge" operation takes a list of input operands and returns
a value of the operand type along with the index of the first match encountered.
}];
let description = [{
More details can be found in Tensorflow Control Flow white paper:
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
This is defined in TensorFlow as:
REGISTER_OP("Merge")
.Input("inputs: N * T")
.Output("output: T")
.Output("value_index: int32")
For example:
%2 = tf_executor.Merge %0, %1, %2, %3 : tensor<*xf32>
Note: Additional result corresponds to the control output.
}];
let arguments = (ins
Variadic<AnyType>:$inputs_and_control
);
let results = (outs
AnyType:$output,
TensorOf<[I32]>:$valueIndex,
TfeControlType:$control
);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
[{
assert(operands.size() >= 1 && "tf_executor.Merge builder expects at "
"least one operand");
Type data_type = operands[0]->getType();
Type control_type = ControlType::get(builder->getContext());
result->types = { data_type, builder->getIntegerType(32), control_type};
result->operands.append(operands.begin(), operands.end());
}]>
];
}
def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
[NoSideEffect, ControlOperandsAfterAllData,
PredOpTrait<"data operand must be broadcastable to result",
TCOpIsBroadcastableToRes<0, 0>>]>{
let summary = [{
The "tf_executor.Enter" operation forwards its input to Tensorflow while
loop.
}];
let description = [{
More details can be found in Tensorflow Control Flow white paper:
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
Each tensor needs its own tf_executor.Enter to be made available inside a
while loop.
This is defined in Tensorflow as:
REGISTER_OP("Enter")
.Input("data: T")
.Output("output: T")
.Attr("T: type")
.Attr("frame_name: string")
.Attr("is_constant: bool = false")
.Attr("parallel_iterations: int = 10")
For example:
%res:2 = tf_executor.Enter %arg0 frame "some/frame" parallel_iterations 42 constant : tensor<*xf32>
Note: Additional result corresponds to the control output.
}];
let arguments = (ins
AnyType:$data,
StrAttr:$frame_name,
DefaultValuedAttr<BoolAttr, "false">:$is_constant,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Optional extra control inputs.
Variadic<TfeControlType>:$controlInputs
);
let results = (outs
AnyType:$output,
TfeControlType:$control
);
let verifier = ?;
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
[{
assert(operands.size() >= 1 && "tf_executor.Enter builder "
"expects at least one operand");
result->operands.append(operands.begin(), operands.end());
Type control_type = ControlType::get(builder->getContext());
result->types.push_back(operands[0]->getType());
result->types.push_back(control_type);
}]>
];
}
#endif // TF_EXECUTOR_DIALECT

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,265 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the base operation definition file for TensorFlow.
//
// This file includes the definition for the TensorFlow dialect, base TensorFlow
// op, and various commonly used TensorFlow types, attributes, and builders.
#ifdef TF_OP_BASE
#else
#define TF_OP_BASE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
//===----------------------------------------------------------------------===//
// TensorFlow dialect definitions
//===----------------------------------------------------------------------===//
def TF_Dialect : Dialect {
let name = "tf";
let description = [{
The TensorFlow dialect.
This dialect maps to TensorFlow operations.
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimentional tensors);
TODO: Make invariants more structured so that we can reference them in ops.
}];
let cppNamespace = "TF";
}
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TF_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// TensorFlow type definitions
//===----------------------------------------------------------------------===//
// Any tensor element type defined in the TensorFlow dialect
def TF_TFDialectType :
Type<CPred<"$_self.isa<TensorFlowType>()">, "TensorFlow type">;
// Any tensor element type allowed in TensorFlow ops
def TF_ElementType : Type<Or<[AnyFloat.predicate, AnyInteger.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
// Any TensorFlow tensor type
def TF_Tensor : TensorOf<[TF_ElementType]>;
//===----------------------------------------------------------------------===//
// Integer types
def TF_I32Or64 : IntOfWidths<[32, 64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
def TF_Int : IntOfWidths<[8, 16, 32, 64]>;
// Any integer tensor types
def TF_IntTensor : TensorOf<[TF_Int]>;
//===----------------------------------------------------------------------===//
// Floating-point types
def TF_F32Or64 : FloatOfWidths<[32, 64]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
// Any floating-point tensor types
def TF_FpTensor : TensorOf<[AnyFloat]>;
//===----------------------------------------------------------------------===//
// Complex types
def TF_Complex64 :
Type<CPred<"$_self.isa<TF::Complex64Type>()">, "complex64 type">;
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128 :
Type<CPred<"$_self.isa<TF::Complex128Type>()">, "complex128 type">;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
"64/128-bit complex type">;
def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types
def TF_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TensorFlow string type">,
BuildableType<"getType<mlir::TF::StringType>()">;
def TF_StrTensor : TensorOf<[TF_Str]>;
def TF_Variant : Type<CPred<"$_self.isa<mlir::TF::VariantType>()">,
"TensorFlow variant type">,
BuildableType<"getType<mlir::TF::VariantType>()">;
def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : Type<CPred<"$_self.isa<mlir::TF::ResourceType>()">,
"TensorFlow variant type">,
BuildableType<"getType<mlir::TF::ResourceType>()">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Multi-category type constraints
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
// Any integer or floating-point tensor types
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyComplex], "number">;
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
def TF_NumberOrStrTensor : TensorOf<[TF_AnyNumber, TF_Str]>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// String attribute constraints
// A string attribute whose value are one of the values in `cases`.
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
CPred<!foldl(
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
!foreach(case, !tail(cases),
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
prev, cur, prev # " || " # cur)>,
"string attribute whose value is " #
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
prev, cur, prev # ", or " # cur)>;
// TODO: Use EnumAttr to define the common attribute cases
def TF_ConvnetDataFormatAttr : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
"'NHWC' or 'NCHW' convnet data format">;
//===----------------------------------------------------------------------===//
// Type attributes
// A derived attribute that returns the element type of `idx`-th ODS-declared
// operand. If the `idx`-th operand is a variadic operand, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic operand has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// dynamic value pack that corresponds to the `idx`-th ODS-declared variadic
// operand. This returns a list of element types so it is used for variadic
// operands that can have different element types.
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"mlir::OperandElementTypeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};"
>;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// result. If the `idx`-th result is a variadic result, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic result has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// dynamic value pack that corresponds to the `idx`-th ODS-declared variadic
// result. This returns a list of element types so it is used for variadic
// results that can have different element types.
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"mlir::ResultElementTypeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};"
>;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";
}
//===----------------------------------------------------------------------===//
// TensorFlow common builders
//===----------------------------------------------------------------------===//
// Mixin class defining a builder for binary ops supporting broadcast
// behavior. The result type has the same element type as both operands.
class WithBroadcastableBinOpBuilder {
list<OpBuilder> builders = [OpBuilder<
"Builder *builder, OperationState *result, Value* x, Value* y",
[{
auto resultType =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
if (!resultType)
mlir::emitError(result->location, "non-broadcastable operands");
return build(builder, result, resultType, x, y);
}]
>];
}
// Mixin class defining a builder for comparison ops supporting broadcast
// behavior. The result type has bool element type.
class WithBroadcastableCmpOpBuilder {
list<OpBuilder> builders = [OpBuilder<
"Builder *builder, OperationState *result, Value* x, Value* y",
[{
Type resultType;
if (x->getType().isa<UnrankedTensorType>() ||
y->getType().isa<UnrankedTensorType>()) {
resultType = builder->getTensorType(builder->getI1Type());
} else {
SmallVector<int64_t, 4> resultShape;
if (!OpTrait::util::getBroadcastedShape(
x->getType().cast<ShapedType>().getShape(),
y->getType().cast<ShapedType>().getShape(), resultShape)) {
mlir::emitError(result->location,
"operands have no broadcastable shapes");
}
resultType = builder->getTensorType(resultShape, builder->getI1Type());
}
return build(builder, result, resultType, x, y);
}]
>];
}
#endif // TF_OP_BASE

View File

@ -0,0 +1,905 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include <algorithm>
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
namespace mlir {
namespace TF {
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// TF op helper functions
//===----------------------------------------------------------------------===//
/// Returns true if the given `value` is of ranked float tensor type with the
/// given `rank`.
static inline bool isOfRankedFloatTensorType(Value *value, int rank) {
auto type = value->getType().dyn_cast<RankedTensorType>();
return type && type.getRank() == rank &&
type.getElementType().isa<FloatType>();
}
// Returns true if the given `value` has the specified rank or has unranked
// type.
static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) {
if (auto type = value->getType().dyn_cast<RankedTensorType>()) {
return type.getRank() == rank;
}
return true;
}
/// Returns true if the specified element type is a TensorFlow type that is ok
/// in a tensor.
static inline bool isValidTFElementType(Type type) {
return type.isa<FloatType>() || type.isa<IntegerType>() ||
type.isa<TensorFlowType>();
}
// Returns true if this is a valid TensorFlow tensor type.
static inline bool isValidTFTensorType(Type type) {
// TensorFlow types should be tensors of one of the valid TensorFlow element
// types.
if (auto tensorTy = type.dyn_cast<TensorType>())
return isValidTFElementType(tensorTy.getElementType());
return false;
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<AddToAddV2>::build(results, context);
}
//===----------------------------------------------------------------------===//
// AddV2Op
//===----------------------------------------------------------------------===//
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
}
//===----------------------------------------------------------------------===//
// BroadcastToOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BroadcastToOp op) {
// TODO(antiagainst): check that
// * The 'shape' input is an 1-D int tensor.
// * Each dimension pair of the source and target shapes are either equal
// or one of them is one.
return success();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<CastSameType>::build(results, context);
}
//===----------------------------------------------------------------------===//
// ConjOp
//===----------------------------------------------------------------------===//
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<ConjNested>::build(results, context);
}
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Returns the held attribute value.
return value();
}
// Builds a constant op with the specified attribute `value`. The result
// op's type is deduced from `value`; if `value` is of scalar type,
// wraps it up with a tensor type of empty shape.
void ConstOp::build(Builder *builder, OperationState *result, Attribute value) {
ShapedType type;
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
type = elemAttr.getType();
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexiblity by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct
// valid TensorFlow constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
value = DenseElementsAttr::get(type, value);
}
// TODO: support other TensorFlow specific types.
assert(type && "unsupported attribute type for building tf.Const");
result->types.push_back(type);
result->addAttribute("value", value);
}
void ConstOp::build(Builder *builder, OperationState *result, Type type,
Attribute value) {
ConstOp::build(builder, result, value);
assert(type == result->types[0] && "type mismatch in construction");
}
//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
}
//===----------------------------------------------------------------------===//
// FakeQuantWithMinMaxArgsOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
// TODO(fengliuai): moving the following to an utility method.
const llvm::fltSemantics &semantics = op.min().getSemantics();
float rmin, rmax;
if (&semantics == &APFloat::IEEEsingle()) {
rmin = op.min().convertToFloat();
rmax = op.max().convertToFloat();
} else {
rmin = op.min().convertToDouble();
rmax = op.max().convertToDouble();
}
// Range boundaries must be valid.
if (rmin >= rmax) {
return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
"," + Twine(std::to_string(rmax)) + "]");
}
// Range must straddle zero.
if (rmin > 0.0 || rmax < 0.0) {
return op.emitOpError("range failed to straddle zero: [" +
Twine(std::to_string(rmin)) + "," +
Twine(std::to_string(rmax)) + "]");
}
int64_t num_bits = op.num_bits().getSExtValue();
if (num_bits < 2 || num_bits > 16) {
return op.emitOpError(
"requires num_bits to be between 2 and 16, inclusive");
}
return success();
}
//===----------------------------------------------------------------------===//
// FakeQuantWithMinMaxVarsOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
if (!isOfRankedFloatTensorType(op.min(), 0))
return op.emitOpError("requires min to be a 0d float tensor");
if (!isOfRankedFloatTensorType(op.max(), 0))
return op.emitOpError("requires max to be a 0d float tensor");
int64_t num_bits = op.num_bits().getSExtValue();
if (num_bits < 2 || num_bits > 16) {
return op.emitOpError(
"requires num_bits to be between 2 and 16, inclusive");
}
return success();
}
//===----------------------------------------------------------------------===//
// FusedBatchNormOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(FusedBatchNormOp op) {
if (!isOfRankedFloatTensorType(op.x(), 4))
return op.emitOpError("requires x to be a 4D float tensor");
if (!isOfRankedFloatTensorType(op.scale(), 1))
return op.emitOpError("requires scale to be a 1D float tensor");
if (!isOfRankedFloatTensorType(op.offset(), 1))
return op.emitOpError("requires offset to be a 1D float tensor");
if (!isOfRankedFloatTensorType(op.mean(), 1))
return op.emitOpError("requires mean to be a 1D float tensor");
if (!isOfRankedFloatTensorType(op.variance(), 1))
return op.emitOpError("requires variance to be a 1D float tensor");
// TODO(antiagainst): check attributes
return success();
}
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
LogicalResult IfOp::verify() {
auto thenAttr = getAttrOfType<FunctionAttr>("then_branch");
if (!thenAttr) return emitOpError("requires then_branch attribute");
auto elseAttr = getAttrOfType<FunctionAttr>("else_branch");
if (!elseAttr) return emitOpError("requires else_branch attribute");
auto *module = getOperation()->getFunction()->getModule();
auto *thenFn = module->getNamedFunction(thenAttr.getValue());
if (!thenFn)
return emitOpError("then_branch refers to an undefined function : ")
<< thenAttr;
auto *elseFn = module->getNamedFunction(elseAttr.getValue());
if (!elseFn)
return emitOpError("else_branch refers to an undefined function : ")
<< elseAttr;
auto thenFuncType = thenFn->getType();
auto elseFuncType = elseFn->getType();
// Non-conditional operands starting with the second operand are passed to
// branches and should be pair-wise compatible with branches' inputs.
unsigned expectedNumInputs = getNumOperands() - 1;
if (thenFuncType.getNumInputs() != expectedNumInputs ||
elseFuncType.getNumInputs() != expectedNumInputs)
return emitError("branches should have " + Twine(expectedNumInputs) +
" inputs");
for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = getOperand(i + 1)->getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!TensorCastOp::areCastCompatible(operandType, thenInputType))
return emitError(
llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}",
thenInputType, operandType, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
if (!TensorCastOp::areCastCompatible(operandType, elseInputType))
return emitError(
llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}",
elseInputType, operandType, i));
// If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid.
if (!TensorCastOp::areCastCompatible(thenInputType, elseInputType))
return emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i));
}
// Branches' results should be pair-wise compatible with the op results.
unsigned expectedNumResults = getNumResults();
if (thenFuncType.getNumResults() != expectedNumResults ||
elseFuncType.getNumResults() != expectedNumResults)
return emitError("branches should have " + Twine(expectedNumResults) +
" results");
for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = getResult(i)->getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!TensorCastOp::areCastCompatible(thenResultType, resultType))
return emitError(
llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}",
thenResultType, resultType, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
if (!TensorCastOp::areCastCompatible(elseResultType, resultType))
return emitError(
llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}",
elseResultType, resultType, i));
}
return success();
}
//===----------------------------------------------------------------------===//
// InvertOp
//===----------------------------------------------------------------------===//
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<InvertNested>::build(results, context);
}
//===----------------------------------------------------------------------===//
// LeakyReluOp
//===----------------------------------------------------------------------===//
OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "leaky relu has one operand");
// leaky_relu(x, alpha: 1) -> x
if (alpha().convertToFloat() == 1.0f) return getOperand();
auto calculate = [&](FloatAttr arg) {
APFloat val = arg.getValue();
if (val.isNegative()) val = alpha() * val;
return FloatAttr::get(arg.getType(), val);
};
if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
return calculate(arg);
} else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
if (auto elementAttr = arg.getSplatValue().dyn_cast<FloatAttr>())
return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
}
return {};
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<LogOfSoftmax>::build(results, context);
}
//===----------------------------------------------------------------------===//
// LogicalNotOp
//===----------------------------------------------------------------------===//
void LogicalNotOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
context);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<NegNested>::build(results, context);
}
//===----------------------------------------------------------------------===//
// ReciprocalOp
//===----------------------------------------------------------------------===//
void ReciprocalOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<ReciprocalNested>::build(results, context);
}
//===----------------------------------------------------------------------===//
// RandomUniformOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(RandomUniformOp op) {
if (!IsOfRankOrUnranked(op.shape(), 1))
return op.emitOpError("shape must be 1D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
void RangeOp::build(Builder *builder, OperationState *result, Value *start,
Value *limit, Value *delta) {
assert(start->getType() == limit->getType());
assert(start->getType() == delta->getType());
DenseIntElementsAttr start_val;
DenseIntElementsAttr limit_val;
DenseIntElementsAttr delta_val;
if (matchPattern(start, m_Constant(&start_val)) &&
matchPattern(limit, m_Constant(&limit_val)) &&
matchPattern(delta, m_Constant(&delta_val))) {
auto size = llvm::APIntOps::RoundingSDiv(
*limit_val.begin() - *start_val.begin(), *delta_val.begin(),
llvm::APInt::Rounding::DOWN);
return RangeOp::build(
builder, result,
builder->getTensorType(
size.getSExtValue(),
start->getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
return RangeOp::build(
builder, result,
builder->getTensorType(
{-1}, start->getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
void RankOp::build(Builder *builder, OperationState *result, Value *input) {
return RankOp::build(builder, result,
builder->getTensorType({}, builder->getIntegerType(32)),
input);
}
//===----------------------------------------------------------------------===//
// RealDivOp
//===----------------------------------------------------------------------===//
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
// TODO(b/128020684): Verify the rank of the output and change to use
// m_Constant.
static LogicalResult Verify(ReshapeOp op) {
auto shapeType = op.shape()->getType().cast<TensorType>();
if (shapeType.getRank() != 1)
return op.emitOpError("shape must be 1D tensor");
auto rankByShape = shapeType.getShape()[0];
auto typeOfTensor = op.tensor()->getType().cast<TensorType>();
// No compile time verification for unknown sized shape.
if (rankByShape == -1 || !typeOfTensor.hasRank()) return success();
// Checks values if constant shape. No compiling time verification for
// non-constant shape.
auto *shapeOp = op.shape()->getDefiningOp();
if (!shapeOp) return success();
Attribute shapeCst;
if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) {
shapeCst = shapeStdOp.getValue();
} else if (auto shapeTFOp = dyn_cast<ConstOp>(shapeOp)) {
shapeCst = shapeTFOp.value();
} else {
return success();
}
auto shapeCstAttr = shapeCst.dyn_cast<ElementsAttr>();
if (!shapeCstAttr) return op.emitOpError("shape must be a valid tensor");
if (auto opaqueAttr = shapeCstAttr.dyn_cast<OpaqueElementsAttr>()) {
opaqueAttr.decode(shapeCstAttr);
}
// We know the shape is a 1-D Tensor, then let us get the number of
// elements it implies.
unsigned numByShape = 1;
unsigned unknownDimCount = 0;
for (int i = 0, e = rankByShape; i != e; ++i) {
auto num = shapeCstAttr.getValue(i).cast<IntegerAttr>().getInt();
// The dimension size value can be -1, and that the real size needs to
// be computed so that the total size remains constant. At most one
// component of shape can be -1.
if (num == -1) {
if (++unknownDimCount > 1) {
return op.emitOpError("more than one component of shape are -1");
}
} else {
numByShape *= num;
}
}
auto numByTensor = typeOfTensor.getNumElements();
// If there is one component of shape is -1, the dimension should be
// computed so that the total size remains constant.
if (unknownDimCount == 1) {
if (numByTensor % numByShape != 0)
return op.emitOpError(
"one component of shape is -1 but couldn't infer the dimension");
return success();
}
// If the elements by the tensor and implies by the shape don't match,
// fail this static check.
if (numByTensor != numByShape) {
return op.emitOpError(
"mismatch in tensor elements and shape implied elements");
}
return success();
}
void ReshapeOp::build(Builder *builder, OperationState *result, Value *tensor,
Value *shape) {
auto etype = tensor->getType().cast<ShapedType>().getElementType();
DenseIntElementsAttr attr_shape;
if (matchPattern(shape, m_Constant(&attr_shape))) {
llvm::SmallVector<int64_t, 4> const_shape;
if (attr_shape.isSplat()) {
const_shape.assign(attr_shape.getType().getNumElements(),
(*attr_shape.begin()).getSExtValue());
} else {
const_shape.reserve(attr_shape.getType().getNumElements());
for (auto dim : attr_shape) const_shape.push_back(dim.getSExtValue());
}
return ReshapeOp::build(builder, result,
builder->getTensorType(const_shape, etype), tensor,
shape);
}
return ReshapeOp::build(builder, result, builder->getTensorType(etype),
tensor, shape);
}
//===----------------------------------------------------------------------===//
// ShapeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ShapeOp op) {
auto inputType = op.input()->getType();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType || resultType.getShape().size() != 1)
return op.emitOpError("requires 1D result type");
auto rankedTensorType = inputType.dyn_cast<RankedTensorType>();
if (rankedTensorType) {
// The operand is a ranked tensor.
if (resultType.hasStaticShape()) {
if ((!rankedTensorType.getShape().empty() &&
resultType.getDimSize(0) != rankedTensorType.getShape().size()) ||
(rankedTensorType.getShape().empty() &&
resultType.getDimSize(0) != 1))
return op.emitOpError(
"requires dimension size of result to match rank of operand");
}
} else {
// The operand is an unranked tensor, verify that the result is dynamic.
if (resultType.hasStaticShape())
return op.emitOpError("requires dynamic shape result for unranked input");
}
Type elt = op.getType().cast<ShapedType>().getElementType();
if (elt.isInteger(32) || elt.isInteger(64)) return success();
return op.emitOpError("requires int32 or int64 return type");
}
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
auto inputType = getOperand()->getType();
auto rankedTensorType = inputType.dyn_cast<RankedTensorType>();
if (!rankedTensorType || !rankedTensorType.hasStaticShape()) return {};
// TODO: This is handling the simple case where the resultant type matches the
// size of getShape()'s returned type.
auto shape = rankedTensorType.getShape();
int rank = shape.size();
if (rank == 0) return {};
Builder b(getContext());
auto elementType = getType().cast<ShapedType>().getElementType();
SmallVector<Attribute, 4> dimensions;
dimensions.reserve(rank);
for (int i = 0; i < rank; ++i)
dimensions.push_back(b.getIntegerAttr(elementType, shape[i]));
auto resultType = b.getTensorType({rank}, elementType);
return b.getDenseElementsAttr(resultType, dimensions);
}
//===----------------------------------------------------------------------===//
// SoftmaxOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SoftmaxOp op) {
if (!IsOfRankOrUnranked(op.logits(), 2))
return op.emitOpError("requires operand to be 2D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// SquareOp
//===----------------------------------------------------------------------===//
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<SquareOfSub>::build(results, context);
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<SubOfNeg>::build(results, context);
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TensorListReserveOp op) {
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
return op.emitOpError("requires num_elements operand to be 0D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeOp op) {
// TODO(hinsu): Verify using a custom verifier that,
// * Transpose permutation is 1-D of size equal to the rank of the first
// input, if the shapes are partially known. Requires use of a more
// restrictive type than TF_Tensor.
// * Result shape dimensions are possible based on the input shape.
return success();
}
// TODO(jpienaar): perm could be optional too.
void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
Value *perm) {
auto x_type = x->getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!x_type.hasRank())
return TransposeOp::build(builder, result,
builder->getTensorType(x_type.getElementType()),
x, perm);
// TODO(jpienaar): Handle unknown perm case.
// TODO(jpienaar): Extract utility function.
auto etype = x_type.cast<ShapedType>().getElementType();
DenseIntElementsAttr attr_shape;
if (matchPattern(perm, m_Constant(&attr_shape))) {
llvm::SmallVector<int64_t, 4> const_shape;
if (attr_shape.isSplat()) {
const_shape.assign(
attr_shape.getType().getNumElements(),
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
} else {
const_shape.reserve(attr_shape.getType().getNumElements());
for (auto dim : attr_shape)
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
}
return TransposeOp::build(
builder, result, builder->getTensorType(const_shape, etype), x, perm);
}
return TransposeOp::build(builder, result, builder->getTensorType(etype), x,
perm);
}
//===----------------------------------------------------------------------===//
// TruncateDivOp
//===----------------------------------------------------------------------===//
void TruncateDivOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
LogicalResult WhileOp::verify() {
auto condAttr = getAttrOfType<FunctionAttr>("cond");
if (!condAttr) return emitOpError("requires cond attribute");
auto *module = getOperation()->getFunction()->getModule();
auto *condFn = module->getNamedFunction(condAttr.getValue());
auto condFuncType = condFn->getType();
// Verify that the cond function has exactly one result.
if (condFuncType.getNumResults() != 1)
return emitOpError("requires cond function to have exactly one result");
auto bodyAttr = getAttrOfType<FunctionAttr>("body");
if (!bodyAttr) return emitOpError("requires body attribute");
auto *bodyFn = module->getNamedFunction(bodyAttr.getValue());
auto bodyFuncType = bodyFn->getType();
SmallVector<Type, 4> operands(getOperandTypes());
SmallVector<Type, 4> results(getResultTypes());
// Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility.
int numTypeLists = 5;
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
{"operand", operands},
{"body function result", bodyFuncType.getResults()},
{"result", results},
{"cond function input", condFuncType.getInputs()},
{"body function input", bodyFuncType.getInputs()},
};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
// common source of inputs for both. Therefore, the While op requires the
// following pairs of type lists to be cast compatible for the tensor_cast
// operation:
//
// * Operands and cond inputs to call the cond function before the
// first iteration.
// * Operands and body inputs to call the body function for the first
// iteration if the cond functions returns True or equivalent result.
// * Operands and results to assign cond function arguments to op results if
// the cond function returns False or equivalent result.
// * All three pairs using cond inputs, body inputs and results as operand is
// a common source for all three.
// * Body result and cond inputs to call the cond function for the subsequent
// iterations. Similarly, Body result should be compatible with body inputs
// and op results.
//
// Note that the operands and body results need not be compatible as they are
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
for (int i = 0; i < numTypeLists; ++i) {
// Skips the first pair as the While op operands and body function results
// does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
auto &a = typeLists[i];
auto &b = typeLists[j];
int aSize = a.second.size();
if (aSize != b.second.size())
return emitOpError(
llvm::formatv("requires the number of {0}s to be equal to the "
"number of {1}s. Found {2} and {3}, respectively",
a.first, b.first, aSize, b.second.size()));
for (int idx = 0; idx < aSize; ++idx) {
auto aType = a.second[idx];
auto bType = b.second[idx];
if (!TensorCastOp::areCastCompatible(aType, bType))
return emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx));
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// XdivyOp
//===----------------------------------------------------------------------===//
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
//===----------------------------------------------------------------------===//
// TF Dialect
//===----------------------------------------------------------------------===//
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
: Dialect(/*name=*/"tf", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
, IfOp, WhileOp>();
addTypes<
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
// Supports unknown operations because not all TensorFlow operations are
// registered.
allowUnknownOperations();
}
// Parses a type registered to this dialect.
Type TensorFlowDialect::parseType(StringRef data, Location loc) const {
auto typeKind = llvm::StringSwitch<unsigned>(data)
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
.Case(name, TensorFlowTypes::enumerant)
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
.Default(0);
switch (typeKind) {
default:
return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
return tftype##Type::get(getContext());
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
}
}
// Prints a type registered to this dialect.
void TensorFlowDialect::printType(Type ty, raw_ostream &os) const {
assert(ty.isa<TensorFlowType>());
switch (ty.getKind()) {
default:
llvm_unreachable("unexpected tensorflow type kind");
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
os << name; \
break;
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
}
}
Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// If this is an opaque elements attribute, then generate a tf.Const.
if (value.isa<OpaqueElementsAttr>() && value.getType() == type)
return builder.create<ConstOp>(loc, type, value);
return nullptr;
}
// This verifies that the Op is a well-formed TensorFlow op, checking
// that all inputs and results are Tensor or other TensorFlow types, etc.
LogicalResult verifyTensorFlowOp(Operation *op) {
if (op->getName().getDialect() != "tf")
return op->emitError("TensorFlow op ")
<< op->getName() << " should start with 'tf.'";
for (Type type : op->getOperandTypes()) {
if (!isValidTFTensorType(type))
return op->emitOpError(
"requires operands to have a valid TensorFlow tensor type");
}
for (Type type : op->getResultTypes()) {
if (!isValidTFTensorType(type))
return op->emitOpError(
"requires results to have a valid TensorFlow tensor type");
}
return success();
}
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,164 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the standard MLIR TensorFlow dialect
// after control dependences are raise to the standard form.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TF {
class TensorFlowDialect : public Dialect {
public:
TensorFlowDialect(MLIRContext *context);
// Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a
// function references to its gradient function. This attribute in TensorFlow
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
// string description of gradient attribute.
static StringRef GetGradientAttrName() { return "tf.gradient"; }
// Parses a type registered to this dialect.
Type parseType(StringRef data, Location loc) const override;
// Prints a type registered to this dialect.
void printType(Type ty, raw_ostream &os) const override;
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
// This verifies that the Op is a well-formed TensorFlow op, checking
// that all inputs and results are Tensor or other TensorFlow types, etc.
static LogicalResult verifyTensorFlowOp(Operation *op);
// This Trait should be used by all TensorFlow Ops.
//
template <typename ConcreteType>
class TensorFlowOp : public OpTrait::TraitBase<ConcreteType, TensorFlowOp> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTensorFlowOp(op);
}
};
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use
// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and
// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to
// undefine here to avoid expanding the getter symbol as macro when including
// both mutex.h and this header file.
#undef mutex_lock
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
// The "tf.If" operation takes a condition operand, a list of inputs, and a
// function attribute for the then/else branches. The condition operand
// doesn't have to be a boolean tensor. It is handled according to these
// rules, quoting the TensorFlow op definition:
//
// If the tensor is a scalar of non-boolean type, the scalar is converted to
// a boolean according to the following rule: if the scalar is a numerical
// value, non-zero means True and zero means False; if the scalar is a
// string, non-empty means True and empty means False. If the tensor is not a
// scalar, being empty means False and being non-empty means True.
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("If")
// .Input("cond: Tcond")
// .Input("input: Tin")
// .Output("output: Tout")
// .Attr("Tcond: type")
// .Attr("Tin: list(type) >= 0")
// .Attr("Tout: list(type) >= 0")
// .Attr("then_branch: func")
// .Attr("else_branch: func")
//
// Note: Additional result corresponds to the control output.
class IfOp : public Op<IfOp, TensorFlowOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::VariadicResults> {
public:
using Op::Op;
static StringRef getOperationName() { return "tf.If"; }
Value *getCondition() { return getOperand(0); }
// TODO(b/132271680): This is not following Google naming style
StringRef getThen() {
return getAttrOfType<FunctionAttr>("then_branch").getValue();
}
StringRef getElse() {
return getAttrOfType<FunctionAttr>("else_branch").getValue();
}
LogicalResult verify();
};
// The "tf.While" operation takes a list of inputs and function attributes for
// the loop condition and body. Inputs are updated repeatedly by the body
// function while the loop condition with the tensors evaluates to true. The
// condition result doesn't have to be a boolean tensor. It is handled
// according to these rules, quoting the TensorFlow op definition:
//
// If the tensor is a scalar of non-boolean type, the scalar is converted to
// a boolean according to the following rule: if the scalar is a numerical
// value, non-zero means True and zero means False; if the scalar is a
// string, non-empty means True and empty means False. If the tensor is not a
// scalar, being empty means False and being non-empty means True.
//
// This is defined in TensorFlow as:
//
// REGISTER_OP("While")
// .Input("input: T")
// .Output("output: T")
// .Attr("T: list(type) >= 0")
// .Attr("cond: func")
// .Attr("body: func")
// .Attr("output_shapes: list(shape) = []")
//
class WhileOp : public Op<WhileOp, TensorFlowOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
using Op::Op;
static StringRef getOperationName() { return "tf.While"; }
StringRef getCond() { return getAttrOfType<FunctionAttr>("cond").getValue(); }
StringRef getBody() { return getAttrOfType<FunctionAttr>("body").getValue(); }
LogicalResult verify();
};
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_

Some files were not shown because too many files have changed in this diff Show More