Start open development of TF, TFLite, XLA MLIR dialects.
This change adds dialects and tools for TF, TFLite and XLA using MLIR (https://github.com/tensorflow/mlir). This is under active development and not built by default. PiperOrigin-RevId: 255538027
This commit is contained in:
parent
3398e887f5
commit
eab4b9c4cc
56
tensorflow/compiler/mlir/BUILD
Normal file
56
tensorflow/compiler/mlir/BUILD
Normal file
@ -0,0 +1,56 @@
|
||||
# Description:
|
||||
# TensorFlow/TensorFlow Lite/XLA MLIR dialects and tools.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# To reference all tablegen files here when checking for updates to them.
|
||||
filegroup(
|
||||
name = "td_files",
|
||||
srcs = glob(["**/*.td"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_mlir_opt_main",
|
||||
srcs = ["tf_mlir_opt_main.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:MlirOptLib",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//test:TestDialect",
|
||||
"@local_config_mlir//test:TestTransforms",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
":tf_mlir_opt_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "litfiles",
|
||||
srcs = glob(["runlit*py"]),
|
||||
)
|
11
tensorflow/compiler/mlir/README.md
Normal file
11
tensorflow/compiler/mlir/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
|
||||
|
||||
This module contains the MLIR
|
||||
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
|
||||
dialects and utilities for
|
||||
|
||||
1. TensorFlow
|
||||
2. XLA
|
||||
3. TF Lite
|
||||
|
||||
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.
|
47
tensorflow/compiler/mlir/glob_lit_test.bzl
Normal file
47
tensorflow/compiler/mlir/glob_lit_test.bzl
Normal file
@ -0,0 +1,47 @@
|
||||
# Test definitions for Lit, the LLVM test runner.
|
||||
#
|
||||
# This is reusing the LLVM Lit test runner in the interim until the new build
|
||||
# rules are upstreamed.
|
||||
# TODO(b/136126535): remove this custom rule.
|
||||
"""Lit runner globbing test
|
||||
"""
|
||||
|
||||
def glob_lit_tests(
|
||||
exclude = None,
|
||||
test_file_exts = ["mlir"],
|
||||
default_size = "small",
|
||||
size_override = None,
|
||||
data = None,
|
||||
per_test_extra_data = None,
|
||||
default_tags = None,
|
||||
tags_override = None,
|
||||
driver = None,
|
||||
features = []):
|
||||
"""Creates all plausible Lit tests (and their inputs) under this directory.
|
||||
|
||||
Args:
|
||||
exclude: [str], paths to exclude (for tests and inputs).
|
||||
test_file_exts: [str], extensions for files that are tests.
|
||||
default_size: str, the test size for targets not in "size_override".
|
||||
size_override: {str: str}, sizes to use for specific tests.
|
||||
data: [str], additional input data to the test.
|
||||
per_test_extra_data: {str: [str]}, extra data to attatch to a given file.
|
||||
default_tags: [str], additional tags to attach to the test.
|
||||
tags_override: {str: str}, tags to add to specific tests.
|
||||
driver: str, label of the driver shell script.
|
||||
features: [str], list of extra features to enable.
|
||||
"""
|
||||
native.py_test(
|
||||
name = "glob_lit_tests",
|
||||
srcs = ["@llvm//:lit"],
|
||||
args = [
|
||||
"tensorflow/compiler/mlir --config-prefix=runlit",
|
||||
],
|
||||
data = data + [
|
||||
"//tensorflow/compiler/mlir:litfiles",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:count",
|
||||
"@llvm//:not",
|
||||
] + native.glob(["*." + ext for ext in test_file_exts]),
|
||||
main = "lit.py",
|
||||
)
|
518
tensorflow/compiler/mlir/lite/BUILD
Normal file
518
tensorflow/compiler/mlir/lite/BUILD
Normal file
@ -0,0 +1,518 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
|
||||
load(
|
||||
"@local_config_mlir//:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
# TODO(jpienaar): Make the visibility more restrictive.
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@local_config_mlir//:QuantizationOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_ops_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls",
|
||||
"ir/tfl_ops.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_prepare_tf_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/prepare_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_lower_static_tensor_list_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/tensorlist_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_legalize_tf_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/legalize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_optimize_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_optimize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/optimize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_quantize_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_quantize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/quantize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "validators",
|
||||
srcs = [
|
||||
"utils/validators.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/validators.h",
|
||||
],
|
||||
deps = [
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite",
|
||||
srcs = [
|
||||
"ir/tfl_ops.cc",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
"ir/tfl_ops.h.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops.h",
|
||||
"ir/tfl_traits.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"utils/quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TypeUtilities",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_quantization_utils",
|
||||
srcs = [
|
||||
"utils/generated_op_quant_spec_getters.inc",
|
||||
"utils/quantization_driver.cc",
|
||||
"utils/quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/lower_static_tensor_list.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_quantization_utils",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_optimize",
|
||||
srcs = [
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/optimize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_quantize",
|
||||
srcs = [
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/post_quantize.cc",
|
||||
"transforms/prepare_quantize.cc",
|
||||
"transforms/quantize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_quantization_utils",
|
||||
":validators",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "op_quant_spec_getters_gen",
|
||||
srcs = [
|
||||
"tools/op_quant_spec_getters_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "op_quant_spec_getters_inc",
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"utils/generated_op_quant_spec_getters.inc",
|
||||
],
|
||||
cmd = ("$(location :op_quant_spec_getters_gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
|
||||
tools = [":op_quant_spec_getters_gen"],
|
||||
)
|
||||
|
||||
# Library with tensorflow Lite dialect static initialization.
|
||||
cc_library(
|
||||
name = "tensorflow_lite_dialect_registration",
|
||||
srcs = [
|
||||
"ir/dialect_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator-writer-gen",
|
||||
srcs = [
|
||||
"operator_writer_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "operator_writer_inc",
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"operator_writers.inc",
|
||||
],
|
||||
cmd = ("$(location :operator-writer-gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
|
||||
tools = [":operator-writer-gen"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_tflite_operator_lib",
|
||||
srcs = [
|
||||
"flatbuffer_operator.cc",
|
||||
"operator_writers.inc",
|
||||
],
|
||||
hdrs = [
|
||||
"flatbuffer_operator.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "flatbuffer_to_string",
|
||||
srcs = ["flatbuffer_to_string.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/schema:schema_fbs_with_reflection",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_lib",
|
||||
srcs = [
|
||||
"flatbuffer_translate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flatbuffer_translate.h",
|
||||
],
|
||||
deps = [
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_tfl_translate_cl_options",
|
||||
srcs = [
|
||||
"tf_tfl_translate_cl.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_tfl_translate",
|
||||
srcs = ["tf_tfl_translate.cc"],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
":tensorflow_lite",
|
||||
":tf_tfl_translate_cl_options",
|
||||
":tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "mlir-tflite-runner",
|
||||
srcs = ["mlir_tflite_runner.cc"],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform/default/build_config:base",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_tfl_flatbuffer",
|
||||
srcs = ["tf_to_tfl_flatbuffer.cc"],
|
||||
hdrs = [
|
||||
"tf_to_tfl_flatbuffer.h",
|
||||
],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_legalize_tf",
|
||||
":tensorflow_lite_optimize",
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
)
|
28
tensorflow/compiler/mlir/lite/README.md
Normal file
28
tensorflow/compiler/mlir/lite/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# Experimental code for the new TF-Lite convertor, and MLIR dialects and utilities for TensorFlow Lite.
|
||||
|
||||
This directory contains:
|
||||
|
||||
1. Experimental code for the new TF-Lite convertor.
|
||||
2. Code for the TF-lite dialect [MLIR](https://github.com/tensorflow/mlir).
|
||||
|
||||
## API:
|
||||
The API for converting TensorFlow models to TensorFlow Lite is
|
||||
tf.lite.TFLiteConverterV2.
|
||||
|
||||
### The conversion process from TensorFlow to TensorFlow Lite includes the following major passes:
|
||||
|
||||
- Import from GraphDef, in .pb or .pbtxt format, into MLIR.
|
||||
- Raise to Control-flow-graph. Converts TF Control Flow dialect to TF dialect.
|
||||
- The Canonicalization pass iteratively applies canonicalization
|
||||
transformations in a greedy way until no further changes occur.
|
||||
Canonicalization includes constant folding.
|
||||
- The Legalize pass converts TensorFlow operations to TensorFlow Lite
|
||||
ones. The operations that cannot be mapped to TensorFlow Lite dialect
|
||||
are left as TensorFlow operations. Unsupported op handling follows the
|
||||
proposed TFLite mechanism.
|
||||
- Optimizations are performed in both the TF & TFLite dialect; aiming for small
|
||||
size and high performance (among the core value proposition of
|
||||
TensorFlow Lite models).
|
||||
- The Export pass writes out TensorFlow Lite FlatBuffer format. This pass
|
||||
operates on MLIR TensorFlow Lite dialect and is simple/direct translation.
|
||||
|
113
tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
Normal file
113
tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
Normal file
@ -0,0 +1,113 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
|
||||
// way when part of the translation of module.
|
||||
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
|
||||
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
|
||||
.Case("NONE", tflite::ActivationFunctionType_NONE)
|
||||
.Case("RELU", tflite::ActivationFunctionType_RELU)
|
||||
.Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
|
||||
.Case("RELU6", tflite::ActivationFunctionType_RELU6)
|
||||
.Case("TANH", tflite::ActivationFunctionType_TANH)
|
||||
.Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
|
||||
}
|
||||
|
||||
static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
|
||||
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return llvm::StringSwitch<tflite::Padding>(str)
|
||||
.Case("SAME", tflite::Padding_SAME)
|
||||
.Case("VALID", tflite::Padding_VALID);
|
||||
}
|
||||
|
||||
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
|
||||
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::F16:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::StandardTypes::F32:
|
||||
return tflite::TensorType_FLOAT32;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::COMPLEX64:
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
default:
|
||||
llvm_unreachable("invalid integer Type in conversion");
|
||||
}
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("invalid Type in conversion");
|
||||
}
|
||||
}
|
||||
|
||||
// I32Attr already returns an int as required by flatbuffer builders.
|
||||
static int ConvertI32AttrForOptionWriter(
|
||||
llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return i.getSExtValue();
|
||||
}
|
||||
|
||||
// F32Attr already returns a float as required by flatbuffer builders.
|
||||
static float ConvertF32AttrForOptionWriter(
|
||||
llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return f.convertToFloat();
|
||||
}
|
||||
|
||||
// BoolAttr already returns a bool as required by flatbuffer builders.
|
||||
static bool ConvertBoolAttrForOptionWriter(
|
||||
bool b, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return b;
|
||||
}
|
||||
|
||||
static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
|
||||
ConvertDerivedShapeAttrForOptionWriter(
|
||||
llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
|
||||
std::vector<int> intVec(r.begin(), r.end());
|
||||
return builder->CreateVector(intVec);
|
||||
}
|
||||
|
||||
static tflite::FullyConnectedOptionsWeightsFormat
|
||||
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
|
||||
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
|
||||
.Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
|
||||
.Case("SHUFFLED4x16INT8",
|
||||
tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
|
||||
}
|
||||
|
||||
// Pull in FlatBuffer writers for TFLite generated using TableGen
|
||||
#include "tensorflow/compiler/mlir/lite/operator_writers.inc"
|
47
tensorflow/compiler/mlir/lite/flatbuffer_operator.h
Normal file
47
tensorflow/compiler/mlir/lite/flatbuffer_operator.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Returns the builtin op code for the given MLIR operation on success; emits
|
||||
// error and returns llvm::None on failure.
|
||||
llvm::Optional<tflite::BuiltinOperator> GetBuiltinOpCode(Operation *mlir_op);
|
||||
|
||||
// Packs the given MLIR operation into a TFLite FlatBuffer operator object.
|
||||
// Returns the FlatBuffer offset for the operator on success; emits error and
|
||||
// returns llvm::None on failure.
|
||||
llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
|
||||
Operation *mlir_op, uint32_t opcode_index,
|
||||
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
|
||||
flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
|
142
tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
Normal file
142
tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Dumps a TFLite flatbuffer to a textual output format.
|
||||
// This tool is intended to be used to simplify unit testing/debugging.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "flatbuffers/minireflect.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/schema/reflection/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
// Reads a model from a provided file path and verifies if it is a valid
|
||||
// flatbuffer, and returns false with the model in serialized_model if valid
|
||||
// else true.
|
||||
bool ReadAndVerify(const std::string& file_path,
|
||||
std::string* serialized_model) {
|
||||
if (file_path == "-") {
|
||||
*serialized_model = std::string{std::istreambuf_iterator<char>(std::cin),
|
||||
std::istreambuf_iterator<char>()};
|
||||
} else {
|
||||
std::ifstream t(file_path);
|
||||
if (!t.is_open()) {
|
||||
std::cerr << "Failed to open input file.\n";
|
||||
return true;
|
||||
}
|
||||
*serialized_model = std::string{std::istreambuf_iterator<char>(t),
|
||||
std::istreambuf_iterator<char>()};
|
||||
}
|
||||
|
||||
flatbuffers::Verifier model_verifier(
|
||||
reinterpret_cast<const uint8_t*>(serialized_model->c_str()),
|
||||
serialized_model->length());
|
||||
if (!model_verifier.VerifyBuffer<Model>()) {
|
||||
std::cerr << "Verification failed.\n";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// A FlatBuffer visitor that outputs a FlatBuffer as a string with proper
|
||||
// indention for sequence fields.
|
||||
// TODO(wvo): ToStringVisitor already has indentation functionality, use
|
||||
// that directly instead of this sub-class?
|
||||
struct IndentedToStringVisitor : flatbuffers::ToStringVisitor {
|
||||
std::string indent_str;
|
||||
int indent_level;
|
||||
|
||||
IndentedToStringVisitor(const std::string& delimiter,
|
||||
const std::string& indent)
|
||||
: ToStringVisitor(delimiter), indent_str(indent), indent_level(0) {}
|
||||
|
||||
void indent() {
|
||||
for (int i = 0; i < indent_level; ++i) s.append(indent_str);
|
||||
}
|
||||
|
||||
// Adjust indention for fields in sequences.
|
||||
|
||||
void StartSequence() override {
|
||||
s += "{";
|
||||
s += d;
|
||||
++indent_level;
|
||||
}
|
||||
|
||||
void EndSequence() override {
|
||||
s += d;
|
||||
--indent_level;
|
||||
indent();
|
||||
s += "}";
|
||||
}
|
||||
|
||||
void Field(size_t /*field_idx*/, size_t set_idx,
|
||||
flatbuffers::ElementaryType /*type*/, bool /*is_vector*/,
|
||||
const flatbuffers::TypeTable* /*type_table*/, const char* name,
|
||||
const uint8_t* val) override {
|
||||
if (!val) return;
|
||||
if (set_idx) {
|
||||
s += ",";
|
||||
s += d;
|
||||
}
|
||||
indent();
|
||||
if (name) {
|
||||
s += name;
|
||||
s += ": ";
|
||||
}
|
||||
}
|
||||
|
||||
void StartVector() override { s += "[ "; }
|
||||
void EndVector() override { s += " ]"; }
|
||||
|
||||
void Element(size_t i, flatbuffers::ElementaryType /*type*/,
|
||||
const flatbuffers::TypeTable* /*type_table*/,
|
||||
const uint8_t* /*val*/) override {
|
||||
if (i) s += ", ";
|
||||
}
|
||||
};
|
||||
|
||||
void ToString(const std::string& serialized_model) {
|
||||
IndentedToStringVisitor visitor(/*delimiter=*/"\n", /*indent=*/" ");
|
||||
IterateFlatBuffer(reinterpret_cast<const uint8_t*>(serialized_model.c_str()),
|
||||
ModelTypeTable(), &visitor);
|
||||
std::cout << visitor.s << "\n\n";
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 2) {
|
||||
std::cerr << "Missing input argument. Usage:\n"
|
||||
<< argv[0] << " <filename or - for stdin>\n\n"
|
||||
<< "Converts TensorFlowLite flatbuffer to textual output format. "
|
||||
<< "One positional input argument representing the source of the "
|
||||
<< "flatbuffer is supported.\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string serialized_model;
|
||||
if (tflite::ReadAndVerify(argv[1], &serialized_model)) return 1;
|
||||
tflite::ToString(serialized_model);
|
||||
return 0;
|
||||
}
|
1028
tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
Normal file
1028
tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
Normal file
File diff suppressed because it is too large
Load Diff
42
tensorflow/compiler/mlir/lite/flatbuffer_translate.h
Normal file
42
tensorflow/compiler/mlir/lite/flatbuffer_translate.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
|
||||
// These flags are used to control the emission or not of different kinds of ops
|
||||
// during the flatbuffer translation.
|
||||
extern bool emit_builtin_tflite_ops;
|
||||
extern bool emit_select_tf_ops;
|
||||
extern bool emit_custom_ops;
|
||||
// The flag to control whether to lower tensorlist ops into TF ops.
|
||||
extern bool lower_tensor_list_ops;
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||
// serialized flatbuffer into the string.
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::Module *module,
|
||||
std::string *serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
|
1512
tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md
Executable file
1512
tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md
Executable file
File diff suppressed because it is too large
Load Diff
19
tensorflow/compiler/mlir/lite/ir/dialect_registration.cc
Normal file
19
tensorflow/compiler/mlir/lite/ir/dialect_registration.cc
Normal file
@ -0,0 +1,19 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
// Static initialization for TensorFlow Lite op registration.
|
||||
static mlir::DialectRegistration<mlir::TFL::TensorFlowLiteDialect> tfl_ops;
|
574
tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
Normal file
574
tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
Normal file
@ -0,0 +1,574 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlowLiteDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
|
||||
: Dialect(/*name=*/"tfl", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common support logic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns true if the dimensions in `a` is a suffix of the ones in `b`.
|
||||
// For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
|
||||
// {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
|
||||
inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
|
||||
if (a.size() > b.size()) return false;
|
||||
|
||||
return std::equal(a.rbegin(), a.rend(), b.rbegin());
|
||||
}
|
||||
|
||||
// Performs const folding `calculate` with broadcast behavior on the two
|
||||
// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
// The two operands are expected to both be scalar values.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class CalculationT =
|
||||
std::function<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
|
||||
Attribute operand2,
|
||||
const CalculationT &calculate) {
|
||||
auto lhs = operand1.cast<AttrElementT>();
|
||||
auto rhs = operand2.cast<AttrElementT>();
|
||||
|
||||
assert(lhs.getType() == result_type && rhs.getType() == result_type &&
|
||||
"values of incompatible types should be caught by op verification");
|
||||
|
||||
// TODO: Need to handle overflow/underflow cases.
|
||||
return AttrElementT::get(result_type,
|
||||
calculate(lhs.getValue(), rhs.getValue()));
|
||||
}
|
||||
|
||||
// TODO: We have multiple functions to handle different attriubte kinds in the
|
||||
// following. Consider add methods to ElementsAttr to unify these functions.
|
||||
|
||||
// Performs const folding `calculate` with broadcast behavior on the two
|
||||
// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
// This function assumes that both operands are `AttrElementT` attributes.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class CalculationT =
|
||||
std::function<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute ConstFoldBinaryOpSplatSplat(Type result_type, Attribute operand1,
|
||||
Attribute operand2,
|
||||
const CalculationT &calculate) {
|
||||
auto type = result_type.cast<ShapedType>();
|
||||
auto elem_type = type.getElementType();
|
||||
|
||||
auto element_result = ConstFoldBinaryOpScalarScalar<AttrElementT>(
|
||||
elem_type, operand1, operand2, calculate);
|
||||
if (!element_result) return {};
|
||||
|
||||
return DenseElementsAttr::get(type, element_result);
|
||||
}
|
||||
|
||||
/// Performs const folding `calculate` with broadcast behavior on the two
|
||||
/// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
/// This function assumes the first operand is a DenseElementsAttr and the
|
||||
/// second one is a SplatElementsAttr, and both are verified to have value
|
||||
/// attributes of broadcastable types.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class CalculationT =
|
||||
std::function<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute ConstFoldBinaryOpDenseSplat(Type result_type, Attribute operand1,
|
||||
Attribute operand2,
|
||||
const CalculationT &calculate) {
|
||||
auto lhs = operand1.cast<DenseElementsAttr>();
|
||||
|
||||
// TODO: Support broadcast behavior
|
||||
if (lhs.getType() != result_type || operand2.getType() != result_type)
|
||||
return {};
|
||||
|
||||
auto rhs = operand2.cast<SplatElementsAttr>().getSplatValue();
|
||||
auto type = result_type.cast<ShapedType>();
|
||||
|
||||
SmallVector<ElementValueT, 16> new_values;
|
||||
new_values.reserve(lhs.rawSize());
|
||||
|
||||
// Add the splat value to each of the values in the dense elements
|
||||
// attribute.
|
||||
auto rhs_val = rhs.cast<AttrElementT>().getValue();
|
||||
for (auto old_val : lhs.getValues<ElementValueT>()) {
|
||||
new_values.push_back(calculate(old_val, rhs_val));
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, new_values);
|
||||
}
|
||||
|
||||
/// Performs const folding `calculate` with broadcast behavior on the two
|
||||
/// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
/// This function assumes the both operands are DenseElementsAttr and verified
|
||||
/// to have value attributes of broadcastable types.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class CalculationT =
|
||||
std::function<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1,
|
||||
Attribute operand2,
|
||||
const CalculationT &calculate) {
|
||||
auto lhs = operand1.cast<DenseElementsAttr>();
|
||||
auto rhs = operand2.cast<DenseElementsAttr>();
|
||||
|
||||
if (lhs.getType() != rhs.getType()) {
|
||||
// We only support the case that one of the operand's dimensions are
|
||||
// a perfect suffix of the other.
|
||||
// TODO: support the general broadcast behavior.
|
||||
auto lhs_shape = lhs.getType().getShape();
|
||||
auto rhs_shape = rhs.getType().getShape();
|
||||
if (!IsTrailingDimensions(lhs_shape, rhs_shape) &&
|
||||
!IsTrailingDimensions(rhs_shape, lhs_shape))
|
||||
return {};
|
||||
}
|
||||
|
||||
auto lhs_num_elements = lhs.getType().getNumElements();
|
||||
auto rhs_num_elements = rhs.getType().getNumElements();
|
||||
|
||||
auto type = result_type.cast<ShapedType>();
|
||||
auto num_elements = type.getNumElements();
|
||||
|
||||
// We assume the arguments have broadcast-compatible types. Make sure again.
|
||||
assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements);
|
||||
assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0);
|
||||
|
||||
SmallVector<ElementValueT, 16> lhs_old_values(lhs.getValues<ElementValueT>());
|
||||
SmallVector<ElementValueT, 16> rhs_old_values(rhs.getValues<ElementValueT>());
|
||||
SmallVector<ElementValueT, 16> new_values;
|
||||
new_values.reserve(num_elements);
|
||||
|
||||
// Add each pair of the corresponding values in the dense elements
|
||||
// attributes.
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
// We only support a degenerated case here: the dimensions in one operand's
|
||||
// shape is a perfect suffix to the other operand. Then conceptually it's
|
||||
// similar to broadcasting a scalar to a 1-D vector.
|
||||
// TODO: support the general broadcast behavior.
|
||||
// We are tiling the operand with less elements an integral times to match
|
||||
// the operand with more elements. We don't care which operand has less
|
||||
// elements here because we are iterating its elements in circles, which can
|
||||
// be achieved using the result index modulo the element count. For the
|
||||
// operand with more elements, since the result has the same number of
|
||||
// elements, we are only going over its elements once. The modulo operation
|
||||
// also works for that.
|
||||
int lhs_index = i % lhs_num_elements;
|
||||
int rhs_index = i % rhs_num_elements;
|
||||
|
||||
new_values.push_back(
|
||||
calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index]));
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, new_values);
|
||||
}
|
||||
|
||||
/// Performs const folding `calculate` with broadcast behavior on the two
|
||||
/// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
/// This function assumes the two operands are verified to have value
|
||||
/// attributes of broadcastable types.
|
||||
template <class AttrElementT,
|
||||
class ElementValueT = typename AttrElementT::ValueType,
|
||||
class CalculationT =
|
||||
std::function<ElementValueT(ElementValueT, ElementValueT)>>
|
||||
Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
|
||||
Attribute operand2, const CalculationT &calculate,
|
||||
bool is_commutative) {
|
||||
if (operand1.dyn_cast_or_null<AttrElementT>()) {
|
||||
// Scalar op scalar case
|
||||
if (operand2.dyn_cast_or_null<AttrElementT>())
|
||||
return ConstFoldBinaryOpScalarScalar<AttrElementT>(result_type, operand1,
|
||||
operand2, calculate);
|
||||
} else if (auto lhs = operand1.dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
// Splat op splat case
|
||||
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
|
||||
return ConstFoldBinaryOpSplatSplat<AttrElementT>(
|
||||
result_type, lhs.getSplatValue(), rhs.getSplatValue(), calculate);
|
||||
|
||||
// Splat op dense case
|
||||
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (is_commutative) {
|
||||
// Swap the two constant values to fall into the following case
|
||||
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand2,
|
||||
operand1, calculate);
|
||||
}
|
||||
}
|
||||
} else if (auto lhs = operand1.dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
// Dense op splat case
|
||||
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
|
||||
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand1,
|
||||
operand2, calculate);
|
||||
|
||||
// Dense op dense case
|
||||
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>())
|
||||
return ConstFoldBinaryOpDenseDense<AttrElementT>(result_type, operand1,
|
||||
operand2, calculate);
|
||||
}
|
||||
|
||||
// TODO: support other attribute kinds
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Performs const folding with broadcast behavior on the two attributes in
|
||||
/// `operands` and returns the result if possible.
|
||||
/// Depending on the given `resultType`, either `floatCalculate` or
|
||||
/// `intCalculate` is chosen to conduct the calculate.
|
||||
Attribute ConstFoldBinaryOp(
|
||||
Type result_type, ArrayRef<Attribute> operands,
|
||||
std::function<APFloat(APFloat, APFloat)> float_calculate,
|
||||
std::function<APInt(APInt, APInt)> int_calculate, bool is_commutative) {
|
||||
// Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
|
||||
// represented as tensor<f32>. So we are only handling tensor types here.
|
||||
auto type = result_type.dyn_cast<ShapedType>();
|
||||
if (!type) return {};
|
||||
|
||||
auto elemType = type.getElementType();
|
||||
|
||||
if (elemType.isa<FloatType>())
|
||||
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
||||
float_calculate, is_commutative);
|
||||
|
||||
if (elemType.isa<IntegerType>())
|
||||
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
||||
int_calculate, is_commutative);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
void buildComparisonBinOp(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
if (!result_type)
|
||||
emitError(result->location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
result->addOperands({lhs, rhs});
|
||||
auto resultShape = result_type.cast<ShapedType>().getShape();
|
||||
// Comparison binary ops always return i1 tensor.
|
||||
result->types.push_back(
|
||||
builder->getTensorType(resultShape, builder->getI1Type()));
|
||||
}
|
||||
|
||||
void buildFusedBroadcastableBinOp(Builder *builder, OperationState *result,
|
||||
Value *lhs, Value *rhs,
|
||||
StringAttr fused_activation_function) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
|
||||
if (!result_type)
|
||||
emitError(result->location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
|
||||
result->addOperands({lhs, rhs});
|
||||
result->addAttribute("fused_activation_function", fused_activation_function);
|
||||
result->types.push_back(result_type);
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a + b; },
|
||||
[](APInt a, APInt b) { return a + b; }, getOperation()->isCommutative());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatenationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO(ashwinm): Implement shape inference for Concatenation
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildGatherOp(Builder *builder, OperationState *result,
|
||||
Value *params, Value *indices, IntegerAttr axis) {
|
||||
auto params_type = params->getType().cast<TensorType>();
|
||||
auto indices_type = indices->getType().cast<TensorType>();
|
||||
|
||||
// If params/indices is unranked, then output is unranked.
|
||||
if (!params_type.hasRank() || !indices_type.hasRank())
|
||||
return TFL::GatherOp::build(
|
||||
builder, result, builder->getTensorType(params_type.getElementType()),
|
||||
params, indices, axis);
|
||||
|
||||
int64_t params_rank = params_type.getRank();
|
||||
int64_t indices_rank = indices_type.getRank();
|
||||
|
||||
// params rank is guaranteed to be at least 1.
|
||||
// Produces an output tensor with shape:
|
||||
// params.shape[:axis] + indices.shape + params.shape[axis + 1:]
|
||||
std::vector<int64_t> shape(params_type.getShape());
|
||||
int64_t axis_i = axis.getInt();
|
||||
|
||||
// For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
|
||||
if (axis_i < 0) {
|
||||
axis_i += params_rank;
|
||||
}
|
||||
|
||||
// params must be atleast rank axis + 1
|
||||
if (params_rank < axis_i + 1) {
|
||||
emitError(result->location, "params must be atleast rank axis + 1");
|
||||
}
|
||||
|
||||
if (indices_rank == 0) {
|
||||
// Scalar indices (output is rank(params) - 1).
|
||||
// Erase shape[axis]
|
||||
shape.erase(shape.begin() + axis_i);
|
||||
} else if (indices_rank == 1) {
|
||||
// Vector indices (output is rank(params)).
|
||||
// Copy indices.shape into params.shape[axis]
|
||||
std::copy(std::begin(indices_type.getShape()),
|
||||
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
|
||||
} else {
|
||||
// Higher rank indices (output is rank(params) + rank(indices) - 1).
|
||||
shape.resize(params_rank + indices_rank - 1);
|
||||
// Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
|
||||
std::copy(std::begin(params_type.getShape()) + axis_i + 1,
|
||||
std::end(params_type.getShape()),
|
||||
std::begin(shape) + axis_i + indices_rank);
|
||||
|
||||
// Copy indices.shape into params.shape[axis]
|
||||
std::copy(std::begin(indices_type.getShape()),
|
||||
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
|
||||
}
|
||||
|
||||
TFL::GatherOp::build(
|
||||
builder, result,
|
||||
builder->getTensorType(shape, params_type.getElementType()), params,
|
||||
indices, axis);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a * b; },
|
||||
[](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133486129): Implement shape inference for pack
|
||||
|
||||
static LogicalResult Verify(PackOp op) {
|
||||
// TODO(antiagainst): Implement other checks as in
|
||||
// tensorflow/lite/kernels/pack.cc
|
||||
|
||||
if (op.getOperation()->getNumOperands() != op.values_count())
|
||||
return op.emitOpError("input count should match 'values_count' attribute");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This pattern matches and merges a tfl.reshape under the following
|
||||
/// condition:
|
||||
/// * The input's defining op is another tfl.reshape.
|
||||
// TODO(antiagainst): This pattern probably should be moved to the peephole
|
||||
// category, after we have the infra for peephole passes.
|
||||
struct RemoveAdjacentReshape : public RewritePattern {
|
||||
RemoveAdjacentReshape(MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = thisOp.getOperand()->getDefiningOp();
|
||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand()->getDefiningOp());
|
||||
|
||||
// Replace
|
||||
// %1 = "tfl.reshape"(%0)
|
||||
// %2 = "tfl.reshape"(%1)
|
||||
// With
|
||||
// %2 = "tfl.reshape"(%0)
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(
|
||||
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand());
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Remove identity reshape.
|
||||
if (getType() == getOperand()->getType()) return getOperand();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<RemoveAdjacentReshape>(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a - b; },
|
||||
[](APInt a, APInt b) { return a - b; }, getOperation()->isCommutative());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopKOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildTopKOp(Builder *builder, OperationState *result, Value *input,
|
||||
Value *k) {
|
||||
// Output size is only known if k is constant value. A negative dimension is
|
||||
// considered dynamic so use -1 here if k is not a constant value.
|
||||
int const_k = -1;
|
||||
ElementsAttr cst;
|
||||
if (matchPattern(k, m_Constant(&cst)))
|
||||
// These casts should all be valid due to how Tensor constants are stored.
|
||||
// TODO(jpienaar): This should use a helper function.
|
||||
const_k = cst.getValue({}).cast<IntegerAttr>().getValue().getSExtValue();
|
||||
|
||||
auto val_type = input->getType().cast<TensorType>();
|
||||
// If value is unranked, then so is results.
|
||||
if (!val_type.hasRank())
|
||||
return TFL::TopKV2Op::build(
|
||||
builder, result, builder->getTensorType(val_type.getElementType()),
|
||||
builder->getTensorType(builder->getIntegerType(32)), input, k);
|
||||
|
||||
// Resultant shape is value.shape[:-1] + [k]
|
||||
std::vector<int64_t> shape(val_type.getShape());
|
||||
shape[shape.size() - 1] = const_k;
|
||||
TFL::TopKV2Op::build(
|
||||
builder, result, builder->getTensorType(shape, val_type.getElementType()),
|
||||
builder->getTensorType(shape, builder->getIntegerType(32)), input, k);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Return true if the op has non-empty "minmax" attribute.
|
||||
static inline bool HasValidMinMaxAttribute(Operation *op) {
|
||||
auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
|
||||
return minmax && minmax.getValue().size() == 2;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// This pattern matches and remove a tfl.fake_quant if all the users of this op
|
||||
/// and itself have "minmax" attribute set.
|
||||
struct DropFakeQuant : public RewritePattern {
|
||||
explicit DropFakeQuant(MLIRContext *context)
|
||||
: RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
// We only match the op with valid "minmax" attribute.
|
||||
if (!HasValidMinMaxAttribute(op)) return matchFailure();
|
||||
|
||||
// If all the users of this op have valid "minmax" attributes, it is matched
|
||||
// and can be removed.
|
||||
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
||||
for (auto *operand : fakeQuantOp.getResult()->getUsers())
|
||||
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
// Replace the matched FakeQuantOp by its primiary operand.
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<DropFakeQuant>(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133486129): Implement shape inference for unpack
|
||||
|
||||
static LogicalResult Verify(UnpackOp op) {
|
||||
// TODO(antiagainst): Implement other checks as in
|
||||
// tensorflow/lite/kernels/unpack.cc
|
||||
|
||||
if (op.getOperation()->getNumResults() != op.num())
|
||||
return op.emitOpError("output count should match 'num' attribute");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MeanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133854225): Implement shape inference to Mean
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
46
tensorflow/compiler/mlir/lite/ir/tfl_ops.h
Normal file
46
tensorflow/compiler/mlir/lite/ir/tfl_ops.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the MLIR TensorFlow Lite dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
class TensorFlowLiteDialect : public Dialect {
|
||||
public:
|
||||
explicit TensorFlowLiteDialect(MLIRContext *context);
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
1988
tensorflow/compiler/mlir/lite/ir/tfl_ops.td
Normal file
1988
tensorflow/compiler/mlir/lite/ir/tfl_ops.td
Normal file
File diff suppressed because it is too large
Load Diff
127
tensorflow/compiler/mlir/lite/ir/tfl_traits.h
Normal file
127
tensorflow/compiler/mlir/lite/ir/tfl_traits.h
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TFL {
|
||||
|
||||
using QuantizedType = mlir::quant::QuantizedType;
|
||||
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
|
||||
|
||||
// The base class that all the quantization related OpTrait implements.
|
||||
template <typename ConcreteType, template <typename> class TraitType>
|
||||
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
static bool IsBias(int index) { return false; }
|
||||
static bool IsQuantizable() { return true; }
|
||||
};
|
||||
|
||||
// This class provides the API for TFL ops that requires same input and output
|
||||
// scale as the quantization results. This is used as a trait like this:
|
||||
//
|
||||
// class TransposeOp
|
||||
// : public Op<TransposeOp, OpTrait::TFL::SameOperandsAndResultsScale> {
|
||||
//
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultsScale
|
||||
: public QuantizationSpecTraitBase<ConcreteType,
|
||||
SameOperandsAndResultsScale> {};
|
||||
|
||||
// This class provides the API for TFL ops that has a fixed output value range.
|
||||
// This is used as a trait like this:
|
||||
//
|
||||
// class SoftmaxOp
|
||||
// : public Op<SoftmaxOp,
|
||||
// OpTrait::TFL::FixedResultUniformScale<
|
||||
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
|
||||
//
|
||||
// TODO(fengliuai): create a better way to epxress floating point scale in the
|
||||
// template argument list.
|
||||
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
|
||||
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
|
||||
class FixedResultUniformScale {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public QuantizationSpecTraitBase<
|
||||
ConcreteType, FixedResultUniformScale<
|
||||
BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
|
||||
StorageTypeMin, StorageTypeMax, Sign>::Impl> {
|
||||
public:
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index)->getType().template cast<TensorType>();
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
::exp10(static_cast<double>(ScaleExp));
|
||||
return UniformQuantizedType::getChecked(
|
||||
Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
|
||||
StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// This class provides the API for TFL ops that has input as bias. This is used
|
||||
// as a trait like this:
|
||||
//
|
||||
// class Conv2DOp
|
||||
// : public Op<Conv2DOp, OpTrait::TFL::AccumulatorScale<2, 0, 1>::Impl> {
|
||||
//
|
||||
// TODO(fengliuai): supports a configurable accumulator bit width.
|
||||
template <int Bias, int... Operands>
|
||||
class AccumulatorUniformScale {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public QuantizationSpecTraitBase<
|
||||
ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
|
||||
public:
|
||||
// Whether the index-th operand is a bias.
|
||||
static bool IsBias(int index) { return index == Bias; }
|
||||
|
||||
// Returns the indexes of all the non-bias operands.
|
||||
static std::vector<int> GetAllNonBiasOperands() {
|
||||
return std::vector<int>({Operands...});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// This class provides the API for TFL ops that shouldn't be quantized. This is
|
||||
// used as a trait like this:
|
||||
//
|
||||
// class LessOp : public Op<LessOp, OpTrait::TFL::NoQuantizableResult> {
|
||||
//
|
||||
template <typename ConcreteType>
|
||||
class NoQuantizableResult
|
||||
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
|
||||
public:
|
||||
static bool IsQuantizable() { return false; }
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
139
tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
Normal file
139
tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Tool to run a TFLite computation from a MLIR input using the TFLite
|
||||
// interpreter.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> inputFileName(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// TODO(jpienaar): Move these functions to some debug utils.
|
||||
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
|
||||
auto begin = tensor.dims ? tensor.dims->data : nullptr;
|
||||
auto end = tensor.dims ? tensor.dims->data + tensor.dims->size : nullptr;
|
||||
return absl::StrJoin(begin, end, ", ");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::string TfLiteTypedTensorString(const TfLiteTensor& tensor) {
|
||||
const T* data = reinterpret_cast<T*>(tensor.data.raw);
|
||||
if (!data) return "<null>";
|
||||
int count = tensor.bytes / sizeof(T);
|
||||
return absl::StrJoin(data, data + count, ", ");
|
||||
}
|
||||
|
||||
// TODO(jpienaar): This really feels like something that should exist already.
|
||||
static std::string TfLiteTensorString(const TfLiteTensor& tensor) {
|
||||
switch (tensor.type) {
|
||||
case kTfLiteInt32:
|
||||
return TfLiteTypedTensorString<int32_t>(tensor);
|
||||
case kTfLiteInt64:
|
||||
return TfLiteTypedTensorString<int64_t>(tensor);
|
||||
case kTfLiteFloat32:
|
||||
return TfLiteTypedTensorString<float>(tensor);
|
||||
default:
|
||||
LOG(QFATAL) << "Unsupported type: " << TfLiteTypeGetName(tensor.type);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
llvm::PrettyStackTraceProgram x(argc, argv);
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
|
||||
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
|
||||
<< "': " << error.message() << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Load the MLIR module.
|
||||
mlir::MLIRContext context;
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
|
||||
std::unique_ptr<mlir::Module> module(
|
||||
mlir::parseSourceFile(source_mgr, &context));
|
||||
if (!module) return 1;
|
||||
|
||||
// TODO(jpienaar): Expand to support inputs.
|
||||
mlir::Function* main = module->getNamedFunction("main");
|
||||
QCHECK(main) << "No 'main' function specified.";
|
||||
if (main->getType().getNumInputs() != 0)
|
||||
LOG(QFATAL) << "NYI: Only nullary functions supported.";
|
||||
|
||||
// Convert to flatbuffer.
|
||||
std::string serialized_flatbuffer;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops))
|
||||
return 1;
|
||||
|
||||
// Create TFLite interpreter & invoke converted program.
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
tflite::FlatBufferModel::BuildFromBuffer(serialized_flatbuffer.c_str(),
|
||||
serialized_flatbuffer.size());
|
||||
tflite::ops::builtin::BuiltinOpResolver builtins;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
QCHECK(tflite::InterpreterBuilder(*model, builtins)(&interpreter) ==
|
||||
kTfLiteOk);
|
||||
QCHECK(interpreter->AllocateTensors() == kTfLiteOk);
|
||||
QCHECK(interpreter->Invoke() == kTfLiteOk);
|
||||
|
||||
// Print the resulting outputs.
|
||||
// TODO(jpienaar): Allow specifying output stream/file.
|
||||
QCHECK(interpreter->outputs().size() == main->getType().getNumResults());
|
||||
for (int index : interpreter->outputs()) {
|
||||
const auto& out = *interpreter->tensor(index);
|
||||
// Print name if named.
|
||||
if (out.name) fprintf(stdout, "%s: ", out.name);
|
||||
// Print tensor result.
|
||||
fprintf(stdout, "Tensor<type: %s, shape: %s, values: %s>\n",
|
||||
TfLiteTypeGetName(out.type), TfLiteTensorDimString(out).c_str(),
|
||||
TfLiteTensorString(out).c_str());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
292
tensorflow/compiler/mlir/lite/operator_writer_gen.cc
Normal file
292
tensorflow/compiler/mlir/lite/operator_writer_gen.cc
Normal file
@ -0,0 +1,292 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
using llvm::formatv;
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::RecordRecTy;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringInit;
|
||||
using llvm::StringRef;
|
||||
|
||||
// Returns the associated option name for the given op definition.
|
||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||
assert(def.getName().endswith("Op") && "unexpected op suffix");
|
||||
|
||||
auto *custom_option = dyn_cast<StringInit>(def.getValueInit("customOption"));
|
||||
std::ostringstream oss;
|
||||
if (custom_option)
|
||||
oss << custom_option->getValue().str();
|
||||
else
|
||||
oss << def.getName().drop_front(4).drop_back(2).str() << "Options";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Returns the builder function name for the given op definition.
|
||||
static inline std::string GetOperatorBuilderName(StringRef op_name) {
|
||||
assert(op_name.startswith("TFL_") && "unexpected op prefix");
|
||||
assert(op_name.endswith("Op") && "unexpected op suffix");
|
||||
|
||||
// E.g., AddOp -> CreateAddOperator
|
||||
std::ostringstream oss;
|
||||
oss << "Create" << op_name.drop_front(4).str() << "erator";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
static void EmitOptionBuilders(const RecordKeeper &record_keeper,
|
||||
const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
raw_ostream &os = *ostream;
|
||||
|
||||
const auto attr_type = record_keeper.getClass("Attr");
|
||||
for (const auto *def : defs) {
|
||||
// TFLite ops without options are skipped over.
|
||||
if (!def->getValueAsBit("hasOptions")) continue;
|
||||
|
||||
StringRef op_name = def->getName().drop_front(4); // Strip 'TFL_' prefix
|
||||
std::string option_name = GetOperatorOptionName(*def);
|
||||
|
||||
os << "flatbuffers::Offset<tflite::" << option_name << "> Create"
|
||||
<< option_name << "(mlir::TFL::" << op_name
|
||||
<< " op, flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
// Construct all the builder option needed.
|
||||
SmallVector<std::string, 8> options;
|
||||
// Add options due to attributes (not-derived).
|
||||
auto *arg_values = def->getValueAsDag("arguments");
|
||||
for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) {
|
||||
auto arg = arg_values->getArg(i);
|
||||
DefInit *arg_def = dyn_cast<DefInit>(arg);
|
||||
if (!arg_def) continue;
|
||||
if (arg_def->getDef()->isSubClassOf(attr_type)) {
|
||||
// This binds the name of the attribute in the TD file with the name
|
||||
// of the add function of the builder and also with the conversion
|
||||
// function to convert from the internal representation to the format
|
||||
// expected by the flatbuffer builder. While this constrains the
|
||||
// naming of the ops/attributes in the TD file, this also removes the
|
||||
// need for specifying indirection. This tool is specific to TFLite
|
||||
// conversion generation and so the simplicity was chosen over the
|
||||
// flexibility.
|
||||
StringRef arg_name = arg_values->getArgNameStr(i);
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName());
|
||||
options.push_back(arg_name.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Add options due to derived attributes.
|
||||
for (const auto &val : def->getValues()) {
|
||||
if (auto *record = dyn_cast<RecordRecTy>(val.getType())) {
|
||||
if (record->isSubClassOf(attr_type)) {
|
||||
if (record->getClasses().size() != 1) {
|
||||
PrintFatalError(
|
||||
def->getLoc(),
|
||||
"unsupported attribute modelling, only single class expected");
|
||||
}
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
val.getName(), record->getClasses()[0]->getName());
|
||||
options.push_back(val.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
os << " tflite::" << option_name << "Builder b(*fbb);\n";
|
||||
for (const auto &option : options)
|
||||
os << formatv(" b.add_{0}(std::move({0}));\n", option);
|
||||
os << " return b.Finish();\n}\n";
|
||||
}
|
||||
}
|
||||
|
||||
// For each TFLite op, emits a builder function that packs the TFLite op into
|
||||
// the corresponding FlatBuffer object.
|
||||
//
|
||||
// TODO(hinsu): Revisit if only builtin_options and mutating_variable_inputs
|
||||
// arguments that depend on op definitions should be auto-generated and then
|
||||
// operator should be built by the caller because it does not require
|
||||
// auto-generation.
|
||||
static void EmitOperatorBuilders(const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
raw_ostream &os = *ostream;
|
||||
|
||||
for (const auto *def : defs) {
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
|
||||
// Signature
|
||||
os << "static flatbuffers::Offset<tflite::Operator> "
|
||||
<< GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name
|
||||
<< " tflOp, uint32_t opcode_index, "
|
||||
<< "const std::vector<int32_t>& operands,"
|
||||
<< "const std::vector<int32_t>& results,"
|
||||
<< "flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
// Inputs & outputs
|
||||
os << " auto inputs = fbb->CreateVector(operands);\n"
|
||||
" auto outputs = fbb->CreateVector(results);\n\n";
|
||||
|
||||
// Build the FlatBuffer operator
|
||||
os << " return tflite::CreateOperator(\n"
|
||||
" *fbb, opcode_index, inputs, outputs,\n";
|
||||
if (def->getValueAsBit("hasOptions")) {
|
||||
auto option_name = GetOperatorOptionName(*def);
|
||||
os << " tflite::BuiltinOptions_" << option_name << ", "
|
||||
<< "Create" << option_name << "(tflOp, fbb).Union(),\n";
|
||||
} else {
|
||||
os << " tflite::BuiltinOptions_NONE, /*builtin_options=*/0,\n";
|
||||
}
|
||||
// Only builtin ops' builders are auto-generated. custom_options are only
|
||||
// used by custom or flex ops and those ops are handled manually.
|
||||
os << " /*custom_options=*/0, "
|
||||
"tflite::CustomOptionsFormat_FLEXBUFFERS,\n"
|
||||
" /*mutating_variable_inputs=*/0);\n"
|
||||
"}\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
static inline std::string GetUpperCasedName(const Record &def) {
|
||||
auto name = def.getValueAsString("opName");
|
||||
return name.upper();
|
||||
}
|
||||
|
||||
// Emits a function that returns builtin operator code for each TFLite op.
|
||||
//
|
||||
// The signature of the function is:
|
||||
//
|
||||
// llvm::Optional<tflite::BuiltinOperator>
|
||||
// mlir::GetBuiltinOpCode(mlir::Operation* op);
|
||||
//
|
||||
// TODO(hinsu): Consider converting this to a static constant associative
|
||||
// container instead of a series of if conditions, if required.
|
||||
static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
raw_ostream &os = *ostream;
|
||||
|
||||
os << "llvm::Optional<tflite::BuiltinOperator> "
|
||||
"mlir::GetBuiltinOpCode(mlir::Operation* op) {\n";
|
||||
|
||||
for (const auto *def : defs) {
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
os << " if (isa<mlir::TFL::" << op_name << ">(op))\n"
|
||||
<< " return tflite::BuiltinOperator_" << GetUpperCasedName(*def)
|
||||
<< ";\n";
|
||||
}
|
||||
|
||||
os << " return llvm::None;\n"
|
||||
"}\n";
|
||||
}
|
||||
|
||||
// Emits a builder function that returns the packed FlatBuffer object given
|
||||
// a general mlir::Operation.
|
||||
//
|
||||
// The signature of the function is:
|
||||
//
|
||||
// llvm::Optional<Flatbuffers::Offset<tflite::Operator>>
|
||||
// mlir::CreateFlatBufferOperator(
|
||||
// mlir::Operation* op,
|
||||
// uint32_t opcode_index,
|
||||
// const std::vector<int32_t>& operands,
|
||||
// const std::vector<int32_t>& results,
|
||||
// flatbuffers::FlatBufferBuilder *fbb);
|
||||
static void EmitBuildOperator(const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
raw_ostream &os = *ostream;
|
||||
|
||||
// Signature
|
||||
os << "llvm::Optional<flatbuffers::Offset<tflite::Operator>>\n"
|
||||
"mlir::CreateFlatBufferOperator(mlir::Operation* op, "
|
||||
"uint32_t opcode_index, "
|
||||
"const std::vector<int32_t>& operands,"
|
||||
"const std::vector<int32_t>& results,"
|
||||
"flatbuffers::FlatBufferBuilder *fbb) {\n";
|
||||
|
||||
for (const auto *def : defs) {
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
|
||||
// Try to cast to each op case and call the corresponding op builder
|
||||
os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name
|
||||
<< ">(op))\n"
|
||||
<< " return " << GetOperatorBuilderName(def->getName())
|
||||
<< "(tflOp, opcode_index, operands, results, fbb);\n";
|
||||
}
|
||||
|
||||
os << " return llvm::None;\n"
|
||||
"}\n";
|
||||
}
|
||||
|
||||
// The function below has a non-constant reference as that is required by LLVM's
|
||||
// TableGenMain.
|
||||
// NOLINTNEXTLINE
|
||||
static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
emitSourceFileHeader("MLIR TFLite FlatBuffer Builders", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("TFL_Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
for (const auto *def : defs) {
|
||||
// TFLite ops in the .td file are expected to follow the naming convention:
|
||||
// TFL_<OpName>Op.
|
||||
// The generated TFLite op C++ class should be TFL::<OpName>Op.
|
||||
// The generated operator's options should be tflite::<OpName>Options.
|
||||
// The option builder should be Create<OpName>Options.
|
||||
if (!def->getName().startswith("TFL_"))
|
||||
PrintFatalError(def->getLoc(),
|
||||
"unexpected op name format: 'TFL_' prefix missing");
|
||||
if (!def->getName().endswith("Op"))
|
||||
PrintFatalError(def->getLoc(),
|
||||
"unexpected op name format: 'Op' suffix missing");
|
||||
}
|
||||
|
||||
EmitOptionBuilders(records, defs, &os);
|
||||
os << "\n\n";
|
||||
EmitOperatorBuilders(defs, &os);
|
||||
os << "\n\n";
|
||||
EmitGetBuiltinOpCode(defs, &os);
|
||||
os << "\n\n";
|
||||
EmitBuildOperator(defs, &os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
|
||||
llvm::PrettyStackTraceProgram X(argc, argv);
|
||||
|
||||
llvm::llvm_shutdown_obj Y;
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
}
|
30
tensorflow/compiler/mlir/lite/python/BUILD
Normal file
30
tensorflow/compiler/mlir/lite/python/BUILD
Normal file
@ -0,0 +1,30 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
packages = [
|
||||
"//tensorflow/lite/toco/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graphdef_to_tfl_flatbuffer",
|
||||
srcs = ["graphdef_to_tfl_flatbuffer.cc"],
|
||||
hdrs = [
|
||||
"graphdef_to_tfl_flatbuffer.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
@ -0,0 +1,142 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
switch (dtype) {
|
||||
case toco::IODataType::FLOAT:
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT32:
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
return DT_INT64;
|
||||
case toco::IODataType::STRING:
|
||||
return DT_STRING;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags) {
|
||||
if (toco_flags.inference_input_type()) {
|
||||
LOG(WARNING) << "Ignored inference_input_type.";
|
||||
}
|
||||
if (toco_flags.output_format()) {
|
||||
LOG(WARNING) << "Ignored output_format.";
|
||||
}
|
||||
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
|
||||
LOG(WARNING) << "Ignored default_ranges_stats.";
|
||||
}
|
||||
if (toco_flags.drop_control_dependency()) {
|
||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||
}
|
||||
if (toco_flags.reorder_across_fake_quant()) {
|
||||
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
|
||||
}
|
||||
if (model_flags.change_concat_input_ranges()) {
|
||||
LOG(WARNING) << "Ignored change_concat_input_ranges.";
|
||||
}
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
LOG(WARNING) << "Ignored post_training_quantize.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_dir().empty()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_dir.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_include_video()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_video.";
|
||||
}
|
||||
if (model_flags.allow_nonexistent_arrays()) {
|
||||
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphDef& input,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
NodeSpecs specs;
|
||||
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<float> node_mins;
|
||||
std::vector<float> node_maxs;
|
||||
tensorflow::DataType inference_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||
for (auto& flag : model_flags.input_arrays()) {
|
||||
node_names.push_back(flag.name());
|
||||
node_dtypes.push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(flag.data_type())));
|
||||
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
|
||||
const float mean_value = flag.mean_value();
|
||||
const float std_value = flag.std_value();
|
||||
const float qmin = 0, qmax = 255;
|
||||
node_mins.push_back((qmin - mean_value) / std_value);
|
||||
node_maxs.push_back((qmax - mean_value) / std_value);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
|
||||
node_names, node_dtypes, node_shapes, inference_type, node_mins,
|
||||
node_maxs, &specs.inputs));
|
||||
|
||||
// Parse output arrays.
|
||||
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
||||
model_flags.output_arrays().end());
|
||||
TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
|
||||
output_arrays, &specs.output_arrays, &specs.output_arrays_order));
|
||||
|
||||
// Other flags.
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
specs.prune_unused_nodes = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
bool emit_quant_adaptor_ops = false;
|
||||
bool lower_tensor_list_ops = false;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
return ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
|
||||
lower_tensor_list_ops, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts the given GraphDef to a TF Lite FlatBuffer string according to the
|
||||
// given model flags, toco flags and debug information. Returns error status if
|
||||
// it fails to convert the input.
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphDef& input, string* result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_
|
19
tensorflow/compiler/mlir/lite/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/lite/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
134
tensorflow/compiler/mlir/lite/tests/broadcastable-trait.mlir
Normal file
134
tensorflow/compiler/mlir/lite/tests/broadcastable-trait.mlir
Normal file
@ -0,0 +1,134 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_scalar_scalar_scalar
|
||||
func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<i32>
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_tensor_scalar_tensor
|
||||
func @broadcast_tensor_scalar_tensor(tensor<4xi32>, tensor<i32>) -> tensor<4xi32> {
|
||||
^bb0(%arg0: tensor<4xi32>, %arg1: tensor<i32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check only one dimension has size 1
|
||||
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
|
||||
return %0 : tensor<4x3x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check multiple dimensions have size 1
|
||||
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
|
||||
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
|
||||
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
|
||||
return %0 : tensor<8x7x6x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check leading unknown dimension
|
||||
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
|
||||
func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
|
||||
^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
|
||||
return %0 : tensor<?x7x6x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check unknown dimension in the middle
|
||||
// CHECK-LABEL: @broadcast_tensor_tensor_tensor
|
||||
func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> {
|
||||
^bb0(%arg0: tensor<8x1x?x1xi32>, %arg1: tensor<7x1x5xi32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
|
||||
return %0 : tensor<8x7x?x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check incompatible vector and tensor result type
|
||||
func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
|
||||
// expected-error @+1 {{cannot broadcast vector with tensor}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check incompatible operand types with known dimension
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x3xi32>):
|
||||
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32>
|
||||
return %0 : tensor<4x3x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check incompatible result type with known dimension
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
|
||||
// expected-error @+1 {{does not have the same shape as the one computed}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
|
||||
return %0 : tensor<4x3x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check incompatible result type with known dimension
|
||||
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
|
||||
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
|
||||
// expected-error @+1 {{does not have the same shape as the one computed}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
|
||||
return %0 : tensor<8x7x6x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
|
||||
return %0 : tensor<4x3x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check incompatible result type with unknown dimension
|
||||
func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
|
||||
^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
|
||||
// expected-error @+1 {{does not have the same shape as the one computed}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
|
||||
return %0 : tensor<8x7x6x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check unranked operand but ranked result
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<*xi32>):
|
||||
// expected-error @+1 {{broadcast unranked tensor should result in unranked tensor}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
|
||||
return %0 : tensor<4x3x2xi32>
|
||||
}
|
92
tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
Normal file
92
tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
Normal file
@ -0,0 +1,92 @@
|
||||
// RUN: tf-opt -canonicalize %s | FileCheck %s
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output's only user is
|
||||
// another tfl.reshape
|
||||
func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
|
||||
^bb0(%arg0: tensor<4x4x4xf32>) :
|
||||
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
|
||||
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
|
||||
return %1 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacent
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output has more than one
|
||||
// user but all users are tfl.reshape
|
||||
func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32> {
|
||||
^bb0(%arg0: tensor<4x4x4xf32>) :
|
||||
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
|
||||
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
|
||||
%2 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
|
||||
%3 = addf %1, %2 : tensor<64xf32>
|
||||
return %3 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
// CHECK: %2 = addf %0, %1
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be kept if its output has more than one
|
||||
// user and not all users are tfl.reshape
|
||||
func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32>, tensor<64xf32>) {
|
||||
^bb0(%arg0: tensor<4x4x4xf32>) :
|
||||
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
|
||||
%1 = "tfl.reshape"(%0) : (tensor<16x4xf32>) -> tensor<64xf32>
|
||||
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<16x4xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output type is the same
|
||||
// as its input type
|
||||
func @reshape_removeIdentity(tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
|
||||
^bb0(%arg0: tensor<4x4x4xf32>) :
|
||||
%0 = "tfl.reshape"(%arg0) : (tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
|
||||
return %0 : tensor<4x4x4xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeIdentity
|
||||
// CHECK: return %arg0 : tensor<4x4x4xf32>
|
||||
}
|
||||
|
||||
// Checks that tfl.fake_quant should be removed if all its users have valid
|
||||
// "minmax" attributes.
|
||||
func @fakequant_dropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
|
||||
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%1 = tfl.pow %arg0, %0 {minmax = [0.4, 0.6]} : tensor<i32>
|
||||
%2 = tfl.pow %1, %0 {minmax = [0.5, 0.7]} : tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
|
||||
// CHECK-LABEL: fakequant_dropfakequant
|
||||
// CHECK-NEXT: %0 = tfl.pow %arg0, %arg0 {minmax = [4.000000e-01, 6.000000e-01]} : tensor<i32>
|
||||
// CHECK-NEXT: %1 = tfl.pow %0, %arg0 {minmax = [5.000000e-01, 0.69999999999999996]} : tensor<i32>
|
||||
|
||||
// CHECK-NEXT: return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// Checks that tfl.fake_quant should not be removed if some of its users or
|
||||
// itself don't have valid "minmax" attributes.
|
||||
func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
|
||||
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%1 = tfl.pow %arg0, %0 : tensor<i32>
|
||||
%2 = tfl.pow %1, %0 : tensor<i32>
|
||||
|
||||
%5 = "tfl.fake_quant"(%arg0) {name = 1, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%6 = tfl.pow %arg0, %5 : tensor<i32>
|
||||
%7 = tfl.pow %6, %5 : tensor<i32>
|
||||
|
||||
%11 = addi %2, %7 : tensor<i32>
|
||||
return %11 : tensor<i32>
|
||||
|
||||
// CHECK-LABEL: fakequant_notdropfakequant
|
||||
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], name = 0 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %3 = "tfl.fake_quant"(%arg0) {minmax = [1.000000e-01, 2.000000e-01], name = 1 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
|
||||
}
|
275
tensorflow/compiler/mlir/lite/tests/const-fold.mlir
Normal file
275
tensorflow/compiler/mlir/lite/tests/const-fold.mlir
Normal file
@ -0,0 +1,275 @@
|
||||
// RUN: tf-opt %s -test-constant-fold | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add_float
|
||||
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = constant dense<4.5> : tensor<f32>
|
||||
%1 = constant dense<1.5> : tensor<f32>
|
||||
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%7 = "tfl.add"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
|
||||
%8 = "tfl.add"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%9 = "tfl.add"(%2, %3) {fused_activation_function = "SIGN_BIT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %5, %6, %7, %8, %9 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_int
|
||||
func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = constant dense<8> : tensor<i32>
|
||||
%1 = constant dense<1> : tensor<i32>
|
||||
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<9> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%7 = "tfl.add"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
|
||||
%8 = "tfl.add"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_float
|
||||
func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = constant dense<4.5> : tensor<f32>
|
||||
%1 = constant dense<1.5> : tensor<f32>
|
||||
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%7 = "tfl.sub"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
|
||||
%8 = "tfl.sub"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %5, %6, %7, %8 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_int
|
||||
func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = constant dense<8> : tensor<i32>
|
||||
%1 = constant dense<1> : tensor<i32>
|
||||
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<7> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%7 = "tfl.sub"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
|
||||
%8 = "tfl.sub"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_float
|
||||
func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = constant dense<4.5> : tensor<f32>
|
||||
%1 = constant dense<1.5> : tensor<f32>
|
||||
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%7 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor< f32>) -> tensor<4xf32>
|
||||
%8 = "tfl.mul"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %5, %6, %7, %8 : tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_int
|
||||
func @mul_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = constant dense<8> : tensor<i32>
|
||||
%1 = constant dense<1> : tensor<i32>
|
||||
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK-DAG: [[cst0:%.*]] = constant dense<8> : tensor<i32>
|
||||
// CHECK-DAG: [[cst1:%.*]] = constant dense<-16> : tensor<4xi32>
|
||||
// CHECK-DAG: [[cst2:%.*]] = constant dense<4> : tensor<4xi32>
|
||||
// CHECK-DAG: [[cst3:%.*]] = constant dense<-8> : tensor<4xi32>
|
||||
// CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]]
|
||||
|
||||
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%7 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor< i32>) -> tensor<4xi32>
|
||||
%8 = "tfl.mul"(%2, %3) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %5, %6, %7, %8 : tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_splat_int
|
||||
func @add_dense_splat_int() -> tensor<4xi32> {
|
||||
%0 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
|
||||
%1 = constant dense< 5> : tensor<4xi32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_int
|
||||
func @add_splat_dense_int() -> tensor<4xi32> {
|
||||
%0 = constant dense< 5> : tensor<4xi32>
|
||||
%1 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_same_shape
|
||||
func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
|
||||
%0 = constant dense<[15, 23, -44, -2]> : tensor<4xi32>
|
||||
%1 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
|
||||
func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>) {
|
||||
%cst_0 = constant dense<[10, 20]> : tensor<2xi32>
|
||||
%cst_1 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<[[[1, 1], [2, 2]], [[3, 3], [4, 4]]]> : tensor<2x2x2xi32>
|
||||
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor< 2xi32>, tensor< 2x2xi32>) -> tensor< 2x2xi32>
|
||||
%1 = "tfl.add"(%cst_2, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2xi32>, tensor< 2x2xi32>) -> tensor<2x2x2xi32>
|
||||
%2 = "tfl.add"(%cst_0, %cst_2) {fused_activation_function = "NONE"} : (tensor< 2xi32>, tensor<2x2x2xi32>) -> tensor<2x2x2xi32>
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
|
||||
func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
|
||||
%cst_0 = constant dense<[[1, 2]]> : tensor<1x2xi32>
|
||||
%cst_1 = constant dense<[[3], [4]]> : tensor<2x1xi32>
|
||||
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
return %0 : tensor<2x2xi32>
|
||||
|
||||
// We don't support this case yet.
|
||||
// %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: %0 = "tfl.add"
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_splat_float
|
||||
func @add_dense_splat_float() -> tensor<4xf32> {
|
||||
%0 = constant dense<[-10.0, -1.5, 42.0, 7.25]> : tensor<4xf32>
|
||||
%1 = constant dense< 3.5> : tensor<4xf32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_float
|
||||
func @add_splat_dense_float() -> tensor<4xf32> {
|
||||
%0 = constant dense< 3.5> : tensor<4xf32>
|
||||
%1 = constant dense<[-10.0, -1.5, 42.0, 7.25]> : tensor<4xf32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_same_shape
|
||||
func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
|
||||
%0 = constant dense<[1.5, 2.3, -4.4, -2.0]> : tensor<4xf32>
|
||||
%1 = constant dense<[-10.4, -1.3, 42.4, 100.0]> : tensor<4xf32>
|
||||
|
||||
%2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
|
||||
func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>) {
|
||||
%cst_0 = constant dense<[1., -4.]> : tensor<2xf32>
|
||||
%cst_1 = constant dense<[[-5.5, 1.5], [7.5, -4.5]]> : tensor<2x2xf32>
|
||||
%cst_2 = constant dense<[[[1., 1.], [2., 2.]], [[3., 3.], [4., 4.]]]> : tensor<2x2x2xf32>
|
||||
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor< 2xf32>, tensor< 2x2xf32>) -> tensor< 2x2xf32>
|
||||
%1 = "tfl.add"(%cst_2, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2xf32>, tensor< 2x2xf32>) -> tensor<2x2x2xf32>
|
||||
%2 = "tfl.add"(%cst_0, %cst_2) {fused_activation_function = "NONE"} : (tensor< 2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32>
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
|
||||
func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
%cst_0 = constant dense<[[1.5, -2.5]]> : tensor<1x2xf32>
|
||||
%cst_1 = constant dense<[[-3.], [4.]]> : tensor<2x1xf32>
|
||||
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// We don't support this case yet.
|
||||
// CHECK: %0 = "tfl.add"
|
||||
// CHECK: return %0
|
||||
}
|
31
tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD
Normal file
31
tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD
Normal file
@ -0,0 +1,31 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":debug_info_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = ["pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all the debug info files that are used by the tests.
|
||||
filegroup(
|
||||
name = "debug_info_files",
|
||||
srcs = glob(
|
||||
["**/*.debug"],
|
||||
),
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,175 @@
|
||||
files: "fake/user/code/file_A.py"
|
||||
files: "fake/user/code/file_B.py"
|
||||
files: "fake/user/code/file_C.py"
|
||||
files: "fake/user/code/file_D.py"
|
||||
files: "fake/user/code/file_E.py"
|
||||
files: "fake/user/code/file_F.py"
|
||||
files: "fake/user/code/file_G.py"
|
||||
files: "fake/user/code/file_H.py"
|
||||
files: "fake/user/code/file_I.py"
|
||||
files: "fake/user/code/file_J.py"
|
||||
files: "fake/user/code/file_K.py"
|
||||
files: "fake/user/code/file_L.py"
|
||||
files: "fake/user/code/file_M.py"
|
||||
files: "fake/user/code/file_N.py"
|
||||
files: "fake/user/code/file_O.py"
|
||||
files: "fake/user/code/file_P.py"
|
||||
files: "fake/user/code/file_Q.py"
|
||||
files: "fake/user/code/file_R.py"
|
||||
files: "fake/user/code/file_S.py"
|
||||
files: "fake/user/code/file_T.py"
|
||||
files: "fake/user/code/file_U.py"
|
||||
files: "fake/user/code/file_V.py"
|
||||
files: "fake/user/code/file_W.py"
|
||||
files: "fake/user/code/file_X.py"
|
||||
files: "fake/user/code/file_Y.py"
|
||||
files: "fake/user/code/file_Z.py"
|
||||
files: "fake/user/code/file_1.py"
|
||||
files: "fake/user/code/file_2.py"
|
||||
files: "fake/user/code/file_3.py"
|
||||
files: "fake/user/code/file_4.py"
|
||||
files: "fake/user/code/file_5.py"
|
||||
files: "fake/user/code/file_6.py"
|
||||
files: "fake/user/code/file_a.py"
|
||||
files: "fake/user/code/file_b.py"
|
||||
files: "fake/user/code/file_c.py"
|
||||
files: "fake/user/code/file_d.py"
|
||||
files: "fake/user/code/file_e.py"
|
||||
files: "fake/user/code/file_f.py"
|
||||
files: "fake/user/code/file_g.py"
|
||||
files: "fake/user/code/file_h.py"
|
||||
files: "fake/user/code/file_i.py"
|
||||
files: "fake/user/code/file_j.py"
|
||||
files: "fake/user/code/file_k.py"
|
||||
files: "fake/user/code/file_l.py"
|
||||
files: "fake/user/code/file_m.py"
|
||||
files: "fake/user/code/file_n.py"
|
||||
files: "fake/user/code/file_o.py"
|
||||
files: "fake/user/code/file_p.py"
|
||||
files: "fake/user/code/file_q.py"
|
||||
files: "fake/user/code/file_r.py"
|
||||
files: "fake/user/code/file_s.py"
|
||||
files: "fake/user/code/file_t.py"
|
||||
files: "fake/user/code/file_u.py"
|
||||
files: "fake/user/code/file_v.py"
|
||||
files: "fake/user/code/file_w.py"
|
||||
files: "fake/user/code/file_x.py"
|
||||
files: "fake/user/code/file_y.py"
|
||||
files: "fake/user/code/file_z.py"
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 33
|
||||
line: 383
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 886
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 38
|
||||
line: 777
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 915
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
line: 793
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 335
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
line: 386
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 492
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 54
|
||||
line: 649
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 421
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 5
|
||||
line: 362
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 2
|
||||
line: 27
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "input"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 40
|
||||
line: 690
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,187 @@
|
||||
files: "fake/user/code/file_A.py"
|
||||
files: "fake/user/code/file_B.py"
|
||||
files: "fake/user/code/file_C.py"
|
||||
files: "fake/user/code/file_D.py"
|
||||
files: "fake/user/code/file_E.py"
|
||||
files: "fake/user/code/file_F.py"
|
||||
files: "fake/user/code/file_G.py"
|
||||
files: "fake/user/code/file_H.py"
|
||||
files: "fake/user/code/file_I.py"
|
||||
files: "fake/user/code/file_J.py"
|
||||
files: "fake/user/code/file_K.py"
|
||||
files: "fake/user/code/file_L.py"
|
||||
files: "fake/user/code/file_M.py"
|
||||
files: "fake/user/code/file_N.py"
|
||||
files: "fake/user/code/file_O.py"
|
||||
files: "fake/user/code/file_P.py"
|
||||
files: "fake/user/code/file_Q.py"
|
||||
files: "fake/user/code/file_R.py"
|
||||
files: "fake/user/code/file_S.py"
|
||||
files: "fake/user/code/file_T.py"
|
||||
files: "fake/user/code/file_U.py"
|
||||
files: "fake/user/code/file_V.py"
|
||||
files: "fake/user/code/file_W.py"
|
||||
files: "fake/user/code/file_X.py"
|
||||
files: "fake/user/code/file_Y.py"
|
||||
files: "fake/user/code/file_Z.py"
|
||||
files: "fake/user/code/file_1.py"
|
||||
files: "fake/user/code/file_2.py"
|
||||
files: "fake/user/code/file_3.py"
|
||||
files: "fake/user/code/file_4.py"
|
||||
files: "fake/user/code/file_5.py"
|
||||
files: "fake/user/code/file_6.py"
|
||||
files: "fake/user/code/file_a.py"
|
||||
files: "fake/user/code/file_b.py"
|
||||
files: "fake/user/code/file_c.py"
|
||||
files: "fake/user/code/file_d.py"
|
||||
files: "fake/user/code/file_e.py"
|
||||
files: "fake/user/code/file_f.py"
|
||||
files: "fake/user/code/file_g.py"
|
||||
files: "fake/user/code/file_h.py"
|
||||
files: "fake/user/code/file_i.py"
|
||||
files: "fake/user/code/file_j.py"
|
||||
files: "fake/user/code/file_k.py"
|
||||
files: "fake/user/code/file_l.py"
|
||||
files: "fake/user/code/file_m.py"
|
||||
files: "fake/user/code/file_n.py"
|
||||
files: "fake/user/code/file_o.py"
|
||||
files: "fake/user/code/file_p.py"
|
||||
files: "fake/user/code/file_q.py"
|
||||
files: "fake/user/code/file_r.py"
|
||||
files: "fake/user/code/file_s.py"
|
||||
files: "fake/user/code/file_t.py"
|
||||
files: "fake/user/code/file_u.py"
|
||||
files: "fake/user/code/file_v.py"
|
||||
files: "fake/user/code/file_w.py"
|
||||
files: "fake/user/code/file_x.py"
|
||||
files: "fake/user/code/file_y.py"
|
||||
files: "fake/user/code/file_z.py"
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 33
|
||||
line: 383
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 886
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 38
|
||||
line: 777
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 915
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
line: 793
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 335
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
line: 386
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 492
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 54
|
||||
line: 649
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights/read"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
line: 421
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 5
|
||||
line: 362
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 2
|
||||
line: 27
|
||||
}
|
||||
file_line_cols {
|
||||
file_index: 3
|
||||
line: 28
|
||||
}
|
||||
file_line_cols {
|
||||
file_index: 4
|
||||
line: 29
|
||||
}
|
||||
file_line_cols {
|
||||
file_index: 5
|
||||
line: 30
|
||||
}
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "input"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 40
|
||||
line: 690
|
||||
}
|
||||
}
|
||||
}
|
20
tensorflow/compiler/mlir/lite/tests/end2end/BUILD
Normal file
20
tensorflow/compiler/mlir/lite/tests/end2end/BUILD
Normal file
@ -0,0 +1,20 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = ["pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
94
tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt
Normal file
94
tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt
Normal file
@ -0,0 +1,94 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0"
|
||||
input: "input1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: {
|
||||
# CHECK-NEXT: version: 3,
|
||||
# CHECK-NEXT: operator_codes: [ {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: } ],
|
||||
# CHECK-NEXT: subgraphs: [ {
|
||||
# CHECK-NEXT: tensors: [ {
|
||||
# CHECK-NEXT: shape: [ 4 ],
|
||||
# CHECK-NEXT: type: INT32,
|
||||
# CHECK-NEXT: buffer: 1,
|
||||
# CHECK-NEXT: name: "input0",
|
||||
# CHECK-NEXT: quantization: {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-NEXT: shape: [ 4 ],
|
||||
# CHECK-NEXT: type: INT32,
|
||||
# CHECK-NEXT: buffer: 2,
|
||||
# CHECK-NEXT: name: "input1",
|
||||
# CHECK-NEXT: quantization: {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-NEXT: shape: [ ],
|
||||
# CHECK-NEXT: type: INT32,
|
||||
# CHECK-NEXT: buffer: 3,
|
||||
# CHECK-NEXT: name: "Add",
|
||||
# CHECK-NEXT: quantization: {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: } ],
|
||||
# CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
# CHECK-NEXT: outputs: [ 2 ],
|
||||
# CHECK-NEXT: operators: [ {
|
||||
# CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
# CHECK-NEXT: outputs: [ 2 ],
|
||||
# CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
# CHECK-NEXT: builtin_options: {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: } ]
|
||||
# CHECK-NEXT: name: "main"
|
||||
# CHECK-NEXT: } ],
|
||||
# CHECK-NEXT: description: "MLIR Converted.",
|
||||
# CHECK-NEXT: buffers: [ {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: }, {
|
||||
# CHECK-EMPTY:
|
||||
# CHECK-NEXT: } ]
|
||||
# CHECK-NEXT: }
|
232
tensorflow/compiler/mlir/lite/tests/end2end/conv_2d.pbtxt
Normal file
232
tensorflow/compiler/mlir/lite/tests/end2end/conv_2d.pbtxt
Normal file
@ -0,0 +1,232 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/w"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374<Z\346\204>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/w/read"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d/conv_2d_0/w"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@conv_net_2d/conv_2d_0/w"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/conv_2d_0/convolution"
|
||||
op: "Conv2D"
|
||||
input: "input"
|
||||
input: "conv_net_2d/conv_2d_0/w/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NHWC"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dilations"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "explicit_paddings"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "padding"
|
||||
value {
|
||||
s: "SAME"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "strides"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_cudnn_on_gpu"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/b"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\315\314\314=\315\314\314="
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d/conv_2d_0/b/read"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d/conv_2d_0/b"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@conv_net_2d/conv_2d_0/b"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/conv_2d_0/BiasAdd"
|
||||
op: "BiasAdd"
|
||||
input: "conv_net_2d_1/conv_2d_0/convolution"
|
||||
input: "conv_net_2d/conv_2d_0/b/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NHWC"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "conv_net_2d_1/Relu"
|
||||
op: "Relu"
|
||||
input: "conv_net_2d_1/conv_2d_0/BiasAdd"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output_0"
|
||||
op: "Identity"
|
||||
input: "conv_net_2d_1/Relu"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
}
|
||||
|
||||
# CHECK: 'main' inputs:
|
||||
# CHECK-NEXT: name: 'input'
|
||||
# CHECK-NEXT: 'main' outputs:
|
||||
# CHECK-NEXT: name: 'output_0'
|
@ -0,0 +1,45 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=output %s -o - --output-mlir -tf-extra-opdefs="name: 'BannaPotatoSaladWithColeslaw' input_arg: { name: 'a' type: DT_INT32 } input_arg: { name: 'b' type: DT_INT32 } output_arg: { name: 'c' type: DT_INT32 }" | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "output"
|
||||
op: "BannaPotatoSaladWithColeslaw"
|
||||
input: "input0"
|
||||
input: "input1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg1) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK-NEXT: %2 = "tf.BannaPotatoSaladWithColeslaw"(%0, %1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %2 : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
@ -0,0 +1,54 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=4 -tf-input-data-types=DT_INT32 -tf-output-arrays=output %s -o - --output-mlir | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "default"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input"
|
||||
op: "Identity"
|
||||
input: "default"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output"
|
||||
op: "Identity"
|
||||
input: "input"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
# CHECK-NEXT: return %0 : tensor<4xi32>
|
||||
# CHECK-NEXT: }
|
757
tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
Normal file
757
tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
Normal file
@ -0,0 +1,757 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %7: tensor<1xi32>
|
||||
|
||||
// CHECK-LABEL: addRelu
|
||||
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
||||
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: return %4 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: LeakyRelu
|
||||
// CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha = 1.000000e-01 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
}
|
||||
|
||||
func @biasAdd(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
%1 = "tf.BiasAdd"(%0, %arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
|
||||
%2 = "tf.Relu6"(%1) : (tensor<1x10x10x32xf32>) -> tensor<1x10x10x32xf32>
|
||||
return %2 : tensor<1x10x10x32xf32>
|
||||
|
||||
// CHECK-LABEL: biasAdd
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
|
||||
// CHECK: %1 = tfl.add %0, %arg0 {fused_activation_function = "RELU6"} : tensor<1x10x10x32xf32>
|
||||
}
|
||||
|
||||
func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> {
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xi32>, tensor<32xi32>) -> tensor<1x10x10x32xi32>
|
||||
return %0 : tensor<1x10x10x32xi32>
|
||||
|
||||
// CHECK-LABEL: biasAddInt
|
||||
// CHECK: %0 = "tf.BiasAdd"(%arg0, %arg1)
|
||||
}
|
||||
|
||||
func @sqeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
|
||||
%0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%2 = constant dense<[2, 5]> : tensor<2xi32>
|
||||
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
return %4 : i32
|
||||
// CHECK-LABEL: sqeezeAndReshape
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
// CHECK: %1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
// CHECK: %2 = "tfl.reshape"(%0) : (tensor<1x10xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
// CHECK: return %3 : i32
|
||||
}
|
||||
|
||||
func @dynamicReshape(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
// CHECK-LABEL: dynamicReshape
|
||||
// CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
// CHECK: return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
func @avgPool2D(%arg0: tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
// OK
|
||||
%0 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
// Unsupported data format
|
||||
%1 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
// Unsupported ksize
|
||||
%2 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [3, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
// Unsupported strides
|
||||
%3 = "tf.AvgPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 3]} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
|
||||
%5 = addf %0, %1 : tensor<1x1x1x16xf32>
|
||||
%6 = addf %2, %3 : tensor<1x1x1x16xf32>
|
||||
%7 = addf %5, %6 : tensor<1x1x1x16xf32>
|
||||
return %7 : tensor<1x1x1x16xf32>
|
||||
|
||||
// CHECK-LABEL: func @avgPool2D
|
||||
// CHECK: %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
// CHECK: %1 = "tf.AvgPool"(%arg0)
|
||||
// CHECK: %2 = "tf.AvgPool"(%arg0)
|
||||
// CHECK: %3 = "tf.AvgPool"(%arg0)
|
||||
}
|
||||
|
||||
func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: softmax
|
||||
// CHECK: %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantArgsFalse
|
||||
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = true} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantArgsTrue
|
||||
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8<1:255>:f32, 0.001181102379804521:86>>) -> tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%arg1 = constant dense<-0.1> : tensor<f32>
|
||||
%arg2 = constant dense<0.2> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantVarsFalse
|
||||
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantVarsTrue(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<8x8x8x8xf32> {
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {min = 0.0 : f32, max = 1.0 : f32, num_bits = 3, narrow_range = true} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantVarsTrue
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32, narrow_range = true, num_bits = 3 : i64}
|
||||
}
|
||||
|
||||
func @const() -> tensor<2xi32> {
|
||||
%0 = "tf.Const"() {device = "", name = "weights_quant/min", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
return %0: tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: @const
|
||||
// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
}
|
||||
|
||||
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
|
||||
// CHECK-LABEL: @placeholder
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_min
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_type(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", type = i8} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_type
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_min_max(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32, max = 0.1 : f32, type = i8} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %0: tensor<2x3xf32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_min_max
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x3x!quant.uniform<u8:f32, 7.8431373717738134E-4:128>>}
|
||||
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<2x3x!quant.uniform<u8:f32, 7.8431373717738134E-4:128>>)
|
||||
}
|
||||
|
||||
func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
|
||||
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
%1 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
%2 = "tf.Add"(%0, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %2: tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: shape
|
||||
// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
// CHECK: %1 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
}
|
||||
|
||||
func @fill(%arg0: tensor<3xi32>, %arg1: tensor<f32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "tf.Fill"(%arg0, %arg1) : (tensor<3xi32>, tensor<f32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
|
||||
// CHECK-LABEL:fill
|
||||
// CHECK: %0 = "tfl.fill"(%arg0, %arg1) : (tensor<3xi32>, tensor<f32>) -> tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @sigmoid(%arg0: tensor<?x88xf16>) -> tensor<?x88xf16> {
|
||||
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
|
||||
return %0 : tensor<?x88xf16>
|
||||
// CHECK-LABEL: sigmoid
|
||||
// CHECK: %0 = "tfl.logistic"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
|
||||
}
|
||||
|
||||
func @log_softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.LogSoftmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
// CHECK-LABEL: log_softmax
|
||||
// CHECK: %0 = "tfl.log_softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.ZerosLike"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
// CHECK-LABEL: zeros_like
|
||||
// CHECK: %0 = "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %5: tensor<1xi32>
|
||||
|
||||
// CHECK-LABEL: divRelu
|
||||
// CHECK: %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
||||
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
|
||||
^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
|
||||
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1: tensor<1xi32>
|
||||
|
||||
// CHECK-LABEL: squaredDifferenceRelu
|
||||
// CHECK: %0 = tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
|
||||
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @maxPool2D(%arg0: tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
// OK
|
||||
%0 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
|
||||
// Unsupported data_format
|
||||
%1 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
|
||||
// Unsupported ksize
|
||||
%2 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [3, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 1]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
|
||||
// Unsupported strides
|
||||
%3 = "tf.MaxPool"(%arg0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", ksize = [1, 3, 6, 1], padding = "VALID", strides = [1, 3, 1, 3]} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
|
||||
%5 = addf %0, %1 : tensor<1x1x1x16xf32>
|
||||
%6 = addf %2, %3 : tensor<1x1x1x16xf32>
|
||||
%7 = addf %5, %6 : tensor<1x1x1x16xf32>
|
||||
return %7 : tensor<1x1x1x16xf32>
|
||||
|
||||
// CHECK-LABEL: func @maxPool2D
|
||||
// CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x1x1x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
// CHECK: %1 = "tf.MaxPool"(%arg0)
|
||||
// CHECK: %2 = "tf.MaxPool"(%arg0)
|
||||
// CHECK: %3 = "tf.MaxPool"(%arg0)
|
||||
}
|
||||
|
||||
func @abs(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL:abs
|
||||
// CHECK: %0 = "tfl.abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: ceil
|
||||
// CHECK: %0 = "tfl.ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK: return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @cos(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Cos"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
|
||||
// CHECK-LABEL:cos
|
||||
// CHECK: %0 = "tfl.cos"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @elu(%arg0: tensor<11x16xf32>) -> tensor<11x16xf32> {
|
||||
%0 = "tf.Elu"(%arg0) : (tensor<11x16xf32>) -> tensor<11x16xf32>
|
||||
return %0 : tensor<11x16xf32>
|
||||
|
||||
// CHECK-LABEL:elu
|
||||
// CHECK: %0 = "tfl.elu"(%arg0) : (tensor<11x16xf32>) -> tensor<11x16xf32>
|
||||
}
|
||||
|
||||
func @expandDims(%arg0: tensor<2x2xf32>, %arg1: tensor<i32>) -> tensor<1x2x2xf32> {
|
||||
%0 = "tf.ExpandDims"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<i32>) -> tensor<1x2x2xf32>
|
||||
return %0 : tensor<1x2x2xf32>
|
||||
|
||||
// CHECK-LABEL:expandDims
|
||||
// CHECK: %0 = "tfl.expand_dims"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<i32>) -> tensor<1x2x2xf32>
|
||||
}
|
||||
|
||||
func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor<i32>) -> tensor<2xf32> {
|
||||
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
|
||||
// CHECK-LABEL:gatherScalarIndices
|
||||
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
|
||||
}
|
||||
|
||||
func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tensor<3xf32> {
|
||||
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
|
||||
// CHECK-LABEL:gatherVectorIndices
|
||||
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
|
||||
}
|
||||
|
||||
func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>) -> tensor<4x5x3x6xf32> {
|
||||
%0 = "tf.Gather"(%arg0, %arg1) : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
|
||||
return %0 : tensor<4x5x3x6xf32>
|
||||
|
||||
// CHECK-LABEL:gatherHigherRankIndices
|
||||
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
|
||||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices
|
||||
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices
|
||||
// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = -1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
func @greater(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
return %0 : tensor<8x16xi1>
|
||||
|
||||
// CHECK-LABEL: greater
|
||||
// CHECK: %0 = "tfl.greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0 : tensor<8x16xi1>
|
||||
}
|
||||
|
||||
func @greater_equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
return %0 : tensor<8x16xi1>
|
||||
|
||||
// CHECK-LABEL: greater_equal
|
||||
// CHECK: %0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0 : tensor<8x16xi1>
|
||||
}
|
||||
|
||||
func @rank(%arg0: tensor<11x16xf32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
|
||||
// CHECK-LABEL:rank
|
||||
// CHECK: %0 = "tfl.rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32>
|
||||
}
|
||||
|
||||
func @floor(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Floor"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: floor
|
||||
// CHECK: %0 = "tfl.floor"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK: return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @floor_div(tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
^bb0(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>):
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: floor_div
|
||||
// CHECK: %0 = tfl.floor_div %arg0, %arg1 : tensor<8x16xf32>
|
||||
// CHECK: return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @not_equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
return %0 : tensor<8x16xi1>
|
||||
|
||||
// CHECK-LABEL: not_equal
|
||||
// CHECK: %0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0 : tensor<8x16xi1>
|
||||
}
|
||||
|
||||
func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: select
|
||||
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
|
||||
// CHECK-LABEL:sin
|
||||
// CHECK: %0 = "tfl.sin"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
%0, %1 = "tf.TopKV2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0, %1: tensor<?xf32>, tensor<?xi32>
|
||||
|
||||
// CHECK-LABEL: topk
|
||||
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %arg1)
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
|
||||
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: topk_2
|
||||
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
|
||||
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
|
||||
|
||||
// CHECK-LABEL: topk_3
|
||||
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
|
||||
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
|
||||
|
||||
// CHECK-LABEL: topk_4
|
||||
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
|
||||
|
||||
// CHECK-LABEL: topk_5
|
||||
// CHECK: %0:2 = "tfl.topk_v2"(%arg0, %cst)
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
func @logicalAnd(%arg0: tensor<8xi1>, %arg1: tensor<8xi1>) -> tensor<8xi1> {
|
||||
%0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<8xi1>, tensor<8xi1>) -> tensor<8xi1>
|
||||
return %0: tensor<8xi1>
|
||||
|
||||
// CHECK-LABEL: logicalAnd
|
||||
// CHECK: %0 = tfl.logical_and %arg0, %arg1 : tensor<8xi1>
|
||||
// CHECK: return %0 : tensor<8xi1>
|
||||
}
|
||||
|
||||
func @logicalNot(%arg0: tensor<8xi1>) -> tensor<8xi1> {
|
||||
%0 = "tf.LogicalNot"(%arg0) : (tensor<8xi1>) -> tensor<8xi1>
|
||||
return %0 : tensor<8xi1>
|
||||
// CHECK-LABEL: logicalNot
|
||||
// CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<8xi1>) -> tensor<8xi1>
|
||||
}
|
||||
|
||||
func @logicalOr(%arg0: tensor<8xi1>, %arg1: tensor<8xi1>) -> tensor<8xi1> {
|
||||
%0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<8xi1>, tensor<8xi1>) -> tensor<8xi1>
|
||||
return %0: tensor<8xi1>
|
||||
|
||||
// CHECK-LABEL: logicalOr
|
||||
// CHECK: %0 = tfl.logical_or %arg0, %arg1 : tensor<8xi1>
|
||||
// CHECK: return %0 : tensor<8xi1>
|
||||
}
|
||||
|
||||
func @addV2(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
|
||||
// CHECK-LABEL: addV2
|
||||
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi32>) -> tensor<1x2x3x4xf32>
|
||||
return %0 : tensor<1x2x3x4xf32>
|
||||
|
||||
// CHECK-LABEL:reverse_v2
|
||||
// CHECK: %0 = "tfl.reverse_v2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: return %0 : tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL:maximum
|
||||
// CHECK: %0 = "tfl.maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @minimum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL:minimum
|
||||
// CHECK: %0 = "tfl.minimum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @realDiv(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: realDiv
|
||||
// CHECK: %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
return %0 : tensor<8x16xi1>
|
||||
|
||||
// CHECK-LABEL: equal
|
||||
// CHECK: %0 = "tfl.equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
// CHECK: return %0 : tensor<8x16xi1>
|
||||
}
|
||||
|
||||
func @pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%0 = "tf.Pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
|
||||
// CHECK-LABEL: pad
|
||||
// CHECK: %0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||
// CHECK: return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
|
||||
// CHECK-LABEL: padv2
|
||||
// CHECK: %0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @pack2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {N = 2 : i64} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: pack2Tensors
|
||||
// CHECK: %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
}
|
||||
|
||||
func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {N = 3 : i64, axis = 1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: pack3Tensors
|
||||
// CHECK: %0 = "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {N = 3 : i64, axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: packNegAxis
|
||||
// CHECK: %0 = "tfl.pack"(%arg0, %arg1, %arg2) {axis = -1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @unpack2Tensors(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
|
||||
%0:2 = "tf.Unpack"(%arg0) {num = 2 : i64} : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: unpack2Tensors
|
||||
// CHECK: %0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
|
||||
}
|
||||
|
||||
func @unpack3Tensors(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.Unpack"(%arg0) {num = 3 : i64, axis = 1 : i64} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: unpack3Tensors
|
||||
// CHECK: %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
}
|
||||
|
||||
func @unpackNegAxis(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.Unpack"(%arg0) {num = 3 : i64, axis = -1 : i64} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: unpackNegAxis
|
||||
// CHECK: %0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
}
|
||||
|
||||
func @mean(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
%0 = "tf.Mean"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
|
||||
// CHECK-LABEL: mean
|
||||
// CHECK: %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
}
|
||||
|
||||
func @mean_true(%arg0: tensor<2x2xf32>, %arg1: tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
%0 = "tf.Mean"(%arg0, %arg1) {keep_dims = true} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
|
||||
// CHECK-LABEL: mean_true
|
||||
// CHECK: %0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
}
|
||||
|
||||
func @sum(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: sum
|
||||
// CHECK: %0 = "tfl.sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: sum_true
|
||||
// CHECK: %0 = "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_min
|
||||
// CHECK: %0 = "tfl.reduce_min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_min_true
|
||||
// CHECK: %0 = "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_max
|
||||
// CHECK: %0 = "tfl.reduce_max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_max_true
|
||||
// CHECK: %0 = "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: batch_to_space_nd
|
||||
// CHECK: %0 = "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: space_to_batch_nd
|
||||
// CHECK: %0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
|
||||
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: concat2Tensors
|
||||
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
}
|
||||
|
||||
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concat3Tensors
|
||||
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2With3Tensors
|
||||
// CHECK: %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: resize_with_bilinear
|
||||
// CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
// Note: half_pixel_centers isn't supported by TFLite, so it's not
|
||||
// legalized.
|
||||
func @resize_with_bilinear_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: resize_with_bilinear_with_half_pixel_centers
|
||||
// CHECK: "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true}
|
||||
}
|
||||
|
||||
func @strided_slice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
// CHECK-LABEL: strided_slice
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
@ -0,0 +1,101 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
|
||||
|
||||
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<*x!tf.variant>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<3x10xf32>
|
||||
return %1, %2 : tensor<10xf32>, tensor<3x10xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistGetItem
|
||||
// CHECK: %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<3x10xf32>, tensor<i32>) -> tensor<10xf32>
|
||||
// CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<*x!tf.variant>, tensor<i32>, tensor<1xi32>) -> tensor<*xf32>
|
||||
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<*xf32>
|
||||
return %1, %2 : tensor<*xf32>, tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistGetItemWithUnknownRank
|
||||
// CHECK: %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10xf32>) -> tensor<3x10xf32> {
|
||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>):
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
||||
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<*x!tf.variant>, tensor<i32>, tensor<10xf32>) -> tensor<*x!tf.variant>
|
||||
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<*x!tf.variant>, tensor<1xi32>) -> tensor<3x10xf32>
|
||||
return %2 : tensor<3x10xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistSetItem
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
|
||||
// CHECK: return %15 : tensor<3x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i32>, tensor<f32>) -> tensor<5xf32> {
|
||||
^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor<*x!tf.variant>
|
||||
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<*x!tf.variant>, tensor<i32>, tensor<f32>) -> tensor<*x!tf.variant>
|
||||
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<*x!tf.variant>, tensor<0xi32>) -> tensor<5xf32>
|
||||
return %2 : tensor<5xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistSetItemWithScalarElements
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
|
||||
// CHECK: return %15 : tensor<5xf32>
|
||||
}
|
||||
|
||||
func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xf32> {
|
||||
^bb0(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) {element_dtype = f32} : (tensor<3xi32>, tensor<i32>) -> tensor<*x!tf.variant>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<*x!tf.variant>, tensor<i32>, tensor<3xi32>) -> tensor<3xf32>
|
||||
return %1 : tensor<3xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistReserve
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
|
||||
// CHECK: return %3 : tensor<3xf32>
|
||||
}
|
20
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD
Normal file
20
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD
Normal file
@ -0,0 +1,20 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,101 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MyCustomOp"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "mul",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "MyCustomOp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "exp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 2, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: ExpOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
// tf.MyCustomOp is the result of conversion to a Custom op
|
||||
%3 = "tf.MyCustomOp"(%2, %1) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("MyCustomOp")
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
return %4 : tensor<4xf32>
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
|
||||
# CHECK: loc("disable_builtin.mlir":2:1): is a TFLite builtin op but builtin emission is not enabled
|
||||
# CHECK-NEXT: Verification failed.
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "std.constant" () {name = "Const2", value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Add" (%0, %1) {name = "add"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %1 : tensor<3x2xi32>
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1
|
||||
# CHECK: loc("disable_flex.mlir":96:8): error: 'tf.div' op is a Flex op but Flex ops are not enabled for emission
|
||||
# CHECK-NEXT: Verification failed.
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// tf.div is the result of conversion to a Flex TF op
|
||||
%3 = "tf.Div"(%2, %1) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "tfl.exp"(%3) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %4 : tensor<4xf32>
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "mul0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "mul1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "exp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: inputs: [ 2, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 3 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: ExpOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul0")
|
||||
%3 = "tfl.mul"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul1")
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
return %4 : tensor<4xf32>
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops=true -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAddV2"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 3, 2 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "tf.Placeholder.input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tf.AddV2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: custom_options: [ 5, 65, 100, 100, 86, 50, 0, 18, 18, 5, 65, 100, 100, 86, 50, 42, 7, 10, 1, 84, 18, 2, 48, 1, 50, 0, 0, 2, 27, 21, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Placeholder"} : (tensor<3x2xf32>) -> tensor<3x2xf32>
|
||||
%1 = "tf.AddV2"(%0, %0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
@ -0,0 +1,100 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexDiv"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "mul",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "div",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "exp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 2, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 3, 68, 105, 118, 0, 16, 18, 3, 68, 105, 118, 42, 7, 10, 1, 84, 18, 2, 48, 1, 50, 0, 0, 2, 23, 19, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: ExpOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
// tf.div is the result of conversion to a Flex TF op
|
||||
%3 = "tf.Div"(%2, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
return %4 : tensor<4xf32>
|
||||
}
|
171
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir
Normal file
171
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir
Normal file
@ -0,0 +1,171 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LESS
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "Experimental_If"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_input1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.less",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "tf.If",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 2, 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "tfl.add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "cond_true"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "tfl.mul",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "cond_false"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
|
||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
return %3 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xi1>) -> tensor<4xi1> {
|
||||
^bb0(%arg0: tensor<4xi1>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_OR
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_AND
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "Const2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "logical_or",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "logical_and",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 3, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ]
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 1, 1, 1, 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xi1>) -> tensor<4xi1> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<true> : tensor<4xi1>} : () -> tensor<4xi1> loc("Const1")
|
||||
%2 = "tfl.pseudo_const" () {value = dense<false> : tensor<4xi1>} : () -> tensor<4xi1> loc("Const2")
|
||||
%3 = "tfl.logical_or"(%0, %2) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> loc("logical_or")
|
||||
%4 = "tfl.logical_and"(%3, %1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> loc("logical_and")
|
||||
return %4 : tensor<4xi1>
|
||||
}
|
137
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir
Normal file
137
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir
Normal file
@ -0,0 +1,137 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DIV
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: NEG
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "squared_difference",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "mul",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "div",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "exp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "neg",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: DivOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 4 ],
|
||||
// CHECK-NEXT: outputs: [ 5 ],
|
||||
// CHECK-NEXT: builtin_options_type: ExpOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 4,
|
||||
// CHECK-NEXT: inputs: [ 5 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: builtin_options_type: NegOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.squared_difference"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||
%3 = "tfl.mul"(%0, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
%4 = "tfl.div"(%3, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
%5 = "tfl.exp"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
%6 = "tfl.neg"(%5) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||
return %6 : tensor<4xf32>
|
||||
}
|
55
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir
Normal file
55
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir
Normal file
@ -0,0 +1,55 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16xf32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 6, 6, 16 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "avgpool",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: builtin_options_type: Pool2DOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: padding: VALID,
|
||||
// CHECK-NEXT: stride_w: 1,
|
||||
// CHECK-NEXT: stride_h: 3,
|
||||
// CHECK-NEXT: filter_width: 6,
|
||||
// CHECK-NEXT: filter_height: 3
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> loc("Input")
|
||||
%1 = "tfl.average_pool_2d"(%0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool")
|
||||
return %1 : tensor<1x1x1x16xf32>
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
|
||||
// | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
|
||||
%2:2 = "tfl.fully_connected"(%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>, tensor<40x40xf32>)
|
||||
return %2 : tensor<40x40xf32>
|
||||
}
|
||||
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
@ -0,0 +1,158 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: QUANTIZE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CONV_2D
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: RESHAPE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SOFTMAX
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 224, 224, 3 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.quantize",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.007812 ],
|
||||
// CHECK-NEXT: zero_point: [ 128 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 32, 3, 3, 3 ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_qconst",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.021827 ],
|
||||
// CHECK-NEXT: zero_point: [ 151 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 32 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_qconst1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.000171 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "tfl.conv_2d",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.023528 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "tfl.reshape",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.023528 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "tfl.softmax",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.003906 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "tfl.dequantize",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 7 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 1, 2, 3 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: Conv2DOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: stride_w: 5,
|
||||
// CHECK-NEXT: stride_h: 4,
|
||||
// CHECK-NEXT: dilation_w_factor: 3,
|
||||
// CHECK-NEXT: dilation_h_factor: 2
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 4 ],
|
||||
// CHECK-NEXT: outputs: [ 5 ],
|
||||
// CHECK-NEXT: builtin_options_type: ReshapeOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: new_shape: [ 1, 1001 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 5 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: builtin_options_type: SoftmaxOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: beta: 1.0
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 4,
|
||||
// CHECK-NEXT: inputs: [ 6 ],
|
||||
// CHECK-NEXT: outputs: [ 7 ]
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
|
||||
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%5 = "tfl.reshape"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
|
||||
return %7 : tensor<1x1001xf32>
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<6xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: RESHAPE
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 3, 2 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.reshape",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: builtin_options_type: ReshapeOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: new_shape: [ 6 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
||||
%2 = "tfl.reshape" (%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
return %2 : tensor<6xi32>
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SUB
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 3, 2 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "Input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 3, 2 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "sub",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "Const2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: builtin_options_type: SubOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: fused_activation_function: RELU6
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 3, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 10, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
|
||||
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
|
||||
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("Const2")
|
||||
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
|
||||
return %4 : tensor<3x2xi32>
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
// CHECK: error: 'unknown_op' op dialect is not registered
|
||||
%1 = "unknown_op"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %1 : tensor<3x2xi32>
|
||||
}
|
@ -0,0 +1,211 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "Experimental_While"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: GREATER
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SUB
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_input",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.pseudo_input1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tf.While",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "tf.While:1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "tfl.greater",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ]
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "cond"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "tfl.sub",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 13,
|
||||
// CHECK-NEXT: name: "tfl.add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3, 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: SubOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 1, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "body"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 1, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
||||
// While %0 is greater than zero, element wise add %1 with itself.
|
||||
%2:2 = "tf.While"(%0, %1) {
|
||||
cond = @cond, body = @body
|
||||
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||
return %2#1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %1, %2 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
858
tensorflow/compiler/mlir/lite/tests/ops.mlir
Normal file
858
tensorflow/compiler/mlir/lite/tests/ops.mlir
Normal file
@ -0,0 +1,858 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// Unary math ops
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testCos
|
||||
func @testCos(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.cos"(%arg0)
|
||||
%0 = "tfl.cos"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid Cos input
|
||||
func @testCosWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}}
|
||||
%0 = "tfl.cos"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testExp
|
||||
func @testExp(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.exp"(%arg0)
|
||||
%0 = "tfl.exp"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFloor
|
||||
func @testFloor(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.floor"(%arg0)
|
||||
%0 = "tfl.floor"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testGather
|
||||
func @testGather(%arg0 : tensor<?xf32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xf32>,tensor<?xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testGather
|
||||
func @testGather(%arg0 : tensor<2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<2xf32>,tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: testGatherUnknownRank
|
||||
func @testGatherUnknownRank(%arg0 : tensor<*xf32>, %arg1 : tensor<1xi32>) -> tensor<*xf32> {
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<*xf32>,tensor<1xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherUnsupportedType(%arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{op failed to verify that params and output must have same element type}}
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xi32>,tensor<?xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherUnsupportedRank(%arg0 : tensor<f32>, %arg1 : tensor<1xi32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{op failed to verify that operand 0 is 1-D}}
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<f32>,tensor<1xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testAbs
|
||||
func @testAbs(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.abs"(%arg0)
|
||||
%0 = "tfl.abs"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddN
|
||||
func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>):
|
||||
// CHECK: "tfl.add_n"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid AddN
|
||||
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
|
||||
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
|
||||
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
|
||||
return %0 : tensor<? x f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLog
|
||||
func @testLog(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.log"(%arg0)
|
||||
%0 = "tfl.log"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNeg
|
||||
func @testNeg(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.neg"(%arg0)
|
||||
%0 = "tfl.neg"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRsqrt
|
||||
func @testRsqrt(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.rsqrt"(%arg0)
|
||||
%0 = "tfl.rsqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSin
|
||||
func @testSin(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.sin"(%arg0)
|
||||
%0 = "tfl.sin"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid Sin input
|
||||
func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}}
|
||||
%0 = "tfl.sin"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSqrt
|
||||
func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.sqrt"(%arg0)
|
||||
%0 = "tfl.sqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSquare
|
||||
func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.square"(%arg0)
|
||||
%0 = "tfl.square"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testTanh
|
||||
func @testTanh(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.tanh"(%arg0)
|
||||
%0 = "tfl.tanh"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testZerosLike
|
||||
func @testZerosLike(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: "tfl.zeros_like"(%arg0)
|
||||
%0 = "tfl.zeros_like"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDequantize
|
||||
func @testDequantize(tensor<? x i32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x i32>):
|
||||
// CHECK: "tfl.dequantize"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||
%0 = "tfl.dequantize"(%arg0): (tensor<? x i32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNot
|
||||
func @testLogicalNot(tensor<? x i1>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x i1>):
|
||||
// CHECK: "tfl.logical_not"(%arg0)
|
||||
%0 = "tfl.logical_not"(%arg0): (tensor<? x i1>) -> tensor<? x i1>
|
||||
return %0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testLogicalNotWrongOperandType(tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>):
|
||||
// expected-error @+1 {{'tfl.logical_not' op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.logical_not"(%arg0) : (tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// Binary math ops
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testAdd
|
||||
func @testAdd(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// TODO(jpienaar): Enable specifying label of enum for parsing.
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSub
|
||||
func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMul
|
||||
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDiv
|
||||
func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLess
|
||||
func @testLess(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: "tfl.less"(%arg0, %arg1)
|
||||
%0 = "tfl.less"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1>
|
||||
return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testFloorDivI32
|
||||
func @testFloorDivI32(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.floor_div %arg0, %arg1
|
||||
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testFloorDivF32
|
||||
func @testFloorDivF32(tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
|
||||
// CHECK: tfl.floor_div %arg0, %arg1
|
||||
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor<2 x f32> {
|
||||
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||
%0 = "tfl.floor_div"(%arg0, %arg1) : (tensor<2 x f32>, tensor<2 x i32>) -> tensor<2 x f32>
|
||||
return %0#0 : tensor<2 x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testFloorMod
|
||||
func @testFloorMod(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.floor_mod %arg0, %arg1
|
||||
%0 = tfl.floor_mod %arg0, %arg1 : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testPow
|
||||
func @testPow(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// CHECK: tfl.pow %arg0, %arg1
|
||||
%0 = tfl.pow %arg0, %arg1 : tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConv2D
|
||||
func @testConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>):
|
||||
// CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %0 : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFakeQuant
|
||||
func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
|
||||
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], narrow_range = true, num_bits = 2 : i32} : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.fake_quant"(%arg0) {minmax = [], num_bits = 2 : i32, narrow_range = true} : (tensor<? x f32>) -> tensor<? x f32>
|
||||
// CHECK: %1 = "tfl.fake_quant"(%0) {minmax = [3.000000e-01, 1.400000e+00], narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "tfl.fake_quant"(%0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %1 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testQuantize
|
||||
func @testQuantize(tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<?x!quant.uniform<u8:f32, 1.000000e-01:128>>}
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<? x !quant.uniform<u8:f32, 0.1:128>>} : (tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>>
|
||||
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalAnd
|
||||
func @testLogicalAnd(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
|
||||
// CHECK: tfl.logical_and %arg0, %arg1
|
||||
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
|
||||
return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testLogicalAndWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// expected-error @+1 {{'tfl.logical_and' op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLogicalOr
|
||||
func @testLogicalOr(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
|
||||
// CHECK: tfl.logical_or %arg0, %arg1
|
||||
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
|
||||
return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testLogicalOrWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
// expected-error @+1 {{'tfl.logical_or' op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testEluF32
|
||||
func @testEluF32(%arg0: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.elu"(%arg0)
|
||||
%0 = "tfl.elu"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testTileF32
|
||||
func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.tile"(%arg0, %arg1)
|
||||
%0 = "tfl.tile"(%arg0, %arg1): (tensor<4 x 1 x f32>, tensor<4 x i32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
|
||||
// expected-error @+1 {{operand #0 must be tensor of floating-point values}}
|
||||
%0 = "tfl.elu"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: "NONE"
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
|
||||
// CHECK: "RELU"
|
||||
%1 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU"} : tensor<4xi32>
|
||||
// CHECK: "RELU_N1_TO_1"
|
||||
%2 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"} : tensor<4xi32>
|
||||
// CHECK: "RELU6"
|
||||
%3 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<4xi32>
|
||||
// CHECK: "TANH"
|
||||
%4 = tfl.add %arg0, %arg1 {fused_activation_function = "TANH"} : tensor<4xi32>
|
||||
// CHECK: "SIGN_BIT"
|
||||
%5 = tfl.add %arg0, %arg1 {fused_activation_function = "SIGN_BIT"} : tensor<4xi32>
|
||||
return %0, %1, %2, %3, %4, %5: tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
// CHECK: "SAME"
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: "VALID"
|
||||
%1 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
// expected-error @+1 {{attribute 'padding' failed to satisfy constraint: padding enum}}
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SOMETHING", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %0 : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMaxPool2D
|
||||
func @testMaxPool2D(tensor<256x32x32x3xf32>) -> tensor<?xf32> {
|
||||
^bb0(%arg0: tensor<256x32x32x3xf32>):
|
||||
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMaxPool2DQuantized
|
||||
func @testMaxPool2DQuantized(tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>> {
|
||||
^bb0(%arg0: tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>):
|
||||
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
|
||||
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>>
|
||||
return %0 : tensor<?x!quant.uniform<i8:f32, 0.1:128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid MaxPool2D
|
||||
func @testMaxPool2DWrongOperandResultType(tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32> {
|
||||
^bb0(%arg0: tensor<1x7x7x16xi32>):
|
||||
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
|
||||
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32>
|
||||
return %0 : tensor<1x7x7x16xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid MaxPool2D
|
||||
func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>> {
|
||||
^bb0(%arg0: tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>):
|
||||
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
|
||||
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
|
||||
return %0 : tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLogistic
|
||||
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
|
||||
// CHECK: "tfl.logistic"(%arg0)
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16>
|
||||
return %0 : tensor<1x2x3x4x5xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point values}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstm
|
||||
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
|
||||
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid none type applied to a tensor type arg
|
||||
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test violation of projection weight and projection bias pred op trait
|
||||
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testReverseV2
|
||||
func @testReverseV2(%arg0: tensor<1x2x3x4xf32>, %arg1 : tensor<2xi32>) -> tensor<1x2x3x4xf32> {
|
||||
// CHECK: "tfl.reverse_v2"(%arg0, %arg1)
|
||||
%0 = "tfl.reverse_v2"(%arg0, %arg1): (tensor<1x2x3x4xf32>, tensor<2xi32>) -> tensor<1x2x3x4xf32>
|
||||
return %0 : tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test select
|
||||
// CHECK-LABEL: testSelect
|
||||
func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk
|
||||
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0, %1: tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk
|
||||
func @topk(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>) {
|
||||
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return %0, %1: tensor<*xf32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk_2
|
||||
func @topk_2(%arg0: tensor<3x4x8xf32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<3x4x8xf32>, tensor<i32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>)
|
||||
return %1#0, %1#1: tensor<3x4x2xf32>, tensor<3x4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk_d
|
||||
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
|
||||
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk_d
|
||||
// TODO(jpienaar): This should fail but doesn't as the op definition does not
|
||||
// include shape verification.
|
||||
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x3xf32>, tensor<?x3xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x3xf32>, tensor<?x3xi32>)
|
||||
return %1#0, %1#1: tensor<?x3xf32>, tensor<?x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk_d
|
||||
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testEqual
|
||||
func @testEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
|
||||
// CHECK: "tfl.equal"(%arg0, %arg1)
|
||||
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
|
||||
return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPad
|
||||
func @testPad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
// CHECK: "tfl.pad"(%arg0, %arg1)
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test Pad with invalid paddings size
|
||||
func @testPadWithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
|
||||
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 0's rank equals operand 1's size}}
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test Pad with invalid paddings rank
|
||||
func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
|
||||
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 1 is 2-D}}
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPadV2
|
||||
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
// CHECK: "tfl.padv2"(%arg0, %arg1, %cst)
|
||||
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test PadV2 with invalid paddings size
|
||||
func @testPadV2WithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 0's rank equals operand 1's size}}
|
||||
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test PadV2 with invalid paddings rank
|
||||
func @testPadV2WithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 1 is 2-D}}
|
||||
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test PadV2 with invalid constant rank
|
||||
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<[2.0]> : tensor<1xf32>
|
||||
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 2 is 0-D}}
|
||||
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<1xf32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test PadV2 with invalid constant data type
|
||||
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<2> : tensor<i32>
|
||||
//// expected-error @+1 {{'tfl.padv2' op failed to verify that input and constant value operands must have same element type}}
|
||||
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<i32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{input count should match 'values_count' attribute}}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{output count should match 'num' attribute}}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMean
|
||||
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
|
||||
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMean_true
|
||||
func @testMean_true(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = true}
|
||||
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testMean_missing_keep_dims(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
// expected-error @+1 {{'tfl.mean' op requires attribute 'keep_dims'}}
|
||||
%0 = "tfl.mean"(%arg0, %arg1): (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBatchToSpaceND
|
||||
func @testBatchToSpaceND(%arg0 : tensor<4x2x2x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSpaceToBatchND
|
||||
func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcat(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatQuantized(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i8:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{'tfl.concatenation' op failed to verify that values and output must have same element type}}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
|
||||
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values}}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testResizeBilinear
|
||||
func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xf32> {
|
||||
// CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false}
|
||||
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testStridedSlice
|
||||
func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> {
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
|
||||
// expected-error @+1 {{op failed to verify that input and output must have same element type}}
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xi32>
|
||||
return %0 : tensor<1x2x2x5xi32>
|
||||
}
|
||||
|
131
tensorflow/compiler/mlir/lite/tests/optimize.mlir
Normal file
131
tensorflow/compiler/mlir/lite/tests/optimize.mlir
Normal file
@ -0,0 +1,131 @@
|
||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: fusedConv2dRelu
|
||||
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fusedDepthwiseConv2dRelu6
|
||||
func @fusedDepthwiseConv2dRelu6(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.relu6"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fusedConv2dTanh
|
||||
func @fusedConv2dTanh(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.tanh"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "TANH", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseAddIntoConv2d
|
||||
func @fuseAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
|
||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseAddWithRelu6IntoConv2d
|
||||
func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
|
||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst)
|
||||
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: intermOpUsedTwice
|
||||
func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00,
|
||||
// CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00,
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK: return %0, %1
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
|
||||
func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
|
||||
%cst1 = constant dense<2.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]], {{\[\[}}7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]], {{\[\[}}1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]]]]> : tensor<1x3x3x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
|
||||
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
|
||||
%cst1 = constant dense<2.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
||||
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0.
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
|
||||
// CHECK: return %1
|
||||
}
|
40
tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
Normal file
40
tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
Normal file
@ -0,0 +1,40 @@
|
||||
// RUN: tf-opt %s -tfl-post-quantize | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
|
||||
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%5 = "tfl.reshape"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
|
||||
return %7 : tensor<1x1001xf32>
|
||||
}
|
||||
|
||||
func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%2 = "tfl.pseudo_input"(%arg1) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%4 = tfl.add %1, %3 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%5 = "tfl.dequantize"(%4) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
|
||||
return %5 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK-NEXT: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK-NEXT: %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
|
||||
// CHECK-NEXT: %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK-NEXT: %4 = "tfl.reshape"(%3) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK-NEXT: %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK-NEXT: return %5 : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// CHECK: func @main2(%arg0: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>, %arg1: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>> {
|
||||
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg1) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg0) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: %2 = tfl.add %1, %0 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: return %2 : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT:}
|
154
tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
Normal file
154
tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
Normal file
@ -0,0 +1,154 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%0 = "tfl.dequantize"(%cst) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"()
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: %2 = "tfl.quantize"(%1)
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConv2D
|
||||
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
|
||||
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
|
||||
// CHECK: %2 = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %3 = "tfl.pseudo_qconst"()
|
||||
// CHECK: %4 = "tfl.dequantize"(%3)
|
||||
// CHECK: %5 = "tfl.conv_2d"(%2, %4, %1)
|
||||
// CHECK: %6 = "tfl.quantize"(%5)
|
||||
// CHECK: return %6
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: QuantizeDepthwiseConv2D
|
||||
func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
|
||||
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
|
||||
// CHECK: %2 = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %3 = "tfl.pseudo_qconst"()
|
||||
// CHECK: %4 = "tfl.dequantize"(%3)
|
||||
// CHECK: %5 = "tfl.depthwise_conv_2d"(%2, %4, %1)
|
||||
// CHECK: %6 = "tfl.quantize"(%5)
|
||||
// CHECK: return %6
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeAveragePool2D
|
||||
func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.average_pool_2d"(%0) {
|
||||
name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32
|
||||
} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
return %1 : tensor<1x1x1x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %1 = "tfl.average_pool_2d"(%0)
|
||||
// CHECK: %2 = "tfl.quantize"(%1)
|
||||
// CHECK: %3 = "tfl.dequantize"(%2)
|
||||
// CHECK: return %3 : tensor<1x1x1x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeReshape2D
|
||||
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
|
||||
return %1 : tensor<1x36x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
|
||||
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>}
|
||||
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK: return %3 : tensor<1x36x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSoftmax
|
||||
func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
|
||||
return %1 : tensor<1x6x6x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
|
||||
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>}
|
||||
// CHECK: %3 = "tfl.dequantize"(%2)
|
||||
// CHECK: return %3 : tensor<1x6x6x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeChain
|
||||
func @QuantizeChain(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.average_pool_2d"(%2) {
|
||||
name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32
|
||||
} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
|
||||
%6 = "tfl.conv_2d"(%5, %4, %cst) {
|
||||
dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32
|
||||
} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%8 = "tfl.dequantize"(%7) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x6x6x16xf32>
|
||||
%9 = "tfl.reshape"(%8) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
|
||||
%10 = "tfl.softmax"(%9) {beta = 1.000000e+00 : f32} : (tensor<1x36x16xf32>) -> tensor<1x36x16xf32>
|
||||
return %10 : tensor<1x36x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
|
||||
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK: %3 = "tfl.pseudo_qconst"()
|
||||
// CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>)
|
||||
// CHECK: %5 = "tfl.average_pool_2d"(%2)
|
||||
// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>}
|
||||
// CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK: %8 = "tfl.conv_2d"(%7, %4, %1)
|
||||
// CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>}
|
||||
// CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK: %11 = "tfl.reshape"(%10)
|
||||
// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>}
|
||||
// CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK: %14 = "tfl.softmax"(%13)
|
||||
// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>}
|
||||
// CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>)
|
||||
// CHECK: return %16 : tensor<1x36x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConstant
|
||||
func @QuantizeConstant() -> tensor<2x3xf32> {
|
||||
%cst = constant dense<[[-3.0, -1.0, 0.0], [0.0, 1.0, 3.0]]> : tensor<2x3xf32>
|
||||
return %cst : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %cst = constant dense{{.*}}tensor<2x3xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.023529411764705882:128>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<2x3xf32>
|
||||
}
|
197
tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
Normal file
197
tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
Normal file
@ -0,0 +1,197 @@
|
||||
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
|
||||
|
||||
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) :
|
||||
// OK
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// Unsupported data format
|
||||
%1 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// OK
|
||||
%2 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// Unsupported padding
|
||||
%3 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "EXPLICIT", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// Unsupported strides
|
||||
%4 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
|
||||
return %0, %1, %2, %3, %4 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK-LABEL: conv
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: %0 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %2 = "tf.Conv2D"
|
||||
// CHECK: %3 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %5 = "tf.Conv2D"
|
||||
// CHECK: %6 = "tf.Conv2D"
|
||||
}
|
||||
|
||||
func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> (tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>) {
|
||||
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x4xf32>) :
|
||||
// OK
|
||||
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
|
||||
// Unsupported data format
|
||||
%1 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
|
||||
// OK
|
||||
%2 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", padding = "VALID", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
|
||||
// Unsupported strides
|
||||
%3 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xf32>
|
||||
|
||||
return %0, %1, %2, %3 : tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>
|
||||
|
||||
// CHECK-LABEL: depthwiseConv2D
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<12xf32>
|
||||
// CHECK: %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
|
||||
// CHECK: %0 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %2 = "tf.DepthwiseConv2dNative"
|
||||
// CHECK: %3 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %5 = "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
|
||||
// OK
|
||||
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Unsupported training
|
||||
%1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Use other output
|
||||
%2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
|
||||
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: fusedBatchNorm
|
||||
// CHECK:%cst = constant dense<1.000000e-03> : tensor<f32>
|
||||
// variance + epsilon
|
||||
// CHECK: %0 = "tf.Add"(%arg4, %cst) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// rsqrt(variance + epsilon)
|
||||
// CHECK: %1 = "tf.Rsqrt"(%0) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %2 = "tf.Mul"(%arg1, %1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// x * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %3 = "tf.Mul"(%arg0, %2) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
|
||||
// mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %4 = "tf.Mul"(%arg3, %2) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// offset - mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %5 = "tf.Sub"(%arg2, %4) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// x * scale * rsqrt(variance + epsilon) +
|
||||
// offset - mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %6 = "tf.Add"(%3, %5) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK: %7:5 = "tf.FusedBatchNorm"(%6, %arg1, %arg2, %arg3, %arg4)
|
||||
// CHECK: %8:5 = "tf.FusedBatchNorm"(%7#0, %arg1, %arg2, %arg3, %arg4)
|
||||
}
|
||||
|
||||
func @fakeQuantNotFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>):
|
||||
%arg1 = constant dense<-0.1> : tensor<f32>
|
||||
%arg2 = constant dense<0.2> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantNotFollowedByQuant
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
|
||||
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
|
||||
// CHECK: return %2 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>):
|
||||
%arg1 = constant dense<-0.1> : tensor<f32>
|
||||
%arg2 = constant dense<0.2> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>) -> tensor<8x8x8x8xf32>
|
||||
return %2 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantFollowedByQuant
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>}
|
||||
// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform<u8:f32, 0.0011764706057660721:85>>)
|
||||
// CHECK: return %2 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantVarsNotConst(tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8x8x8x8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
|
||||
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %1 : tensor<8x8x8x8xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantVarsNotConst
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: return %0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantFollowedByTranspose(tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> (tensor<16x3x3x3xf32>) {
|
||||
^bb0(%arg0: tensor<3x3x3x16xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%1 = "tf.Transpose"(%0, %cst_0): (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
return %1 : tensor<16x3x3x3xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantFollowedByTranspose
|
||||
// CHECK: %cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: return %1 : tensor<16x3x3x3xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantFollowedByReshape(tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x3x3x12xf32>) {
|
||||
^bb0(%arg0: tensor<3x3x3x4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x4xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x4xf32>
|
||||
%1 = "tf.Reshape"(%0, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
return %1 : tensor<1x3x3x12xf32>
|
||||
|
||||
// CHECK-LABEL: fakeQuantFollowedByReshape
|
||||
// CHECK: %cst = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
|
||||
// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: return %1 : tensor<1x3x3x12xf32>
|
||||
}
|
||||
|
||||
func @identity(tensor<10xi32>) -> tensor<10xi32> {
|
||||
^bb0(%arg0: tensor<10xi32>):
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32>
|
||||
return %0: tensor<10xi32>
|
||||
|
||||
// CHECK-LABEL: identity
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
|
||||
func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
|
||||
%166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = false, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32>
|
||||
return %166 : tensor<1x1000xf32>
|
||||
|
||||
// CHECK-LABEL: matmulNoTransposeAOrB
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
}
|
||||
|
||||
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
|
||||
%166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = true, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32>
|
||||
return %166 : tensor<1x1000xf32>
|
||||
|
||||
// CHECK-LABEL: matmulNoTransposeB
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
||||
// CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
}
|
132
tensorflow/compiler/mlir/lite/tests/quantize.mlir
Normal file
132
tensorflow/compiler/mlir/lite/tests/quantize.mlir
Normal file
@ -0,0 +1,132 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: QuantizeFloatConst
|
||||
func @QuantizeFloatConst() -> tensor<f32> {
|
||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeDenseFloatConst
|
||||
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
|
||||
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSplatFloatConst
|
||||
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
|
||||
%0 = constant dense<3.0> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%0 = "tfl.dequantize"(%cst) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: return %0 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConv2D
|
||||
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
|
||||
// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
// CHECK: return %2 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: QuantizeDepthwiseConv2D
|
||||
func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
|
||||
// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
|
||||
// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %0) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeAveragePool2D
|
||||
func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.average_pool_2d"(%0) {name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32>
|
||||
return %1 : tensor<1x1x1x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.average_pool_2d"(%arg0)
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x1x1x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x1x1x16xf32>
|
||||
// CHECK: return %1 : tensor<1x1x1x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeReshape2D
|
||||
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.reshape"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x36x16xf32>
|
||||
return %1 : tensor<1x36x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0)
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32>
|
||||
// CHECK: return %1 : tensor<1x36x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSoftmax
|
||||
func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x6x6x16xf32>
|
||||
%1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
|
||||
return %1 : tensor<1x6x6x16xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.softmax"(%arg0)
|
||||
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03:-128>>) -> tensor<1x6x6x16xf32>
|
||||
// CHECK: return %1 : tensor<1x6x6x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeAdd
|
||||
func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>) -> tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>> {
|
||||
^bb0(%arg0: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, %arg1: tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>) -> tensor<1x56x56x24xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>) -> tensor<1x56x56x24xf32>
|
||||
%2 = tfl.add %0, %1 {fused_activation_function = "NONE"} : tensor<1x56x56x24xf32>
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>} : (tensor<1x56x56x24xf32>) -> tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
|
||||
return %3 : tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
|
||||
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x56x56x24x!quant.uniform<u8:f32, 0.27583434161017922:119>>, tensor<1x56x56x24x!quant.uniform<u8:f32, 0.40149296779258581:136>>)
|
||||
// CHECK: return %0 : tensor<1x56x56x24x!quant.uniform<u8:f32, 0.4321689530914905:133>>
|
||||
}
|
151
tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
Normal file
151
tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
Normal file
@ -0,0 +1,151 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using mlir::MLIRContext;
|
||||
using mlir::Module;
|
||||
using stream_executor::port::StatusOr;
|
||||
using tensorflow::Status;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> print_function_result_mapping(
|
||||
"print-function-result-mapping",
|
||||
llvm::cl::desc(
|
||||
"Print the mapping of function result to flatbuffer output buffer"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
enum TranslationStatus { kTrSuccess, kTrFailure };
|
||||
|
||||
static int PrintFunctionResultMapping(const std::string &result,
|
||||
Module *module) {
|
||||
// Build model from the resultant string to extract the return values from
|
||||
// their source of truth.
|
||||
auto model =
|
||||
tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
|
||||
if (!model) return kTrFailure;
|
||||
|
||||
// Get an unknown location for where we don't have a terminator to get the
|
||||
// location of the return value from.
|
||||
auto unknown_loc = mlir::UnknownLoc::get(module->getContext());
|
||||
|
||||
auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
|
||||
std::function<mlir::Location(int)> loc) {
|
||||
const auto &output_tensor = (*subgraph.tensors())[buffer];
|
||||
std::cout << "\tname: '"
|
||||
<< (output_tensor->name() ? output_tensor->name()->str()
|
||||
: "<<unnamed>>")
|
||||
<< "' buffer: " << buffer;
|
||||
if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
|
||||
std::cout << '\n';
|
||||
};
|
||||
|
||||
// For every subgraph print out the name (if available), each result's output
|
||||
// buffer number and location of the return value (if available).
|
||||
for (auto *subgraph : *(*model)->subgraphs()) {
|
||||
std::string subgraph_name =
|
||||
subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
|
||||
|
||||
std::cout << '\'' << subgraph_name << "' inputs:\n";
|
||||
int i = 0;
|
||||
for (auto input : *subgraph->inputs())
|
||||
print_buffer(*subgraph, i++, input, nullptr);
|
||||
|
||||
std::cout << '\'' << subgraph_name << "' outputs:\n";
|
||||
mlir::Operation *terminator = nullptr;
|
||||
if (subgraph->name()) {
|
||||
if (auto fn = module->getNamedFunction(subgraph->name()->str()))
|
||||
terminator = fn->back().getTerminator();
|
||||
}
|
||||
i = 0;
|
||||
for (auto output : *subgraph->outputs()) {
|
||||
print_buffer(*subgraph, i, output, [&](int i) {
|
||||
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
|
||||
});
|
||||
}
|
||||
}
|
||||
return kTrSuccess;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::PrettyStackTraceProgram x(argc, argv);
|
||||
// TODO(jpienaar): Revise the command line option parsing here.
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
// TODO(antiagainst): We are pulling in multiple transformations as follows.
|
||||
// Each transformation has its own set of command-line options; options of one
|
||||
// transformation can essentially be aliases to another. For example, the
|
||||
// -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
|
||||
// -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
|
||||
// -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
|
||||
// We need to disable duplicated ones to provide a cleaner command-line option
|
||||
// interface. That also means we need to relay the value set in one option to
|
||||
// all its aliases.
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(
|
||||
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
||||
|
||||
// TODO(ashwinm): Enable command line parsing for both sides.
|
||||
int fake_argc = 1;
|
||||
tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
|
||||
|
||||
MLIRContext context;
|
||||
llvm::SourceMgr source_mgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
||||
StatusOr<std::unique_ptr<Module>> module =
|
||||
tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, use_splatted_constant, extra_opdefs,
|
||||
debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, inference_type, min_values, max_values,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
// message. So we can just return here.
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops,
|
||||
lower_tensor_list_ops, &result);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
auto output = mlir::openOutputFile(output_file_name);
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
// Print out debugging info related to function mapping.
|
||||
if (print_function_result_mapping)
|
||||
return PrintFunctionResultMapping(result, module.ValueOrDie().get());
|
||||
return kTrSuccess;
|
||||
}
|
67
tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc
Normal file
67
tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// TODO(jpienaar): Revise the command line option parsing here.
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> input_file_name(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> use_splatted_constant(
|
||||
"use-splatted-constant",
|
||||
llvm::cl::desc(
|
||||
"Replace constants with randonmly generated splatted tensors"),
|
||||
llvm::cl::init(false), llvm::cl::Hidden);
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> input_mlir(
|
||||
"input-mlir",
|
||||
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
|
||||
"GraphDef format"),
|
||||
llvm::cl::init(false), llvm::cl::Hidden);
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> output_mlir(
|
||||
"output-mlir",
|
||||
llvm::cl::desc(
|
||||
"Output MLIR rather than FlatBuffer for the generated TFLite model"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// The following is a temporary approach to allow injecting opdefs in addition
|
||||
// to those that are already part of the global TF registry linked in. The
|
||||
// primary goal is testing. This is not intended to be a general solution for
|
||||
// unregistered ops. More appropriate mechanisms, such as op hints, should be
|
||||
// used instead.
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::list<std::string> extra_opdefs(
|
||||
"tf-extra-opdefs", llvm::cl::desc("List of extra opdefs when importing "
|
||||
"graphdef (testing purposes only)"));
|
||||
|
||||
// Quantize and Dequantize ops pair can be optionally emitted before and after
|
||||
// the quantized model as the adaptors to receive and produce floating point
|
||||
// type data with the quantized model. Set this to `false` if the model input is
|
||||
// integer types.
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> emit_quant_adaptor_ops(
|
||||
"emit-quant-adaptor-ops",
|
||||
llvm::cl::desc(
|
||||
"Emit Quantize/Dequantize before and after the generated TFLite model"),
|
||||
llvm::cl::init(false));
|
40
tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h
Normal file
40
tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
|
||||
|
||||
// This file contains command-line options aimed to provide the parameters
|
||||
// required by the TensorFlow Graph(Def) to TF Lite Flatbuffer conversion. It is
|
||||
// only intended to be included by binaries.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
// The commandline options are defined in LLVM style, so the caller should
|
||||
// use llvm::InitLLVM to initilize the options.
|
||||
//
|
||||
// Please see the implementation file for documentation of details of these
|
||||
// options.
|
||||
// TODO(jpienaar): Revise the command line option parsing here.
|
||||
extern llvm::cl::opt<std::string> input_file_name;
|
||||
extern llvm::cl::opt<std::string> output_file_name;
|
||||
extern llvm::cl::opt<bool> use_splatted_constant;
|
||||
extern llvm::cl::opt<bool> input_mlir;
|
||||
extern llvm::cl::opt<bool> output_mlir;
|
||||
extern llvm::cl::list<std::string> extra_opdefs;
|
||||
extern llvm::cl::opt<bool> emit_quant_adaptor_ops;
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
|
166
tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Normal file
166
tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Normal file
@ -0,0 +1,166 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using mlir::MLIRContext;
|
||||
using mlir::Module;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
StatusOr<std::unique_ptr<Module>> LoadFromGraphdefOrMlirSource(
|
||||
const std::string &input_filename, bool input_mlir,
|
||||
bool use_splatted_constant, const std::vector<std::string> &extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, absl::string_view inference_type,
|
||||
absl::string_view min_values, absl::string_view max_values,
|
||||
bool prune_unused_nodes, llvm::SourceMgr *source_mgr,
|
||||
MLIRContext *context) {
|
||||
if (input_mlir) {
|
||||
// Set up the input file.
|
||||
std::string error_message;
|
||||
auto file = mlir::openInputFile(input_filename, &error_message);
|
||||
if (!file) {
|
||||
llvm::errs() << error_message << "\n";
|
||||
return errors::InvalidArgument("fail to open input file");
|
||||
}
|
||||
|
||||
source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
|
||||
return std::unique_ptr<Module>(mlir::parseSourceFile(*source_mgr, context));
|
||||
}
|
||||
for (const auto &tf_opdefs_string : extra_tf_opdefs) {
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||
&opdef)) {
|
||||
LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
|
||||
return errors::InvalidArgument("fail to parse extra OpDef");
|
||||
}
|
||||
// Register extra opdefs.
|
||||
// TODO(b/133770952): Support shape functions.
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](tensorflow::OpRegistrationData *op_reg_data) -> Status {
|
||||
*op_reg_data = tensorflow::OpRegistrationData(opdef);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
||||
if (use_splatted_constant) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
input_filename, debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, inference_type, min_values, max_values,
|
||||
prune_unused_nodes, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, inference_type, min_values, max_values, prune_unused_nodes,
|
||||
context);
|
||||
}
|
||||
|
||||
bool ShouldRunQuantizePasses(mlir::Module *m) {
|
||||
if (mlir::Function *main_fn = m->getNamedFunction("main")) {
|
||||
return main_fn->getAttrOfType<mlir::UnitAttr>("tf.quantize") !=
|
||||
mlir::Attribute();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
||||
bool emit_quant_adaptor_ops,
|
||||
bool lower_tensor_list_ops,
|
||||
mlir::PassManager *pass_manager) {
|
||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
// TODO(jpienaar): Revise post dialect constants.
|
||||
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
|
||||
// Canonicalization includes const folding, which is utilized here to optimize
|
||||
// away ops that can't get constant folded after PrepareTF pass. For example,
|
||||
// tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// The below passes only make sense if Builtin TFLite ops are enabled
|
||||
// for emission.
|
||||
if (emit_builtin_tflite_ops) {
|
||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||
// the TFLite dialect.
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
if (lower_tensor_list_ops) {
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
if (run_quantize) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass());
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(mlir::createCSEPass());
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||
mlir::Module *module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
||||
bool lower_tensor_list_ops, std::string *result) {
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module->getContext(),
|
||||
/*propagate=*/true);
|
||||
mlir::PassManager pm;
|
||||
bool run_quantize = ShouldRunQuantizePasses(module);
|
||||
|
||||
AddTFToTFLConversionPasses(emit_builtin_tflite_ops, run_quantize,
|
||||
emit_quant_adaptor_ops, lower_tensor_list_ops,
|
||||
&pm);
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
if (export_to_mlir) {
|
||||
llvm::raw_string_ostream os(*result);
|
||||
module->print(os);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Write MLIR TFLite dialect into FlatBuffer
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
77
tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
Normal file
77
tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
|
||||
// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
|
||||
// file; otherwise, load from a GraphDef.
|
||||
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||
// output_arrays is specified.
|
||||
stream_executor::port::StatusOr<std::unique_ptr<mlir::Module>>
|
||||
LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, absl::string_view inference_type,
|
||||
absl::string_view min_values, absl::string_view max_values,
|
||||
bool prune_unused_nodes, llvm::SourceMgr* source_mgr,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Quantization passess will run only when the user specifies a quantized type
|
||||
// in the `-tf-inference-type` flag, which is converted to the function
|
||||
// attribute "tf.quantize" by the importer module.
|
||||
// TODO(fengliuai): switch to the cmd flag once the flags are moved to this
|
||||
// file with main method.
|
||||
bool ShouldRunQuantizePasses(mlir::Module* m);
|
||||
|
||||
// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect
|
||||
// to a MLIR `pass_manager`. These passes first raise the control flow in the TF
|
||||
// control flow dialect, decode the constant tensors, and then legalize the
|
||||
// module to TF Lite dialect with some optimizations afterwards.
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops. If `run_quantize` is true, quantization
|
||||
// passes will be added. If `emit_quant_adaptor_ops` is true, Quantize and
|
||||
// Dequantize ops are added to the inputs and outputs of the quantized model.
|
||||
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
|
||||
// TF ops before legalization to TF Lite dialect.
|
||||
void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
||||
bool emit_quant_adaptor_ops,
|
||||
bool lower_tensor_list_ops,
|
||||
mlir::PassManager* pass_manager);
|
||||
|
||||
// Taking a MLIR module in TF control flow dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TF Lite dialect and
|
||||
// serializes the result to a string. Depending on an attribute in the module
|
||||
// main function, Quantization is applied. If `export_to_mlir` is true, the
|
||||
// result is exported in MLIR text format, otherwise exported in flat buffer.
|
||||
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||
mlir::Module* module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
||||
bool lower_tensor_list_ops, std::string* result);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
|
||||
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
// Helper macro that returns indented os.
|
||||
#define OUT(X) os.indent((X))
|
||||
|
||||
// The function below has a non-constant reference as that is required by LLVM's
|
||||
// TableGenMain.
|
||||
// NOLINTNEXTLINE
|
||||
static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"};
|
||||
|
||||
emitSourceFileHeader("TensorFlow Lite Ops Quant Spec Getters", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
|
||||
"GetOpQuantSpec(mlir::Operation *op) {\n";
|
||||
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
|
||||
for (auto *def : defs) {
|
||||
Operator op(def);
|
||||
for (const auto t : op.getTraits()) {
|
||||
if (auto opTrait = llvm::dyn_cast<mlir::tblgen::NativeOpTrait>(&t)) {
|
||||
auto trait = opTrait->getTrait();
|
||||
// We only handle TFL specific native op traits.
|
||||
if (!trait.startswith("TFL::")) continue;
|
||||
trait.consume_front("TFL::");
|
||||
|
||||
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
|
||||
<< ">(op)) {\n";
|
||||
|
||||
// There is a "NoQuantizableResult" trait, set the flag.
|
||||
if (trait.equals("NoQuantizableResult")) {
|
||||
OUT(4) << "spec->is_quantizable = false;\n";
|
||||
}
|
||||
// There is a "SameOperandsAndResultScale" trait, set the flag.
|
||||
if (trait.equals("SameOperandsAndResultsScale")) {
|
||||
OUT(4) << "spec->requires_same_scale = true;\n";
|
||||
}
|
||||
// There is a "FixedResultUniformScale" trait, set the type for result.
|
||||
if (trait.startswith("FixedResultUniformScale")) {
|
||||
OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n";
|
||||
OUT(6) << "spec->restricted_output_params.push_back(tfl."
|
||||
"GetResultQuantizedType(i));\n";
|
||||
}
|
||||
// There is a "AccumulatorUniformScale" trait, set the type for bias.
|
||||
auto trait_str = opTrait->getTrait().str();
|
||||
llvm::SmallVector<llvm::StringRef, 1> matches;
|
||||
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
|
||||
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
|
||||
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
|
||||
<< "GetUniformQuantizedTypeForBias)));\n";
|
||||
}
|
||||
|
||||
OUT(2) << "}\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
OUT(2) << "return spec;\n";
|
||||
OUT(0) << "}\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::PrettyStackTraceProgram X(argc, argv);
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OpQuantSpecWriter);
|
||||
}
|
243
tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
Normal file
243
tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
Normal file
@ -0,0 +1,243 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TFLite legalization patterns
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
// Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr
|
||||
// with builder.
|
||||
class ExtractI32At<int i> : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
|
||||
"].cast<IntegerAttr>().getInt())">;
|
||||
|
||||
// Merge the two Attributes to a ArrayAttr;
|
||||
def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// storage type $3 to a QuantizedType.
|
||||
def ConvertToQuantType : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3.getValue())">;
|
||||
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4)">;
|
||||
|
||||
// Predicate that holds if all the three attributes are set.
|
||||
def HasAll3Attrs : Constraint<CPred<"HasAll3Attrs($0, $1, $2)">>;
|
||||
|
||||
// Converts an integer attribute $0 to 32-bit with builder.
|
||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr($0.cast<IntegerAttr>().getInt())">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
|
||||
def IsIntList1XY1 : AttrConstraint<CPred<"TFIntListIs1XY1($_self)">>;
|
||||
|
||||
def : Pat<(TF_AbsOp $arg), (TFL_AbsOp $arg)>;
|
||||
|
||||
def : Pat<(TF_AvgPoolOp $value,
|
||||
IsIntList1XY1:$ksize,
|
||||
IsIntList1XY1:$strides,
|
||||
$padding,
|
||||
IsDataFormatNHWC:$format),
|
||||
(TFL_AveragePool2DOp $value,
|
||||
/*filter_height=*/ExtractI32At<1>:$ksize,
|
||||
/*filter_width=*/ExtractI32At<2>:$ksize,
|
||||
/*padding=*/$padding,
|
||||
/*stride_h=*/ExtractI32At<1>:$strides,
|
||||
/*stride_w=*/ExtractI32At<2>:$strides,
|
||||
/*fused_activation_function=*/TFL_AF_None)>;
|
||||
|
||||
def : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
|
||||
|
||||
def : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
|
||||
|
||||
def : Pat<(TF_EluOp $arg), (TFL_EluOp $arg)>;
|
||||
|
||||
def : Pat<(TF_ExpandDimsOp $input, $dim), (TFL_ExpandDimsOp $input, $dim)>;
|
||||
|
||||
def : Pat<(TF_FakeQuantWithMinMaxArgsOp $inputs,
|
||||
$min, $max,
|
||||
$num_bits, $narrow_range),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp $inputs,
|
||||
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
|
||||
$num_bits, $narrow_range)))>;
|
||||
|
||||
def : Pat<(TF_FillOp $arg, $value), (TFL_FillOp $arg, $value)>;
|
||||
|
||||
def : Pat<(TF_FloorOp $arg), (TFL_FloorOp $arg)>;
|
||||
|
||||
def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>;
|
||||
def : Pat<(TF_LogicalNotOp $arg), (TFL_LogicalNotOp $arg)>;
|
||||
def : Pat<(TF_LogSoftmaxOp $arg), (TFL_LogSoftmaxOp $arg)>;
|
||||
|
||||
def : Pat<(TF_MaxPoolOp $value,
|
||||
IsIntList1XY1:$ksize,
|
||||
IsIntList1XY1:$strides,
|
||||
$padding,
|
||||
IsDataFormatNHWC:$format),
|
||||
(TFL_MaxPool2DOp $value,
|
||||
/*padding=*/$padding,
|
||||
/*stride_w=*/ExtractI32At<2>:$strides,
|
||||
/*stride_h=*/ExtractI32At<1>:$strides,
|
||||
/*filter_width=*/ExtractI32At<2>:$ksize,
|
||||
/*filter_height=*/ExtractI32At<1>:$ksize,
|
||||
/*fused_activation_function=*/TFL_AF_None)>;
|
||||
|
||||
def : Pat<(TF_MaximumOp $arg1, $arg2), (TFL_MaximumOp $arg1, $arg2)>;
|
||||
def : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>;
|
||||
def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>;
|
||||
def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
|
||||
def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
|
||||
// The second operand is captured in the type for this transform.
|
||||
def : Pat<(TF_ReshapeOp:$res AnyStaticShapeTensor:$arg, $ignored),
|
||||
(TFL_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
|
||||
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||
// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
|
||||
// condition
|
||||
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
||||
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
|
||||
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
|
||||
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
||||
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
||||
def : Pat<(TF_SqueezeOp AnyStaticShapeTensor:$arg, $ignored_dims),
|
||||
(TFL_ReshapeOp $arg)>;
|
||||
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
||||
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
|
||||
// The following two rules can both match an tf.Placeholder.input node with
|
||||
// min/max/type attributes, so we increase the benefit of the first rule by one
|
||||
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
|
||||
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp
|
||||
(TFL_InputOp $inputs),
|
||||
(ConvertToQuantType $inputs, $min, $max, $type))),
|
||||
[(HasAll3Attrs $min, $max, $type)], (addBenefit 1)>;
|
||||
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
|
||||
(TFL_InputOp $inputs)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>;
|
||||
def : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_GreaterEqualOp $l, $r), (TFL_GreaterEqualOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
|
||||
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||
(ToOp $l, $r, TFL_AF_None)>;
|
||||
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
|
||||
[TF_Relu6Op, TFL_AF_Relu6]] in {
|
||||
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
// TODO: Maybe move these below to general pass?
|
||||
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
foreach fromToPair = [[TF_AddOp, TFL_AddOp],
|
||||
[TF_AddV2Op, TFL_AddOp],
|
||||
[TF_DivOp, TFL_DivOp],
|
||||
[TF_MulOp, TFL_MulOp],
|
||||
[TF_RealDivOp, TFL_DivOp],
|
||||
[TF_SubOp, TFL_SubOp]] in
|
||||
defm : FusedBinaryActivationFuncOpPat<fromToPair[0], fromToPair[1]>;
|
||||
|
||||
def : Pat<(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
|
||||
IsDataFormatNHWC:$data_format),
|
||||
(TFL_AddOp $l, $r, TFL_AF_None)>;
|
||||
// TODO(jpienaar): These should be handled by the pattern rewriter, find out
|
||||
// why it isn't.
|
||||
def : Pat<(TF_Relu6Op (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
|
||||
IsDataFormatNHWC:$data_format)),
|
||||
(TFL_AddOp $l, $r, TFL_AF_Relu6)>;
|
||||
|
||||
def : Pat<(TF_FakeQuantWithMinMaxVarsOp $inputs,
|
||||
(ConstantOp F32ElementsAttr:$min),
|
||||
(ConstantOp F32ElementsAttr:$max),
|
||||
$num_bits, $narrow_range),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp $inputs,
|
||||
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
|
||||
$num_bits, $narrow_range)))>;
|
||||
|
||||
def : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
|
||||
|
||||
def : Pat<(TF_SquaredDifferenceOp $l, $r), (TFL_SquaredDifferenceOp $l, $r)>;
|
||||
|
||||
// Note(ycling): We can eliminate Relu from Relu(SquaredDifference(x, y)),
|
||||
// since the result of SquaredDifference is always non-negative.
|
||||
// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
|
||||
// are failing without the following pattern to optimize Relu away fixes
|
||||
// the problem.
|
||||
def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
|
||||
(TFL_SquaredDifferenceOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>;
|
||||
|
||||
def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
|
||||
|
||||
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
||||
|
||||
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
||||
|
||||
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
|
||||
(TFL_StridedSliceOp $input, $begin, $end, $strides,
|
||||
(convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
|
||||
(convertIntAttrTo32Bit $new_axis_mask), (convertIntAttrTo32Bit $shrink_axis_mask))>;
|
222
tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
Normal file
222
tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
Normal file
@ -0,0 +1,222 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass converts operations in TensorFlow dialect into
|
||||
// operations that are legal in the TensorFlow Lite dialect. Operations that
|
||||
// can be legalized to TensorFlow Lite dialect with simple replacements are part
|
||||
// of this pass and other operations that may create extra ops should be part of
|
||||
// the PrepareTF pass which should be run before this pass. That way any
|
||||
// constant folding opportunities from the extra ops can be exploited by the
|
||||
// constant folding support for the TensorFlow ops.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual LegalizeTF Pass.
|
||||
namespace {
|
||||
|
||||
// Legalize operations in functions.
|
||||
struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
|
||||
#define DECL_CONVERT_OP(tf_op) \
|
||||
struct ConvertTF##tf_op##Op : public RewritePattern { \
|
||||
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
|
||||
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
|
||||
PatternMatchResult matchAndRewrite( \
|
||||
Operation* op, PatternRewriter& rewriter) const override; \
|
||||
}
|
||||
|
||||
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
|
||||
// operands are properly supported in declarative rewrite rule specification.
|
||||
|
||||
DECL_CONVERT_OP(Concat);
|
||||
DECL_CONVERT_OP(ConcatV2);
|
||||
DECL_CONVERT_OP(Gather);
|
||||
DECL_CONVERT_OP(GatherV2);
|
||||
DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(Pack);
|
||||
DECL_CONVERT_OP(TopKV2);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_concat_op.values());
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
return matchFailure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
||||
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_concat_op.values());
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
||||
return matchFailure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<ConcatenationOp>(
|
||||
op, output_type, values, ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult mlir::TFL::ConvertTFGatherOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
// Gather in TF -> Gather in TFL with axis=0
|
||||
IntegerType type = IntegerType::get(32, rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<TFL::GatherOp>(
|
||||
op, op->getOperand(0), op->getOperand(1), IntegerAttr::get(type, 0));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFGatherV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_op = cast<TF::GatherV2Op>(op);
|
||||
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_op.axis(), m_Constant(&axis))) return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<GatherOp>(op, op->getOperand(0),
|
||||
op->getOperand(1),
|
||||
ExtractSingleElementAsInteger(axis));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// The following is effectively:
|
||||
// def : Pat<
|
||||
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
|
||||
// ConstBoolAttrTrue:$transpose_b),
|
||||
// (TFL_FullyConnectedOp:$__0 $a, $b,
|
||||
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
|
||||
PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_matmul_op = cast<TF::MatMulOp>(op);
|
||||
if (tf_matmul_op.transpose_a()) return matchFailure();
|
||||
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
||||
|
||||
Type output_type = tf_matmul_op.getResult()->getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
auto no_input = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
auto fc_op = rewriter.create<FullyConnectedOp>(
|
||||
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
|
||||
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
|
||||
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {fc_op.getResult(0)});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_pack_op = cast<TF::PackOp>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_pack_op.values());
|
||||
auto output_type = tf_pack_op.output()->getType();
|
||||
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N().getZExtValue());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
||||
|
||||
rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
|
||||
axis);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFTopKV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||
rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, op->getOperand(0),
|
||||
op->getOperand(1));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
|
||||
auto* input = tf_unpack_op.value();
|
||||
auto output_types = functional::map([](Value* v) { return v->getType(); },
|
||||
tf_unpack_op.output());
|
||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num().getZExtValue());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
|
||||
|
||||
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* ctx = &getContext();
|
||||
auto& func = getFunction();
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
|
||||
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
|
||||
ConvertTFTopKV2Op, ConvertTFUnpackOp>::build(patterns,
|
||||
ctx);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
FunctionPassBase* CreateLegalizeTFPass() { return new LegalizeTF(); }
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
"tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -0,0 +1,338 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass prepares for legalization to the TFLite dialect by
|
||||
// converting Tensorlist operations in TensorFlow dialect into operations that
|
||||
// can be legalized to TensorFlow Lite dialect with simple replacements. The
|
||||
// newly created operations are in the TensorFlow dialect if the operation can
|
||||
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
|
||||
// is used.
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual LowerStaticTensorList Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
// Lower TensorList ops in functions for subsequent legalization.
|
||||
struct LowerStaticTensorListPass
|
||||
: public FunctionPass<LowerStaticTensorListPass> {
|
||||
void runOnFunction() override;
|
||||
LogicalResult ModifyTensorList();
|
||||
};
|
||||
|
||||
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
||||
ArrayRef<int64_t> shape, int32_t val) {
|
||||
auto type = rewriter->getTensorType(shape, rewriter->getIntegerType(32));
|
||||
auto attr = DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
|
||||
return rewriter->create<ConstantOp>(op->getLoc(), type, attr);
|
||||
}
|
||||
|
||||
Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
|
||||
Value *shape_tensor, int32_t val) {
|
||||
auto scalar_val = CreateI32SplatConst(op, rewriter, {}, val);
|
||||
return rewriter->create<TF::FillOp>(
|
||||
op->getLoc(), rewriter->getTensorType({-1}, rewriter->getIntegerType(32)),
|
||||
shape_tensor, scalar_val);
|
||||
}
|
||||
|
||||
struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
explicit ConvertTFTensorListSetItem(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
// This function rewrites the original op into a series of slice and concat op
|
||||
// to produce the same result. It first slices the first `$index` rows. Then
|
||||
// expands the dimension of the `$item`, followed by another slice of the
|
||||
// remaining rows starting from `$index` + 1. Lastly it concatenates the
|
||||
// three parts together.
|
||||
// On a high level, it's doing something like:
|
||||
// def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
|
||||
// (Concat
|
||||
// concat_dim = 0,
|
||||
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
|
||||
|
||||
auto input = tf_op.input_handle();
|
||||
auto shape_dtype = rewriter.getIntegerType(32);
|
||||
auto input_rank = rewriter.create<TF::RankOp>(
|
||||
op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
|
||||
auto item = tf_op.item();
|
||||
auto item_rank = rewriter.create<TF::RankOp>(
|
||||
op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
|
||||
|
||||
// Prepare the start position for the first slice op, which is [0, 0, ..,
|
||||
// 0].
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
|
||||
scalar_zero);
|
||||
// Fill all 0s into the first position tensor.
|
||||
auto first_start_position =
|
||||
CreateI32SplatTensor(op, &rewriter, position_shape, 0);
|
||||
|
||||
// Prepare the start position for the second slice op, which is
|
||||
// [index + 1, 0, 0 .. 0].
|
||||
// Calculate the first dimension, which is index + 1.
|
||||
auto index = tf_op.index();
|
||||
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
||||
auto begin =
|
||||
rewriter.create<TF::AddOp>(op->getLoc(), vector_type, index,
|
||||
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
||||
|
||||
// Followed by the first dimension `begin`, are `item_rank` of 0s.
|
||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
|
||||
scalar_zero);
|
||||
auto partial_second_start_position =
|
||||
CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
|
||||
auto position_type = first_start_position->getType();
|
||||
// Concatenate `begin` with the remaining 0s.
|
||||
auto second_start_position = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), position_type, scalar_zero,
|
||||
ArrayRef<Value *>({begin, partial_second_start_position}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
// Create the size parameter for the first slice op, which is [index, -1,
|
||||
// -1, .., -1].
|
||||
auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), vector_type, index, scalar_zero);
|
||||
auto partial_size1 =
|
||||
CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
|
||||
auto size1 = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), position_type, scalar_zero,
|
||||
ArrayRef<Value *>({size1_leading_dim, partial_size1}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
// Create the size parameter for the second slice, which is [-1, -1, ..,
|
||||
// -1].
|
||||
auto size2 = CreateI32SplatTensor(op, &rewriter, position_shape, -1);
|
||||
|
||||
// Create two slice ops.
|
||||
auto element_type = input->getType().cast<TensorType>().getElementType();
|
||||
auto unranked_tensor = rewriter.getTensorType(element_type);
|
||||
auto slice1 = rewriter.create<TF::SliceOp>(
|
||||
op->getLoc(), unranked_tensor, input, first_start_position, size1);
|
||||
auto slice2 = rewriter.create<TF::SliceOp>(
|
||||
op->getLoc(), unranked_tensor, input, second_start_position, size2);
|
||||
|
||||
// Expand the dimension of item so that it will have the same rank with
|
||||
// input.
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), unranked_tensor, item, scalar_zero);
|
||||
|
||||
// Concatenate three parts together to generate the final result.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, input->getType(), scalar_zero,
|
||||
ArrayRef<Value *>({slice1, expanded_item, slice2}),
|
||||
rewriter.getI64IntegerAttr(3));
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(hinsu): Fix end-to-end test when passing string `element_dtype`
|
||||
// attribute.
|
||||
struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
explicit ConvertTFTensorListReserve(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
|
||||
|
||||
auto element_shape = tf_op.element_shape();
|
||||
auto shape_dtype =
|
||||
element_shape->getType().cast<TensorType>().getElementType();
|
||||
auto num_elements = tf_op.num_elements();
|
||||
int64_t input_rank = -1; // -1 means unknown dimension.
|
||||
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
|
||||
// Note that the first item of the shape array is the element's rank, add
|
||||
// it by 1 to get the input's rank.
|
||||
if (type.hasStaticShape()) {
|
||||
input_rank = type.getShape()[0] + 1;
|
||||
}
|
||||
}
|
||||
auto element_dtype = tf_op.element_dtype();
|
||||
|
||||
// The output shape of the result tensor should be [num_elements +
|
||||
// element_shape].
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
|
||||
scalar_zero);
|
||||
auto shape_type = rewriter.getTensorType({input_rank}, shape_dtype);
|
||||
auto list_shape = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), shape_type, scalar_zero,
|
||||
ArrayRef<Value *>({leading_dim, element_shape}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
// Create a zero-initialized constant tensor that has the same type
|
||||
// as specified by element_dtype.
|
||||
auto zero_type = rewriter.getTensorType({}, element_dtype);
|
||||
auto zero_attr = rewriter.getZeroAttr(zero_type);
|
||||
auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::FillOp>(
|
||||
op, rewriter.getTensorType(element_dtype), list_shape, zero);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace TFL {
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
|
||||
} // namespace
|
||||
} // namespace TFL
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::ModifyTensorList() {
|
||||
// In `runOnFunction`, there is no guarantee about
|
||||
// in which order those patterns will be applied. Our transformation requires
|
||||
// that at runtime each `TensorListSetItem` op takes in a normal tensor type
|
||||
// rather than a `DT_VARIANT` tensor. So here we need to manually walk-through
|
||||
// the IR and change the argument/return types of each `TensorListSetItemOp`.
|
||||
// TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/
|
||||
// produces `DT_VARIANT` tensor. 2) More robust support for handling multiple
|
||||
// different tensorlist types. For example, consider the case like:
|
||||
// l1 = list_ops.tensor_list_from_tensor(t, element_shape1)
|
||||
// l2 = list_ops.tensor_list_from_tensor(t, element_shape2)
|
||||
// l1 = list_ops.tensor_list_set_item(l1, 0, item1)
|
||||
// l2 = list_ops.tensor_list_set_item(l2, 0, item2)
|
||||
// 3) Handle the case where a tensorlist output is passed to multiple
|
||||
// functions.
|
||||
for (Block &block : getFunction()) {
|
||||
Type tensor_type;
|
||||
for (Operation &op : block) {
|
||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
|
||||
tensor_type = tf_op.tensor()->getType();
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||
tf_op.element_dtype().isF64() ||
|
||||
tf_op.element_dtype().isa<IntegerType>())) {
|
||||
return tf_op.emitError(
|
||||
"requires element_dtype to be integer or 16-bit/32-bit/64-bit "
|
||||
"float type during TF Lite transformation pass");
|
||||
}
|
||||
// TODO(haoliang): figure out better way of specify shape.
|
||||
tensor_type = UnrankedTensorType::get(tf_op.element_dtype());
|
||||
}
|
||||
|
||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||
tf_op.input_handle()->setType(tensor_type);
|
||||
tf_op.getResult()->setType(tensor_type);
|
||||
}
|
||||
// Currently we will raise an error if an op other than the following
|
||||
// contains a DT_VARIANT tensor as its input or output. Below ops already
|
||||
// have proper transformation patterns that eliminate the need of
|
||||
// `DT_VARIANT`, we consider it's safe to not raise an error on those ops.
|
||||
if (llvm::isa<TF::TensorListFromTensorOp>(op) ||
|
||||
llvm::isa<TF::TensorListReserveOp>(op) ||
|
||||
llvm::isa<TF::TensorListSetItemOp>(op) ||
|
||||
llvm::isa<TF::TensorListStackOp>(op) ||
|
||||
llvm::isa<TF::TensorListGetItemOp>(op)) {
|
||||
continue;
|
||||
}
|
||||
// Check if any of the input operand is a DT_VARIANT.
|
||||
for (Type type : op.getOperandTypes()) {
|
||||
if (type.isa<TF::VariantType>()) {
|
||||
return op.emitError(
|
||||
"op's input contains a DT_VARIANT tensor. Currently we only "
|
||||
"allow "
|
||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
||||
"TensorListSetItem/"
|
||||
"TensorListGetItem to have DT_VARIANT input/output");
|
||||
}
|
||||
}
|
||||
// Check if any of the output is a DT_VARIANT.
|
||||
for (Type type : op.getResultTypes()) {
|
||||
if (type.isa<TF::VariantType>()) {
|
||||
return op.emitError(
|
||||
"op's output contains a DT_VARIANT tensor. Currently we only "
|
||||
"allow "
|
||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
||||
"TensorListSetItem/"
|
||||
"TensorListGetItem to have DT_VARIANT input/output");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnFunction() {
|
||||
if (failed(ModifyTensorList())) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFTensorListReserve>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFTensorListSetItem>(&getContext()));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
FunctionPassBase *TFL::CreateLowerStaticTensorListPass() {
|
||||
return new LowerStaticTensorListPass();
|
||||
}
|
||||
|
||||
static PassRegistration<LowerStaticTensorListPass> pass(
|
||||
"tfl-lower-static-tensor-list",
|
||||
"Lower TensorList ops within TensorFlow Lite dialect");
|
||||
|
||||
} // namespace mlir
|
66
tensorflow/compiler/mlir/lite/transforms/optimize.cc
Normal file
66
tensorflow/compiler/mlir/lite/transforms/optimize.cc
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass takes operations in TensorFlowLite dialect and
|
||||
// optimizes them to resulting operations in TensorFlowLite dialect.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual Optimize Pass.
|
||||
namespace {
|
||||
|
||||
// Optimize TFLite operations in functions.
|
||||
struct Optimize : public FunctionPass<Optimize> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
|
||||
// types.
|
||||
bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
|
||||
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
|
||||
|
||||
void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
// Add the generated patterns to the list.
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
FunctionPassBase *CreateOptimizePass() { return new Optimize(); }
|
||||
|
||||
static PassRegistration<Optimize> pass(
|
||||
"tfl-optimize", "Optimize within the TensorFlow Lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
112
tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Normal file
112
tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the optimization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-pattern consisting of matching stand-alone convolution op followed by
|
||||
// activation op.
|
||||
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w)),
|
||||
(TFL_Conv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w)>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
|
||||
// activation functions.
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_TanhOp, TFL_AF_Tanh]] in
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
|
||||
// If we see an add op adding a constant value to a convolution op with constant
|
||||
// bias, we can fuse the add into the convolution op by constant folding the
|
||||
// bias and the add op's constant operand.
|
||||
// The following pattern restricts to float constant values for now.
|
||||
def : Pat<(TFL_AddOp (TFL_Conv2DOp $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input, $filter,
|
||||
(TFL_AddOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w)>;
|
||||
def : Pat<(TFL_AddOp (TFL_DepthwiseConv2DOp $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter,
|
||||
(TFL_AddOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)>;
|
||||
|
||||
def BroadcastableElementsAttrs :
|
||||
Constraint<CPred<"IsBroadcastableElementsAttrs($0, $1)">>;
|
||||
|
||||
// If we see a mul op multiplying a constant value to a convolution op with
|
||||
// constant filter and bias, we can fuse the multiplication into the convolution
|
||||
// op by constant folding the filter/bias and the mul op's constant operand.
|
||||
// The following pattern restricts to float constant values for now.
|
||||
def : Pat<(TFL_MulOp (TFL_DepthwiseConv2DOp $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input,
|
||||
(TFL_MulOp (ConstantOp $filter),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
(TFL_MulOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(BroadcastableElementsAttrs $filter, $value)]>;
|
||||
|
||||
// This pattern applies when the same quantize/dequantize have been used twice
|
||||
// with the same scale. We want to remove the redundancy.
|
||||
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
49
tensorflow/compiler/mlir/lite/transforms/passes.h
Normal file
49
tensorflow/compiler/mlir/lite/transforms/passes.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
|
||||
|
||||
namespace mlir {
|
||||
class FunctionPassBase;
|
||||
|
||||
namespace TFL {
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
FunctionPassBase *CreateLegalizeTFPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
FunctionPassBase *CreateOptimizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
FunctionPassBase *CreatePrepareTFPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
FunctionPassBase *CreateLowerStaticTensorListPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||
FunctionPassBase *CreateQuantizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||
FunctionPassBase *CreatePrepareQuantizePass();
|
||||
|
||||
// Creates a instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
FunctionPassBase *CreatePostQuantizePass(bool emit_quant_adaptor_ops);
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_
|
137
tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
Normal file
137
tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies soem clean up steps after quantization.
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The post-quantize Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// Applies all the clean up steps after quantization.
|
||||
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration. This will remove the adaptor ops.
|
||||
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
|
||||
|
||||
// Constructor used by manually creating the pass.
|
||||
explicit PostQuantizePass(bool emit_quant_adaptor_ops)
|
||||
: emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
// Set this flag to true if the inputs and outputs are in floating point. The
|
||||
// quant adaptor ops convert them to fixed point values (i.e. quantize) before
|
||||
// feeding them to the model and convert them back to floating point
|
||||
// (i.e. dequantize) as the output.
|
||||
bool emit_quant_adaptor_ops_;
|
||||
};
|
||||
|
||||
void RemoveQuantizationAdaptorOps(Function* func) {
|
||||
mlir::OpBuilder builder(func->getBody());
|
||||
auto& bb = func->getBlocks().front();
|
||||
auto* terminator = bb.getTerminator();
|
||||
|
||||
int num_args = bb.getNumArguments();
|
||||
llvm::SmallVector<Type, 4> input_types;
|
||||
input_types.reserve(num_args);
|
||||
// Edit the block arguments and create the new input ops in place to replace
|
||||
// the old input ops and quantize ops.
|
||||
for (int i = 0; i != num_args; ++i) {
|
||||
// Previous loop iteration may invalidate the insertion point so we have to
|
||||
// reset insertion point each iteration.
|
||||
builder.setInsertionPointToStart(&bb);
|
||||
|
||||
// In each iteration, a new argument is appended to the end of the list
|
||||
// and the current argument is erased, so here we always process the first
|
||||
// argument in the list.
|
||||
auto* arg = bb.getArgument(0);
|
||||
auto* input_op = *arg->user_begin();
|
||||
auto input_result = input_op->getResult(0);
|
||||
// We can drop the quantization adaptor only when the pseudo input op has
|
||||
// one user and it is the quantize op. Otherwise, we have to keep the
|
||||
// adaptor and allow the floating point inputs.
|
||||
if (input_result->hasOneUse() &&
|
||||
isa<QuantizeOp>(*input_result->user_begin())) {
|
||||
auto* second_op = *input_result->user_begin();
|
||||
auto quantize_output = second_op->getResult(0);
|
||||
auto quantize_type = quantize_output->getType();
|
||||
input_types.push_back(quantize_type);
|
||||
auto* new_arg = bb.addArgument(quantize_type);
|
||||
// Make a copy of input op with quantized input and output type.
|
||||
auto new_input =
|
||||
builder.create<InputOp>(input_op->getLoc(), quantize_type, new_arg);
|
||||
quantize_output->replaceAllUsesWith(new_input);
|
||||
second_op->erase();
|
||||
input_op->erase();
|
||||
} else {
|
||||
// Make a copy of current argument and append it to the end of the list.
|
||||
Type arg_type = arg->getType();
|
||||
input_types.push_back(arg_type);
|
||||
auto* new_arg = bb.addArgument(arg_type);
|
||||
arg->replaceAllUsesWith(new_arg);
|
||||
}
|
||||
arg->dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
}
|
||||
|
||||
// Edit the return ops and remove the dequantize ops in place.
|
||||
int num_return_operands = terminator->getNumOperands();
|
||||
llvm::SmallVector<Type, 4> output_types;
|
||||
output_types.reserve(num_return_operands);
|
||||
for (int i = 0; i != num_return_operands; ++i) {
|
||||
auto* returned_value = terminator->getOperand(i);
|
||||
Operation* returned_op = returned_value->getDefiningOp();
|
||||
if (isa<DequantizeOp>(returned_op)) {
|
||||
auto* dequantized_result = returned_op->getOperand(0);
|
||||
output_types.push_back(dequantized_result->getType());
|
||||
terminator->setOperand(i, dequantized_result);
|
||||
returned_op->erase();
|
||||
} else {
|
||||
output_types.push_back(returned_value->getType());
|
||||
}
|
||||
}
|
||||
auto new_func_type = builder.getFunctionType(input_types, output_types);
|
||||
func->setType(new_func_type);
|
||||
}
|
||||
|
||||
void PostQuantizePass::runOnFunction() {
|
||||
auto& func = getFunction();
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
RemoveQuantizationAdaptorOps(&func);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
FunctionPassBase* CreatePostQuantizePass(bool emit_quant_adaptor_ops) {
|
||||
return new PostQuantizePass(emit_quant_adaptor_ops);
|
||||
}
|
||||
|
||||
static PassRegistration<PostQuantizePass> pass(
|
||||
"tfl-post-quantize", "Apply post quantization clean up after quantization");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
100
tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
Normal file
100
tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
||||
|
||||
// Converts tf.FusedBatchNorm into a sequence of more primitive arithmetic
|
||||
// operations. Specifically, performs the following calculation:
|
||||
//
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset
|
||||
//
|
||||
// Let multiplier = scale / sqrt(variance + epsilon),
|
||||
// to compute
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
|
||||
// is then to compute
|
||||
// (x * multiplier) + (offset - mean * multiplier).
|
||||
def : Pattern<
|
||||
(TF_FusedBatchNormOp $x, $scale, $offset, $mean, $variance,
|
||||
F32Attr:$epsilon, $data_format,
|
||||
FalseBoolAttr:$is_training),
|
||||
[(TF_AddOp
|
||||
(TF_MulOp
|
||||
$x,
|
||||
(TF_MulOp:$multiplier
|
||||
$scale,
|
||||
(TF_RsqrtOp
|
||||
(TF_AddOp $variance,
|
||||
(TF_ConstOp $epsilon))))),
|
||||
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
|
||||
/*batch_mean=*/(verifyUnusedValue),
|
||||
/*batch_variance=*/(verifyUnusedValue),
|
||||
/*reserve_space_1=*/(verifyUnusedValue),
|
||||
/*reserve_space_2=*/(verifyUnusedValue)
|
||||
]>;
|
||||
|
||||
// TODO(jpienaar): Move to opbase something more general.
|
||||
def TFi32ElementsAttr : Attr<CPred<"$_self.isa<DenseIntElementsAttr>">,
|
||||
"scalar int attribute"> {
|
||||
let storageType = [{ DenseIntElementAttr }];
|
||||
let constBuilderCall = "$_builder.getDenseElementsAttr("
|
||||
"$_builder.getTensorType({}, $_builder.getIntegerType(32)), "
|
||||
"{$_builder.getI32IntegerAttr($0)})";
|
||||
}
|
||||
class TFi32<int v> : ConstantAttr<TFi32ElementsAttr, !cast<string>(v)>;
|
||||
|
||||
// Matmul without transpose on b to matmul with explicit transpose op and
|
||||
// transposed b.
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
|
||||
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $b),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
|
||||
$at, ConstBoolAttrTrue)>;
|
||||
|
||||
// Matmul with transpose on a to matmul with explicit transpose op and a not
|
||||
// transposed.
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
|
||||
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $a),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
|
||||
ConstBoolAttrFalse, $bt)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op removal patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_IdentityOp $arg), (replaceWithValue $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op quantization pass-through patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO(fengliuai): Implement similar rule in the QuantizePass if the constant
|
||||
// folding hook of tfl.transpose and tfl.reshape are implemented.
|
||||
def : Pat<(TF_TransposeOp
|
||||
(TF_FakeQuantWithMinMaxVarsOp
|
||||
$input, $min, $max, $num_bits, $narrow_range),
|
||||
$perm),
|
||||
(TF_FakeQuantWithMinMaxVarsOp (TF_TransposeOp $input, $perm),
|
||||
$min, $max, $num_bits, $narrow_range)>;
|
||||
|
||||
def : Pat<(TF_ReshapeOp
|
||||
(TF_FakeQuantWithMinMaxVarsOp
|
||||
$input, $min, $max, $num_bits, $narrow_range),
|
||||
$shape),
|
||||
(TF_FakeQuantWithMinMaxVarsOp (TF_ReshapeOp $input, $shape),
|
||||
$min, $max, $num_bits, $narrow_range)>;
|
54
tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
Normal file
54
tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on TFLite dialect.
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The prepare-quantize Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
// Applies prepare quantization on the model in TFL dialect. This pass runs
|
||||
// before the quantization pass and propagate the quantization parameters
|
||||
// across ops. This step is necessary for post-training quantization and also
|
||||
// making the quantizaton rule for some operations in the quantization-awre
|
||||
// training quantization simpler.
|
||||
struct PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
ApplyQuantizationParamsPropagation(&getFunction());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||
FunctionPassBase *CreatePrepareQuantizePass() {
|
||||
return new PrepareQuantizePass();
|
||||
}
|
||||
|
||||
static PassRegistration<PrepareQuantizePass> pass(
|
||||
"tfl-prepare-quantize", "Prepare TFL dialect for quantization");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
379
tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
Normal file
379
tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
Normal file
@ -0,0 +1,379 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass prepares for legalization to the TFLite dialect by
|
||||
// converting operations in TensorFlow dialect into operations that can be
|
||||
// legalized to TensorFlow Lite dialect with simple replacements. The newly
|
||||
// created operations are in the TensorFlow dialect if the operation can be
|
||||
// represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
|
||||
// used. For example, Conv2D in TFLite which uses OHWI data format for filters
|
||||
// is not supported in TensorFlow because TensorFlow requires filters in the
|
||||
// HWIO data format.
|
||||
//
|
||||
// Motivation to prepare for the TFLite legalization before the actual
|
||||
// legalization is to exploit constant folding opportunities in any newly
|
||||
// created ops by leveraging constant folding support for the TensorFlow ops.
|
||||
// This way TFLite can be used as a serialization format only and does not
|
||||
// require access to the TFLite runtime for optimizations as required by the
|
||||
// TFLite team.
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual PrepareTF Pass.
|
||||
//
|
||||
// TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
|
||||
// pass.
|
||||
namespace {
|
||||
|
||||
// Prepare TF operations in functions for subsequent legalization.
|
||||
struct PrepareTFPass : public FunctionPass<PrepareTFPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair after the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
|
||||
// folding logic will use a "std.constant" op to replace the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
|
||||
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
|
||||
// convert the output type to the next op.
|
||||
struct InsertTFLQuantOpsAfterTFFakeQuantOp : public RewritePattern {
|
||||
InsertTFLQuantOpsAfterTFFakeQuantOp(MLIRContext *context)
|
||||
: RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct MatchedState : public PatternState {
|
||||
FloatAttr min;
|
||||
FloatAttr max;
|
||||
APInt num_bits;
|
||||
bool narrow_range;
|
||||
};
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto tf_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
|
||||
auto res = tf_op.outputs();
|
||||
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
|
||||
return matchFailure();
|
||||
auto state = absl::make_unique<MatchedState>();
|
||||
ElementsAttr min_value, max_value;
|
||||
if (!matchPattern(tf_op.min(), m_Constant(&min_value)))
|
||||
return matchFailure();
|
||||
if (!matchPattern(tf_op.max(), m_Constant(&max_value)))
|
||||
return matchFailure();
|
||||
state->min = ExtractSingleElementAsFloat(min_value);
|
||||
state->max = ExtractSingleElementAsFloat(max_value);
|
||||
if (!state->min || !state->max) return matchFailure();
|
||||
state->num_bits = tf_op.num_bits();
|
||||
state->narrow_range = tf_op.narrow_range();
|
||||
return matchSuccess(std::move(state));
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto &s = *static_cast<MatchedState *>(state.get());
|
||||
Location loc = op->getLoc();
|
||||
Value *copied = OpBuilder(op).clone(*op)->getResult(0);
|
||||
Type res_type = copied->getType();
|
||||
Type storage_type = rewriter.getIntegerType(s.num_bits.getSExtValue());
|
||||
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, s.min, s.max,
|
||||
storage_type, s.narrow_range);
|
||||
Value *quantize_op =
|
||||
rewriter.create<TFL::QuantizeOp>(loc, qtype.getValue(), copied, qtype);
|
||||
rewriter.replaceOpWithNewOp<TFL::DequantizeOp>(op, res_type, quantize_op);
|
||||
}
|
||||
};
|
||||
|
||||
// Templated class for declaring a converter from some TensorFlow convolution
|
||||
// op into its counterpart in TensorFlow Lite.
|
||||
//
|
||||
// The `ConcreteType` deriving from this template must provide the following
|
||||
// method for constructing TensorFlow Lite op:
|
||||
//
|
||||
// TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
// PatternRewriter &rewriter, Location loc,
|
||||
// Type result_type, Value *input,
|
||||
// Value *filter, Value *bias) const;
|
||||
//
|
||||
// And also the following method for getting the dimension for bias tensor:
|
||||
//
|
||||
// int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
|
||||
template <typename ConcreteType, typename TFConvOpType>
|
||||
struct ConvertTFConvOp : public RewritePattern {
|
||||
// Transient state for preserving data from match to rewrite
|
||||
struct ConvertTFConvOpMatchState : public PatternState {
|
||||
IntegerAttr dilation_height_factor;
|
||||
IntegerAttr dilation_width_factor;
|
||||
StringAttr padding;
|
||||
IntegerAttr stride_height;
|
||||
IntegerAttr stride_width;
|
||||
};
|
||||
|
||||
ConvertTFConvOp(MLIRContext *context)
|
||||
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
|
||||
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
// Assumes TensorFlow convolution op is already verified to be
|
||||
// in valid form.
|
||||
|
||||
// Match a TFConvOpType under the following conditions:
|
||||
// * The 'T' attribute must exist and be of value DT_FLOAT.
|
||||
// * The 'data_format' attribute must exist and be of value "NHWC".
|
||||
// * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
|
||||
// * The 'dilations' attribute is optional, but it must be of the form
|
||||
// [1, X, Y, 1] if exists.
|
||||
|
||||
TFConvOpType tf_op = cast<TFConvOpType>(op);
|
||||
|
||||
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
|
||||
return matchFailure();
|
||||
|
||||
IntegerAttr height, width;
|
||||
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure();
|
||||
|
||||
auto state = llvm::make_unique<ConvertTFConvOpMatchState>();
|
||||
|
||||
state->stride_height = height;
|
||||
state->stride_width = width;
|
||||
|
||||
if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
|
||||
state->dilation_height_factor = height;
|
||||
state->dilation_width_factor = width;
|
||||
} else {
|
||||
// If the 'dilations' attribute is missing, we use the default value (1)
|
||||
// for both dilation height and width factor.
|
||||
state->dilation_height_factor = intAttrOne;
|
||||
state->dilation_width_factor = intAttrOne;
|
||||
}
|
||||
|
||||
StringAttr padding_attr;
|
||||
if (!TFPaddingIsSameOrValid(op, &padding_attr)) return matchFailure();
|
||||
state->padding = padding_attr;
|
||||
|
||||
// Additionally, we require the filter operand to be of 4-D tensor type so
|
||||
// that we can extract info from the shape (e.g., for constructing bias
|
||||
// tensor, for setting depth_multiplier attribute, etc.).
|
||||
auto filter_type =
|
||||
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
|
||||
if (filter_type && filter_type.getRank() == 4)
|
||||
return matchSuccess(std::move(state));
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TensorFlow convolution op only has two inputs, while the TFLite one has
|
||||
// three, with the bias vector marked as optional. However, TOCO has a
|
||||
// dedicated pass, EnsureBiasVectors, to create default bias vectors for all
|
||||
// those missing. So we model TFLite convolution op as requiring three
|
||||
// inputs to achieve the legalization task of EnsureBiasVector. this
|
||||
// requires the filter tensor to have static shape.
|
||||
|
||||
// TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
|
||||
|
||||
TFConvOpType tf_op = cast<TFConvOpType>(op);
|
||||
|
||||
// Get a splat zero tensor with the expected dimension for the bias tensor
|
||||
auto filter = tf_op.filter();
|
||||
auto filter_type = filter->getType().template cast<RankedTensorType>();
|
||||
auto elem_type = filter_type.getElementType();
|
||||
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
|
||||
filter_type.getShape());
|
||||
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
|
||||
auto bias_attr = rewriter.getZeroAttr(bias_type);
|
||||
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
|
||||
|
||||
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
|
||||
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
|
||||
conv_state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(),
|
||||
filter, bias);
|
||||
|
||||
rewriter.replaceOp(op, conv_op.getResult());
|
||||
}
|
||||
|
||||
const IntegerAttr intAttrOne;
|
||||
};
|
||||
|
||||
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
public:
|
||||
using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
|
||||
|
||||
ConvertTFConv2D(MLIRContext *context) : BaseType(context) {}
|
||||
|
||||
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
|
||||
return filterShape.back();
|
||||
}
|
||||
|
||||
TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
PatternRewriter &rewriter, Location loc,
|
||||
Type result_type, Value *input, Value *filter,
|
||||
Value *bias) const {
|
||||
filter = legalizeFilter(rewriter, loc, filter);
|
||||
return rewriter.create<TFL::Conv2DOp>(
|
||||
loc, result_type, input, filter, bias,
|
||||
/*dilation_h_factor=*/state->dilation_height_factor,
|
||||
/*dilation_w_factor=*/state->dilation_width_factor,
|
||||
/*fused_activation_function=*/rewriter.getStringAttr("NONE"),
|
||||
/*padding=*/state->padding,
|
||||
/*stride_h=*/state->stride_height,
|
||||
/*stride_w=*/state->stride_width);
|
||||
}
|
||||
|
||||
private:
|
||||
// Legalize the given filter by converting it from TensorFlow filter data
|
||||
// format HWIO to TFLite Conv2D op filter data format OHWI and return Value
|
||||
// for the converted filter. Requires that filter is verified by the match
|
||||
// method that it is a 4-D RankedTensorType.
|
||||
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value *filter) const {
|
||||
// Create a constant op for HWIO to OHWI transpose permutation.
|
||||
SmallVector<int, 4> perm = {3, 0, 1, 2};
|
||||
auto perm_type = rewriter.getTensorType({static_cast<int>(perm.size())},
|
||||
rewriter.getIntegerType(32));
|
||||
auto perm_attr =
|
||||
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
auto result_shape = functional::map(
|
||||
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
||||
perm);
|
||||
auto elem_type = filter_type.getElementType();
|
||||
auto result_type = rewriter.getTensorType(result_shape, elem_type);
|
||||
|
||||
return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTFDepthwiseConv2dNative
|
||||
: public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
|
||||
TF::DepthwiseConv2dNativeOp> {
|
||||
public:
|
||||
using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
|
||||
TF::DepthwiseConv2dNativeOp>;
|
||||
|
||||
ConvertTFDepthwiseConv2dNative(MLIRContext *context) : BaseType(context) {}
|
||||
|
||||
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
|
||||
return filterShape[2] * filterShape[3];
|
||||
}
|
||||
|
||||
TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
PatternRewriter &rewriter, Location loc,
|
||||
Type result_type, Value *input,
|
||||
Value *filter, Value *bias) const {
|
||||
// Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
|
||||
// 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
|
||||
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
|
||||
// fourth dimension in the 4-D filter tensor. We query the multiplier from
|
||||
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
|
||||
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
|
||||
|
||||
filter = legalizeFilter(rewriter, loc, filter);
|
||||
return rewriter.create<TFL::DepthwiseConv2DOp>(
|
||||
loc, result_type, input, filter, bias,
|
||||
/*dilation_h_factor=*/state->dilation_height_factor,
|
||||
/*dilation_w_factor=*/state->dilation_width_factor,
|
||||
/*fused_activation_function=*/rewriter.getStringAttr("NONE"),
|
||||
/*padding=*/state->padding,
|
||||
/*stride_h=*/state->stride_height,
|
||||
/*stride_w=*/state->stride_width,
|
||||
/*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
|
||||
}
|
||||
|
||||
private:
|
||||
/// Legalize the given filter by converting it from TensorFlow filter data
|
||||
/// format to TFLite DepthwiseConv2D op filter data format and return Value
|
||||
/// for the converted filter. TensorFlow filter data format is
|
||||
/// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
|
||||
/// filter data format is [1, filter_height, filter_width, out_channels].
|
||||
/// Requires that filter is verified by the match method that it is a 4-D
|
||||
/// RankedTensorType.
|
||||
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value *filter) const {
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
auto filterShape = filter_type.getShape();
|
||||
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
|
||||
filterShape[2] * filterShape[3]};
|
||||
auto elem_type = filter_type.getElementType();
|
||||
auto result_type = rewriter.getTensorType(result_shape, elem_type);
|
||||
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
|
||||
auto shape_attr =
|
||||
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
|
||||
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
// TODO(karimnosseir): Split to separate pass probably after
|
||||
// deciding on long term plan for this optimization.
|
||||
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
||||
// and any expanded from FusedBatchNorm. We need to do this
|
||||
// before converting TF_Conv to TFL_Conv
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
patterns.push_back(llvm::make_unique<ConvertTFConv2D>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFDepthwiseConv2dNative>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<InsertTFLQuantOpsAfterTFFakeQuantOp>(&getContext()));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
FunctionPassBase *CreatePrepareTFPass() { return new PrepareTFPass(); }
|
||||
|
||||
static PassRegistration<PrepareTFPass> pass(
|
||||
"tfl-prepare-tf", "Prepare TF for legalization to TensorFlow Lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
67
tensorflow/compiler/mlir/lite/transforms/quantize.cc
Normal file
67
tensorflow/compiler/mlir/lite/transforms/quantize.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization on TFLite dialect.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual Quantize Pass.
|
||||
//
|
||||
namespace {
|
||||
|
||||
/// Applies quantization on the model in TFL dialect.
|
||||
struct QuantizePass : public FunctionPass<QuantizePass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
|
||||
|
||||
void QuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto *context = func.getContext();
|
||||
populateWithGenerated(context, &patterns);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
||||
FunctionPassBase *CreateQuantizePass() { return new QuantizePass(); }
|
||||
|
||||
static PassRegistration<QuantizePass> pass(
|
||||
"tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
128
tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td
Normal file
128
tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
|
||||
|
||||
// Call the generic builder of `op`. Use the result type of $0 in the new op.
|
||||
class ReplaceWith<string op> : NativeCodeCall<"$_builder.create<" # op #
|
||||
">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">;
|
||||
|
||||
// Squash tfl.dequantize and tfl.quantize pairs.
|
||||
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
||||
// squashed to a requantize op if the scales are different.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
|
||||
// Quantize the value of a constant op if the quantization parameters have been
|
||||
// propagated to the output.
|
||||
def : Pat<(TFL_QuantizeOp
|
||||
(ConstantOp ElementsAttr:$value),
|
||||
$qtype),
|
||||
(TFL_QConstOp
|
||||
$qtype,
|
||||
(QuantizeByQuantizedType $value, $qtype))>;
|
||||
|
||||
// Quantize the AddOp if both inputs are dequantized and the output is
|
||||
// quantized.
|
||||
def : Pat<(TFL_QuantizeOp:$q
|
||||
(TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs),
|
||||
$fused_activation_function),
|
||||
$output_type),
|
||||
(ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs,
|
||||
$fused_activation_function)>;
|
||||
|
||||
// Quantize the Conv2DOp if the input and weight are dequantized. The scale of
|
||||
// the bias input is determined by the scales of input and weight operands.
|
||||
// TODO(fengliuai): propagate the quantization parameters to the bias input.
|
||||
def : Pat<(TFL_QuantizeOp
|
||||
(TFL_Conv2DOp
|
||||
(TFL_DequantizeOp $in),
|
||||
(TFL_DequantizeOp $weight),
|
||||
(TFL_DequantizeOp $bias),
|
||||
$dilation_h_factor,
|
||||
$dilation_w_factor,
|
||||
$fused_activation_function,
|
||||
$padding,
|
||||
$stride_h,
|
||||
$stride_w),
|
||||
$output_type),
|
||||
(TFL_Conv2DOp
|
||||
$in,
|
||||
$weight,
|
||||
$bias,
|
||||
$dilation_h_factor,
|
||||
$dilation_w_factor,
|
||||
$fused_activation_function,
|
||||
$padding,
|
||||
$stride_h,
|
||||
$stride_w)>;
|
||||
|
||||
// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The
|
||||
// scale of the bias input is determined by the scales of input and weight
|
||||
// operands.
|
||||
// TODO(fengliuai): propagate the quantization parameters to the bias input.
|
||||
def : Pat<(TFL_QuantizeOp
|
||||
(TFL_DepthwiseConv2DOp
|
||||
(TFL_DequantizeOp $in),
|
||||
(TFL_DequantizeOp $weight),
|
||||
(TFL_DequantizeOp $bias),
|
||||
$dilation_h_factor,
|
||||
$dilation_w_factor,
|
||||
$fused_activation_function,
|
||||
$padding,
|
||||
$stride_h,
|
||||
$stride_w,
|
||||
$multiplier),
|
||||
$output_type),
|
||||
(TFL_DepthwiseConv2DOp
|
||||
$in,
|
||||
$weight,
|
||||
$bias,
|
||||
$dilation_h_factor,
|
||||
$dilation_w_factor,
|
||||
$fused_activation_function,
|
||||
$padding,
|
||||
$stride_h,
|
||||
$stride_w,
|
||||
$multiplier)>;
|
||||
|
||||
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
|
||||
// The pre-quantize pass can guarantee both quantization parameters are the
|
||||
// same.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type),
|
||||
(TFL_ReshapeOp $in)>;
|
||||
|
||||
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
|
||||
// The pre-quantize pass has set the output quantization parameters to a
|
||||
// pre-defined value.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta),
|
||||
$output_type),
|
||||
(TFL_SoftmaxOp $in, $beta)>;
|
||||
|
||||
// Quantize the AveragePool2DOp if the input is dequantized and output is
|
||||
// quantized. The pre-quantize pass can guarantee both quantization parameters
|
||||
// are the same.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in),
|
||||
$filter_height, $filter_width, $fused_activation_function,
|
||||
$padding, $stride_h, $stride_w), $output_type),
|
||||
(TFL_AveragePool2DOp $in,
|
||||
$filter_height, $filter_width, $fused_activation_function,
|
||||
$padding, $stride_h, $stride_w)>;
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorList transformation patterns.
|
||||
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
|
||||
// into regular tensors. We also assume that each element in the `TensorList` has
|
||||
// a same constant shape.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
|
||||
(replaceWithValue $tensor)>;
|
||||
|
||||
def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
||||
(replaceWithValue $input)>;
|
||||
|
||||
def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
49
tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
Normal file
49
tensorflow/compiler/mlir/lite/utils/attribute_utils.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) {
|
||||
if (attr.getType().getNumElements() != 1 ||
|
||||
!attr.getType().getElementType().isa<FloatType>()) {
|
||||
return {};
|
||||
}
|
||||
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
||||
return attr.getValue(index).cast<FloatAttr>();
|
||||
}
|
||||
|
||||
FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
|
||||
if (auto m = attr.dyn_cast_or_null<ElementsAttr>()) {
|
||||
return ExtractSingleElementAsFloat(m);
|
||||
} else {
|
||||
return attr.dyn_cast_or_null<FloatAttr>();
|
||||
}
|
||||
}
|
||||
|
||||
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
|
||||
if (attr.getType().getNumElements() != 1 ||
|
||||
!attr.getType().getElementType().isa<IntegerType>()) {
|
||||
return {};
|
||||
}
|
||||
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
||||
return attr.getValue(index).cast<IntegerAttr>();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
50
tensorflow/compiler/mlir/lite/utils/attribute_utils.h
Normal file
50
tensorflow/compiler/mlir/lite/utils/attribute_utils.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with op attributes.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Returns true if none of the three attributes are empty.
|
||||
inline bool HasAll3Attrs(Attribute a, Attribute b, Attribute c) {
|
||||
return a != Attribute() && b != Attribute() && c != Attribute();
|
||||
}
|
||||
|
||||
// Returns the single float element from an ElementsAttr. Returns empty
|
||||
// attribute if the number of elements in the attribute is not 1 or the
|
||||
// element isn't a float attribute.
|
||||
FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr);
|
||||
|
||||
// Returns the single float element if the input is an ElementsAttr, or return
|
||||
// itself as a float element. Returns empty attribute if the number of elements
|
||||
// in the attribute is not 1, the element or itself isn't a float attribute.
|
||||
FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr);
|
||||
|
||||
// Returns the single integer element from an ElementsAttr. Returns empty
|
||||
// attribute if the number of elements in the attribute is not 1 or the
|
||||
// element isn't a integer attribute.
|
||||
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr);
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
607
tensorflow/compiler/mlir/lite/utils/quantization_driver.cc
Normal file
607
tensorflow/compiler/mlir/lite/utils/quantization_driver.cc
Normal file
@ -0,0 +1,607 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
using QuantParams = quant::QuantizedType;
|
||||
using AccumulatorScaleFunc =
|
||||
std::function<QuantParams(const std::vector<QuantParams> &)>;
|
||||
|
||||
// Quantization specs of ops, driving the TF Lite quantization algorithm.
|
||||
struct OpQuantSpec {
|
||||
// Whether the op has quantizable result. This flag is set to false if the op
|
||||
// has "TFL::NoQuantizableResult" trait.
|
||||
bool is_quantizable = true;
|
||||
|
||||
// Whether it requires same inputs and result scale. This flag is set to true
|
||||
// if the op has "TFL::SameOperandsAndResultScale" trait.
|
||||
bool requires_same_scale = false;
|
||||
|
||||
// Maps the operand index of a bias input to its quantization specifications,
|
||||
// including the non-bias operand indexes and the method retrieving
|
||||
// quantization parameters from list of parameters of the non-bias operands.
|
||||
// This map is empty if the op doesn't havea bias operand.
|
||||
std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
|
||||
biases_params;
|
||||
|
||||
// Quantization parameters for value restricted outputs. This is the
|
||||
// "hard-coded" parameters and should be used unconditionally for the
|
||||
// quantized op. This vector is empty if the op doesn't have value resctricted
|
||||
// outputs.
|
||||
llvm::SmallVector<QuantParams, 1> restricted_output_params;
|
||||
};
|
||||
|
||||
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
|
||||
|
||||
// The state for each op result during the quantization parameters propagation.
|
||||
struct QuantState {
|
||||
// Quantization parameters propagated to an op result.
|
||||
QuantParams params;
|
||||
// A flag indicates this state (the params) shouldn't be changed after it is
|
||||
// initialized. This flag will be set to true if the quantization parameters
|
||||
// are from the quantization-aware training.
|
||||
const bool immutable;
|
||||
|
||||
bool IsEmpty() { return EmptyParams(params); }
|
||||
};
|
||||
|
||||
// The state for rescaling the propagated quantization parameters. This can be
|
||||
// on the input side to satisfy the constraint of previous operation, or on the
|
||||
// output side to satisfy the constraint of the next operation.
|
||||
struct RequantizeState {
|
||||
// Sometimes, we have to "requantize" the quantization result to satisfy all
|
||||
// the constraints. The "requantize" can happen either on the input or output
|
||||
// of the quantization result.
|
||||
enum RequantizePosition {
|
||||
NO_REQUANTIZE,
|
||||
ON_INPUT,
|
||||
ON_OUTPUT
|
||||
} pos = NO_REQUANTIZE;
|
||||
|
||||
// Quantization parameters will be used to add the requantize ops.
|
||||
QuantParams params;
|
||||
};
|
||||
|
||||
// This is a worklist-driven driver for propagating quantization parameters
|
||||
// across operations.
|
||||
//
|
||||
// The initial quantization parameters are extracted from the quantized type
|
||||
// between adjacent tfl.quantize and tfl.dequantize ops. All these initial
|
||||
// parameters are marked as immutable because they are from quantization-aware
|
||||
// training.
|
||||
//
|
||||
// The algorithm traverses each op and sets the quantization parameters of its
|
||||
// operands and results, according to its quantization specification, and then
|
||||
// adds the operands and results to the worklist. If there are any conflicts
|
||||
// (for example, there are quantization parameters propagated from the previous
|
||||
// iteration), this process stops if the existing parameters are the immutable,
|
||||
// or adding `requantize` op to resolve the conflicts.
|
||||
//
|
||||
// After the algorithm is converaged, pairs of tfl.quantize and tfl.dequantize
|
||||
// are inserted to the right position to materialize the propagation and
|
||||
// requantize results.
|
||||
//
|
||||
class QuantizationDriver {
|
||||
public:
|
||||
explicit QuantizationDriver(Function *fn) : builder_(fn->getBody()) {}
|
||||
|
||||
// The entry point of the quantization parameters propagation.
|
||||
void Run();
|
||||
|
||||
private:
|
||||
// This is used to identify an operand or result of an op. The second element
|
||||
// of this pair is the index of the operand or result.
|
||||
using OpValue = std::pair<mlir::Operation *, int>;
|
||||
|
||||
// Sets up the states for all the op results in the function.
|
||||
void Initialize();
|
||||
|
||||
// Propagates the quantization parameters across all the ops.
|
||||
bool PropagateParams();
|
||||
|
||||
// Inserts the Quantize and Dequantize ops according to the propagation
|
||||
// result.
|
||||
void Finalize();
|
||||
|
||||
// Whether the constant is used as a bias input of another op. Here we assume
|
||||
// bias is used immediately by the user. This assumption is always correct
|
||||
// after constant folding.
|
||||
bool UsedAsBias(ConstantOp cst) {
|
||||
Value *value = cst.getResult();
|
||||
for (auto &use : value->getUses()) {
|
||||
auto biases = GetQuantSpec(use.getOwner())->biases_params;
|
||||
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns all the related quantization constraints of the op.
|
||||
std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
|
||||
|
||||
// Whether Quantization parameters have been propagated to the results of this
|
||||
// op.
|
||||
bool IsQuantized(Operation *op);
|
||||
|
||||
// Adds all the users of index-th result of op to the work list.
|
||||
void AddUserToList(Operation *op, int index) {
|
||||
for (auto &user : op->getResult(index)->getUses()) {
|
||||
work_list_.push_back(user.getOwner());
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the defining op of index-th operand of op to the work list.
|
||||
void AddOperandToList(Operation *op, int index) {
|
||||
if (auto *inst = op->getOperand(index)->getDefiningOp())
|
||||
work_list_.push_back(inst);
|
||||
}
|
||||
|
||||
// Returns the quantization params for the bias input from the non-bias
|
||||
// operands which have their indexes in the `non_biases` vector. The returned
|
||||
// parameters are calculated by `func`.
|
||||
QuantParams GetBiasParams(Operation *op, int bias,
|
||||
const std::vector<int> &non_biases,
|
||||
AccumulatorScaleFunc func);
|
||||
|
||||
// Sets the quantization parameters of the result to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen on
|
||||
// the input of propagated quantization.
|
||||
bool SetResultParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Sets the quantization parameters of the operand to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen on
|
||||
// the output of propagated quantization.
|
||||
bool SetOperandParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Sets the quantization parameters of the constant result according to its
|
||||
// content.
|
||||
bool SetConstantResultParams(Operation *op, unsigned storage_type_width,
|
||||
bool narrow_range);
|
||||
|
||||
// Inserts the Quantize and Dequantize ops for quantizing the index-th result
|
||||
// of the op.
|
||||
void QuantizeOpResult(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Retrieves all the operands and results quantization states of the op.
|
||||
// Mutable and immutable states are collected in two vectors. Return false
|
||||
// if the there are more than one immutable states and their scales are
|
||||
// different because there are no way the same scale constraint can be
|
||||
// satisfied.
|
||||
bool GetAllQuantStatesCanBeSameScale(
|
||||
Operation *op, std::vector<QuantState *> *mutable_states,
|
||||
std::vector<QuantState *> *immutable_states);
|
||||
|
||||
// A heuristic to determine what the quantization parameters are to stisfy
|
||||
// the same scale constraints. Return NULL if it isn't determined, mainly
|
||||
// because none of the values are quantized.
|
||||
QuantState *GetFinalSameState(
|
||||
Operation *op, const std::vector<QuantState *> &mutable_states,
|
||||
const std::vector<QuantState *> &immutable_states);
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
QuantState &GetOperandState(Operation *op, int index) {
|
||||
return states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
QuantState &GetResultState(Operation *op, int index) {
|
||||
return states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Uses the type of `val` to set the initial state of the index-th result if
|
||||
// `as_result` is true or index-th operand if `as_result` is false. The state
|
||||
// is immutable if the type is a quantized type. Returns the index of this
|
||||
// new state in the state vector.
|
||||
int InitializeState(Operation *op, int index, Value *val, bool as_result);
|
||||
|
||||
// Sets the state of the index-th operand of the op. If this operand is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeOperandState(Operation *op, int index, Value *in,
|
||||
llvm::DenseMap<Value *, int> *cache) {
|
||||
auto cached = cache->insert({in, 0});
|
||||
if (!cached.second) {
|
||||
operand_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
|
||||
}
|
||||
|
||||
// Sets the state of the index-th result of the op. If this result is cached,
|
||||
// uses the cached result without creating new entry in the state vector.
|
||||
// Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeResultState(Operation *op, int index, Value *res,
|
||||
llvm::DenseMap<Value *, int> *cache) {
|
||||
auto cached = cache->insert({res, 0});
|
||||
if (!cached.second) {
|
||||
result_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
|
||||
}
|
||||
|
||||
OpBuilder builder_;
|
||||
|
||||
// All the ops needs to propagate the quantization parameters to.
|
||||
std::vector<Operation *> work_list_;
|
||||
|
||||
// The vector contains all the quantization parameters propagated from the
|
||||
// defining operations of the value, or from the quantization aware training.
|
||||
std::vector<QuantState> states_;
|
||||
|
||||
// The map contains all the quantization parameters which are required to
|
||||
// satisfy the same operands and results constraint. The keys of this map are
|
||||
// the values from `operand_states_` and `result_state_`.
|
||||
std::unordered_map<int, RequantizeState> rescale_states_;
|
||||
|
||||
// Maps of indexes to the propagation state vector from the ops results and
|
||||
// op operands. Both maps are unmodified after initialization.
|
||||
llvm::DenseMap<OpValue, int> operand_states_;
|
||||
llvm::DenseMap<OpValue, int> result_states_;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||
} // namespace
|
||||
|
||||
// TODO(fengliuai): cache the quantization parameters.
|
||||
std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
|
||||
return GetOpQuantSpec(op);
|
||||
}
|
||||
|
||||
bool QuantizationDriver::IsQuantized(Operation *op) {
|
||||
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
if (GetResultState(op, i).IsEmpty()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
|
||||
bool as_result) {
|
||||
QuantParams params =
|
||||
quant::QuantizedType::getQuantizedElementType(val->getType());
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
if (as_result)
|
||||
result_states_.insert({{op, index}, next_state_index});
|
||||
else
|
||||
operand_states_.insert({{op, index}, next_state_index});
|
||||
|
||||
return next_state_index;
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetConstantResultParams(Operation *op,
|
||||
unsigned storage_type_width,
|
||||
bool narrow_range) {
|
||||
ElementsAttr attr;
|
||||
Value *res = op->getResult(0);
|
||||
if (!matchPattern(res, m_Constant(&attr))) {
|
||||
return false;
|
||||
}
|
||||
auto final_type = GetUniformQuantizedTypeForElementsAttr(
|
||||
attr, storage_type_width, narrow_range)
|
||||
.dyn_cast_or_null<quant::QuantizedType>();
|
||||
if (!final_type) return false;
|
||||
return SetResultParams(op, 0, final_type);
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
|
||||
QuantParams params) {
|
||||
auto &state = GetResultState(op, res_index);
|
||||
if (state.immutable) return false;
|
||||
|
||||
if (state.IsEmpty()) {
|
||||
state.params = params;
|
||||
AddUserToList(op, res_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (state.params != params) {
|
||||
auto rescale = GetResultRequantizeState(op, res_index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_INPUT;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
QuantParams QuantizationDriver::GetBiasParams(
|
||||
Operation *op, int bias, const std::vector<int> &non_biases,
|
||||
AccumulatorScaleFunc func) {
|
||||
auto &bias_state = GetOperandState(op, bias);
|
||||
if (!bias_state.IsEmpty()) {
|
||||
return bias_state.params;
|
||||
}
|
||||
std::vector<QuantParams> op_types;
|
||||
op_types.reserve(non_biases.size());
|
||||
for (auto non_bias : non_biases) {
|
||||
auto &non_bias_type = GetOperandState(op, non_bias);
|
||||
op_types.push_back(non_bias_type.params);
|
||||
}
|
||||
if (op_types.empty()) return {};
|
||||
return func(op_types);
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetOperandParams(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
auto &state = GetOperandState(op, index);
|
||||
if (state.immutable) return false;
|
||||
|
||||
if (state.IsEmpty()) {
|
||||
state.params = params;
|
||||
AddOperandToList(op, index);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (state.params != params) {
|
||||
auto rescale = GetOperandRequantizeState(op, index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_OUTPUT;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
Value *original_result = op->getResult(index);
|
||||
Type expressed_type = original_result->getType();
|
||||
Type new_type = params.castFromExpressedType(expressed_type);
|
||||
TypeAttr type_attr = builder_.getTypeAttr(new_type);
|
||||
builder_.setInsertionPoint(op);
|
||||
auto quantize = builder_.create<TFL::QuantizeOp>(op->getLoc(), new_type,
|
||||
original_result, type_attr);
|
||||
auto dequantize = builder_.create<TFL::DequantizeOp>(
|
||||
op->getLoc(), expressed_type, quantize.output());
|
||||
|
||||
// New ops are inserted before `op`, so here to adjust the order.
|
||||
op->moveBefore(quantize);
|
||||
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
original_result->replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, original_result);
|
||||
}
|
||||
|
||||
bool QuantizationDriver::GetAllQuantStatesCanBeSameScale(
|
||||
Operation *op, std::vector<QuantState *> *mutable_states,
|
||||
std::vector<QuantState *> *immutable_states) {
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto &state = GetOperandState(op, i);
|
||||
if (state.immutable) {
|
||||
if (immutable_states->empty() ||
|
||||
state.params == immutable_states->front()->params) {
|
||||
immutable_states->push_back(&state);
|
||||
} else {
|
||||
// Multiple immutable states have different scale, quantization fails.
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
mutable_states->push_back(&state);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto &state = GetResultState(op, i);
|
||||
if (state.immutable) {
|
||||
if (immutable_states->empty() ||
|
||||
state.params == immutable_states->front()->params) {
|
||||
immutable_states->push_back(&state);
|
||||
} else {
|
||||
// Multiple immutable states have different scale, quantization fails.
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
mutable_states->push_back(&state);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// A heuristic to determine what the quantization parameters are to satisfy
|
||||
// the same scale constraints:
|
||||
// - use an immutable state, or,
|
||||
// - use the single input if it is ready, or,
|
||||
// - use the single output if it is ready, or,
|
||||
// - use use the first ready one in the collection.
|
||||
QuantState *QuantizationDriver::GetFinalSameState(
|
||||
Operation *op, const std::vector<QuantState *> &mutable_states,
|
||||
const std::vector<QuantState *> &immutable_states) {
|
||||
if (!immutable_states.empty()) {
|
||||
return immutable_states.front();
|
||||
}
|
||||
|
||||
if (op->getNumOperands() == 1) {
|
||||
auto &state = GetOperandState(op, 0);
|
||||
if (!state.IsEmpty()) return &state;
|
||||
}
|
||||
|
||||
if (op->getNumResults() == 1) {
|
||||
auto &state = GetResultState(op, 0);
|
||||
if (!state.IsEmpty()) return &state;
|
||||
}
|
||||
|
||||
// The first one which is not empty. This case is rare.
|
||||
for (auto *state : mutable_states) {
|
||||
if (!state->IsEmpty()) return state;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This method scans the operations in the function to setup the initial
|
||||
// states for quantization parameter propagation.
|
||||
// TODO(fengliuai): This algorithm assumes there are only one pair of
|
||||
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
|
||||
// check should be applied.
|
||||
void QuantizationDriver::Initialize() {
|
||||
llvm::DenseMap<Value *, int> value_to_state;
|
||||
|
||||
builder_.getRegion()->walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator()) return;
|
||||
if (!GetQuantSpec(op)->is_quantizable) return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto *operand = op->getOperand(i);
|
||||
auto *inst = operand->getDefiningOp();
|
||||
if (!inst) continue;
|
||||
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
operand = dq.input();
|
||||
}
|
||||
InitializeOperandState(op, i, operand, &value_to_state);
|
||||
}
|
||||
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
|
||||
auto *result = op->getResult(res);
|
||||
// If the result has been quantized, it should only be used by a
|
||||
// tfl.quantize op. For this case, we uses the quantized result to create
|
||||
// the state and mark it immutable.
|
||||
if (result->hasOneUse()) {
|
||||
auto user = result->use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
}
|
||||
}
|
||||
InitializeResultState(op, res, result, &value_to_state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool QuantizationDriver::PropagateParams() {
|
||||
// TODO(fengliuai): uses a typed indicator instead of a bool value.
|
||||
bool changed = false;
|
||||
while (!work_list_.empty()) {
|
||||
Operation *op = work_list_.back();
|
||||
work_list_.pop_back();
|
||||
auto spec = GetQuantSpec(op);
|
||||
|
||||
// If the op has no quantizable result, the quantization parameters will not
|
||||
// be propagated to the results.
|
||||
if (!spec->is_quantizable) continue;
|
||||
|
||||
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
||||
// This constant is used as a bias in another op, then the quantization
|
||||
// parameters are determined by that op.
|
||||
if (UsedAsBias(cst) || IsQuantized(op)) continue;
|
||||
|
||||
// The quantization parameters are determined by the content of the
|
||||
// constant.
|
||||
// TODO(fengliuai): the storage_type_width should be from higher level.
|
||||
changed |= SetConstantResultParams(op, /*storage_type_width=*/8,
|
||||
/*narrow_range=*/false);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (spec->requires_same_scale) {
|
||||
std::vector<QuantState *> mutable_states, immutable_states;
|
||||
if (!GetAllQuantStatesCanBeSameScale(op, &mutable_states,
|
||||
&immutable_states)) {
|
||||
// Constraints couldn't be satisfied, so this needs to return `false`
|
||||
// unconditionally, then the Finalize step will be skipped. It shouldn't
|
||||
// continue or partially quantize the model.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto *final_state =
|
||||
GetFinalSameState(op, mutable_states, immutable_states);
|
||||
if (!final_state) continue;
|
||||
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
|
||||
changed |= SetOperandParams(op, i, final_state->params);
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res)
|
||||
changed |= SetResultParams(op, res, final_state->params);
|
||||
}
|
||||
|
||||
for (int i = 0, e = spec->restricted_output_params.size(); i != e; ++i)
|
||||
changed |= SetResultParams(op, i, spec->restricted_output_params[i]);
|
||||
|
||||
for (auto &it : spec->biases_params) {
|
||||
auto params =
|
||||
GetBiasParams(op, it.first, it.second.first, it.second.second);
|
||||
changed |= SetOperandParams(op, it.first, params);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void QuantizationDriver::Finalize() {
|
||||
for (auto it : result_states_) {
|
||||
Operation *op = it.first.first;
|
||||
int res_index = it.first.second;
|
||||
auto &state = GetResultState(op, res_index);
|
||||
if (state.IsEmpty() || state.immutable) {
|
||||
continue;
|
||||
}
|
||||
QuantizeOpResult(op, res_index, state.params);
|
||||
|
||||
auto &requantize = GetResultRequantizeState(op, res_index);
|
||||
DCHECK(requantize.pos == RequantizeState::NO_REQUANTIZE)
|
||||
<< "Unimplemented requantize handling";
|
||||
}
|
||||
}
|
||||
|
||||
void QuantizationDriver::Run() {
|
||||
Initialize();
|
||||
if (PropagateParams()) {
|
||||
Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyQuantizationParamsPropagation(mlir::Function *func) {
|
||||
QuantizationDriver(func).Run();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
122
tensorflow/compiler/mlir/lite/utils/quantization_utils.cc
Normal file
122
tensorflow/compiler/mlir/lite/utils/quantization_utils.cc
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h"
|
||||
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Returns the quantized type for the
|
||||
// input_type/min/max/storag_type_width/narrow_range.
|
||||
static Type GetQuantizedType(Builder builder, Type input_type, double min,
|
||||
double max, int storage_type_width,
|
||||
bool narrow_range) {
|
||||
auto converter =
|
||||
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
|
||||
|
||||
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
|
||||
converter.expressedType);
|
||||
return converter.convert(quantizedEleType);
|
||||
}
|
||||
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
|
||||
FloatAttr max, Type storage_type,
|
||||
bool narrow_range) {
|
||||
int storage_type_width = storage_type.cast<IntegerType>().getWidth();
|
||||
Type final_type = GetQuantizedType(
|
||||
builder, input_type, min.getValueAsDouble(), max.getValueAsDouble(),
|
||||
storage_type_width, narrow_range);
|
||||
return builder.getTypeAttr(final_type);
|
||||
}
|
||||
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
|
||||
Attribute max, IntegerAttr num_bits,
|
||||
BoolAttr narrow_range) {
|
||||
FloatAttr min_value = GetSingleElementAsFloatOrSelf(min);
|
||||
FloatAttr max_value = GetSingleElementAsFloatOrSelf(max);
|
||||
if (!min_value || !max_value) return {};
|
||||
return GetQuantizedTypeAttr(builder, input_type, min_value, max_value,
|
||||
builder.getIntegerType(num_bits.getInt()),
|
||||
narrow_range.getValue());
|
||||
}
|
||||
|
||||
Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
|
||||
unsigned storage_type_width,
|
||||
bool narrow_range) {
|
||||
Builder builder(attr.getContext());
|
||||
double min = std::numeric_limits<double>::max();
|
||||
double max = std::numeric_limits<double>::min();
|
||||
if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
|
||||
double ele_value = FloatAttr::getValueAsDouble(*it);
|
||||
min = std::min(min, ele_value);
|
||||
max = std::max(max, ele_value);
|
||||
}
|
||||
// The range must straddle zero.
|
||||
if (min > 0.0 || max < 0.0) return {};
|
||||
auto type = GetQuantizedType(builder, attr.getType(), min, max,
|
||||
storage_type_width, narrow_range);
|
||||
if (auto ele_type = type.dyn_cast_or_null<TensorType>())
|
||||
return ele_type.getElementType();
|
||||
}
|
||||
|
||||
// The range from SplatElementAttr and other element attribute types couldn't
|
||||
// straddle zero, so the quantization parameters couldn't be derived from its
|
||||
// range.
|
||||
return {};
|
||||
}
|
||||
|
||||
quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
const std::vector<quant::QuantizedType>& op_types) {
|
||||
if (op_types.empty()) return {};
|
||||
|
||||
double scale = 1.0;
|
||||
for (unsigned i = 0, e = op_types.size(); i != e; ++i) {
|
||||
auto qtype = op_types[i].dyn_cast_or_null<quant::UniformQuantizedType>();
|
||||
if (!qtype) return {};
|
||||
scale *= qtype.getScale();
|
||||
}
|
||||
auto type = op_types.back().cast<quant::UniformQuantizedType>();
|
||||
Builder builder(type.getContext());
|
||||
IntegerType storageType = builder.getIntegerType(32);
|
||||
return quant::UniformQuantizedType::getChecked(
|
||||
/*flags=*/true, storageType, type.getExpressedType(), scale,
|
||||
/*zeroPoint=*/0,
|
||||
quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32),
|
||||
quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32),
|
||||
builder.getUnknownLoc());
|
||||
}
|
||||
|
||||
ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
|
||||
if (auto q_type =
|
||||
quant::QuantizedType::getQuantizedElementType(tensor_type)) {
|
||||
Type converted_type;
|
||||
return quant::quantizeAttr(real_value, q_type, converted_type)
|
||||
.cast<ElementsAttr>();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
70
tensorflow/compiler/mlir/lite/utils/quantization_utils.h
Normal file
70
tensorflow/compiler/mlir/lite/utils/quantization_utils.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with op attributes.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Converts the min/max/storage_type/narrow_range information to a
|
||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
|
||||
FloatAttr max, Type storage_type,
|
||||
bool narrow_range = false);
|
||||
|
||||
// Converts the min/max/num_bits/narrow_range information to a
|
||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
|
||||
Attribute max, IntegerAttr num_bits,
|
||||
BoolAttr narrow_range);
|
||||
|
||||
// Quantizes the elements in the attribute `real_value` by the quantization
|
||||
// parameters in `tensor_type`. Returns empty Attribute if the
|
||||
// `tensor_type` is not a QuantizedType or the quantization fails.
|
||||
ElementsAttr Quantize(Attribute real_value, Type tensor_type);
|
||||
|
||||
// Returns the quantized type for an element attribute. The quantization
|
||||
// parameters in this type is based on the min and max element of the attribute.
|
||||
// When the elements in the `attr` are not in floating-point, or the value range
|
||||
// isn't straddling zero, an empty type is returned.
|
||||
Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
|
||||
unsigned storage_type_width,
|
||||
bool narrow_range = false);
|
||||
|
||||
// Returns the quantized type of a bias input, given the quantized types of
|
||||
// other operands which are multiply-accumulated (the bias is added to the
|
||||
// accumulated value).
|
||||
quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
const std::vector<quant::QuantizedType>& op_types);
|
||||
|
||||
// Propagates quantization parameters across ops in this function and satisfy
|
||||
// the quantization specification of the ops. This methods assumes the initial
|
||||
// quantization parameters are stored as adjacent quantize and dequantize ops
|
||||
// and the propagation results are materialized by inserting pairs of quantize
|
||||
// and dequantize ops to this function.
|
||||
void ApplyQuantizationParamsPropagation(mlir::Function* func);
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
|
72
tensorflow/compiler/mlir/lite/utils/validators.cc
Normal file
72
tensorflow/compiler/mlir/lite/utils/validators.cc
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Returns true if the given `op`
|
||||
// * has an attribute with the given `name`,
|
||||
// * and the attribute is an integer list of the form [1, X, Y, 1],
|
||||
// and writes X, Y as 32-bit integer attribute to `x`, `y`.
|
||||
bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
|
||||
IntegerAttr *y) {
|
||||
auto attr = op->getAttrOfType<ArrayAttr>(name);
|
||||
if (!attr) return false;
|
||||
|
||||
auto elements = attr.getValue();
|
||||
if (elements.size() != 4 ||
|
||||
std::any_of(elements.begin(), elements.end(),
|
||||
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
|
||||
return false;
|
||||
|
||||
if (elements.front().cast<IntegerAttr>().getInt() != 1 ||
|
||||
elements.back().cast<IntegerAttr>().getInt() != 1)
|
||||
return false;
|
||||
|
||||
Builder b(op->getContext());
|
||||
*x = b.getI32IntegerAttr(elements[1].cast<IntegerAttr>().getInt());
|
||||
*y = b.getI32IntegerAttr(elements[2].cast<IntegerAttr>().getInt());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
|
||||
bool TFIntListIs1XY1(const ArrayAttr &attr) {
|
||||
const auto &elements = attr.getValue();
|
||||
if (elements.size() != 4 ||
|
||||
std::any_of(elements.begin(), elements.end(),
|
||||
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
|
||||
return false;
|
||||
|
||||
if (elements.front().cast<IntegerAttr>().getValue() != 1 ||
|
||||
elements.back().cast<IntegerAttr>().getValue() != 1)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
|
||||
// This would return false if we had unranked tensors (where they should
|
||||
// probably be considered as broadcastable), but given we are working with
|
||||
// attributes here that shouldn't be an issue,
|
||||
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
73
tensorflow/compiler/mlir/lite/utils/validators.h
Normal file
73
tensorflow/compiler/mlir/lite/utils/validators.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common validators used by TFLite transformation
|
||||
// passes to validate op attributes or values.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// TODO(jpienaar): Change these to being one of these variants and/or generate
|
||||
// these predicates.
|
||||
|
||||
// Returns true if the given TensorFlow op does not have a `data_format`
|
||||
// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC".
|
||||
inline bool TFDataFormatIsNHWC(Operation *op) {
|
||||
auto attr = op->getAttrOfType<StringAttr>("data_format");
|
||||
return !attr || attr.getValue() == "NHWC";
|
||||
}
|
||||
|
||||
// Returns true if the given `op`
|
||||
// * has an attribute with the given `name`,
|
||||
// * and the attribute is an integer list of the form [1, X, Y, 1],
|
||||
// and writes X, Y as 32-bit integer attribute to `x`, `y`.
|
||||
bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
|
||||
IntegerAttr *y);
|
||||
|
||||
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
|
||||
bool TFIntListIs1XY1(const ArrayAttr &attr);
|
||||
|
||||
// Returns true iff the given value is a float tensor.
|
||||
// is "DT_FLOAT".
|
||||
inline bool TFTypeIsFloatTensor(Value *value) {
|
||||
auto tensorType = value->getType().dyn_cast<TensorType>();
|
||||
if (!tensorType) return false;
|
||||
return tensorType.getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// Returns true iff the given TensorFlow op has a `padding` attribute whose
|
||||
// value is "SAME" or "VALID", and writes the attribute to `padding`.
|
||||
inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
|
||||
auto padding_attr = op->getAttrOfType<StringAttr>("padding");
|
||||
if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID")
|
||||
return false;
|
||||
*padding = padding_attr;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns whether the given `a` and `b` have broadcast-compatible
|
||||
/// types.
|
||||
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
57
tensorflow/compiler/mlir/runlit.cfg.py
Normal file
57
tensorflow/compiler/mlir/runlit.cfg.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Lit runner configuration."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import lit.formats
|
||||
from lit.llvm import llvm_config
|
||||
from lit.llvm.subst import ToolSubst
|
||||
|
||||
# Lint for undefined variables is disabled as config is not defined inside this
|
||||
# file, instead config is injected by way of evaluating runlit.cfg.py from
|
||||
# runlit.site.cfg.py which in turn is evaluated by lit.py. The structure is
|
||||
# common for lit tests and intended to only persist temporarily (b/136126535).
|
||||
# pylint: disable=undefined-variable
|
||||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'MLIR ' + os.path.basename(config.mlir_test_dir)
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = config.mlir_test_dir
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.environ['RUNFILES_DIR']
|
||||
|
||||
llvm_config.use_default_substitutions()
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
|
||||
tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
config.mlir_tools_dir, config.llvm_tools_dir
|
||||
]
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
# pylint: enable=undefined-variable
|
56
tensorflow/compiler/mlir/runlit.site.cfg.py
Normal file
56
tensorflow/compiler/mlir/runlit.site.cfg.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2019 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Lit runner site configuration."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import lit.llvm
|
||||
|
||||
# Lint for undefined variables is disabled as config is not defined inside this
|
||||
# file, instead config is injected by lit.py. The structure is common for lit
|
||||
# tests and intended to only persist temporarily (b/136126535).
|
||||
# pylint: disable=undefined-variable
|
||||
config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm')
|
||||
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
|
||||
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
|
||||
'local_config_mlir')
|
||||
# TODO(jpienaar): Replace with sufffices in build rule.
|
||||
config.suffixes = ['.td', '.mlir', '.pbtxt']
|
||||
|
||||
mlir_tf_tools_dirs = [
|
||||
'tensorflow/compiler/mlir',
|
||||
'tensorflow/compiler/mlir/lite',
|
||||
'tensorflow/compiler/mlir/tensorflow',
|
||||
'tensorflow/compiler/mlir/xla',
|
||||
]
|
||||
config.mlir_tf_tools_dirs = [
|
||||
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'], s)
|
||||
for s in mlir_tf_tools_dirs
|
||||
]
|
||||
test_dir = os.environ['TEST_TARGET']
|
||||
test_dir = test_dir.strip('/').rsplit(':', 1)[0]
|
||||
config.mlir_test_dir = os.path.join(os.environ['TEST_SRCDIR'],
|
||||
os.environ['TEST_WORKSPACE'], test_dir)
|
||||
lit.llvm.initialize(lit_config, config)
|
||||
|
||||
# Let the main config do the real work.
|
||||
lit_config.load_config(
|
||||
config,
|
||||
os.path.join(
|
||||
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'],
|
||||
'tensorflow/compiler/mlir/runlit.cfg.py')))
|
||||
# pylint: enable=undefined-variable
|
567
tensorflow/compiler/mlir/tensorflow/BUILD
Normal file
567
tensorflow/compiler/mlir/tensorflow/BUILD
Normal file
@ -0,0 +1,567 @@
|
||||
load("@local_config_mlir//:tblgen.bzl", "gentbl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tensorflow_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tf_generated_ops.td",
|
||||
"ir/tf_op_base.td",
|
||||
"ir/tf_ops.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_ops_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls",
|
||||
"ir/tf_ops.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs",
|
||||
"ir/tf_ops.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tf_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "ir/tf_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_executor_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls",
|
||||
"ir/tf_executor.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs",
|
||||
"ir/tf_executor.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tf_executor.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "ir/tf_executor_ops.td",
|
||||
td_srcs = [
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"@local_config_mlir//:include/mlir/StandardOps/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_canonicalize_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/canonicalize.td",
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow",
|
||||
srcs = [
|
||||
"ir/control_flow_ops.cc",
|
||||
"ir/tf_executor.cc",
|
||||
"ir/tf_executor.cc.inc",
|
||||
"ir/tf_executor.h.inc",
|
||||
"ir/tf_ops.cc",
|
||||
"ir/tf_ops.cc.inc",
|
||||
"ir/tf_ops.h.inc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/control_flow_ops.h",
|
||||
"ir/tf_executor.h",
|
||||
"ir/tf_ops.h",
|
||||
"ir/tf_types.def",
|
||||
"ir/tf_types.h",
|
||||
"transforms/passes.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":tensorflow_canonicalize_inc_gen",
|
||||
":tensorflow_executor_inc_gen",
|
||||
":tensorflow_ops_inc_gen",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
"@local_config_mlir//:TypeUtilities",
|
||||
],
|
||||
# TODO(jpienaar): Merge in the dialect registration.
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Library with TensorFlow dialect static initialization.
|
||||
cc_library(
|
||||
name = "tensorflow_dialect_registration",
|
||||
srcs = ["ir/dialect_registration.cc"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_graphdef",
|
||||
srcs = [
|
||||
"translate/export_graphdef.cc",
|
||||
"translate/import_graphdef.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/export_graphdef.h",
|
||||
"translate/import_graphdef.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":export_tf_dialect_op",
|
||||
":export_utils",
|
||||
":mangling_util",
|
||||
":mlir_roundtrip_flags",
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "import_utils",
|
||||
srcs = [
|
||||
"utils/import_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/import_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":error_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "export_utils",
|
||||
srcs = [
|
||||
"utils/export_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/export_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":mangling_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "export_tf_dialect_op",
|
||||
srcs = [
|
||||
"translate/derived_attr_populator.inc",
|
||||
"translate/export_tf_dialect_op.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/export_tf_dialect_op.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_type",
|
||||
":export_utils",
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_tf_dialect_op",
|
||||
srcs = ["translate/translate_tf_dialect_op.cc"],
|
||||
deps = [
|
||||
":export_tf_dialect_op",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_roundtrip_pass",
|
||||
srcs = ["translate/mlir_roundtrip_pass.cc"],
|
||||
hdrs = ["translate/mlir_roundtrip_pass.h"],
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_roundtrip_flags",
|
||||
srcs = ["translate/mlir_roundtrip_flags.cc"],
|
||||
hdrs = ["translate/mlir_roundtrip_flags.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_type",
|
||||
srcs = ["utils/convert_type.cc"],
|
||||
hdrs = ["utils/convert_type.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_dialect_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_tensor",
|
||||
srcs = ["utils/convert_tensor.cc"],
|
||||
hdrs = ["utils/convert_tensor.h"],
|
||||
deps = [
|
||||
":convert_type",
|
||||
":mangling_util",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mangling_util",
|
||||
srcs = ["utils/mangling_util.cc"],
|
||||
hdrs = ["utils/mangling_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "error_util",
|
||||
srcs = ["utils/error_util.cc"],
|
||||
hdrs = ["utils/error_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_dialect_passes",
|
||||
srcs = [
|
||||
"transforms/constant_fold.cc",
|
||||
"transforms/decode_constant.cc",
|
||||
"transforms/dialect_hooks.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/constant_fold.h",
|
||||
"transforms/decode_constant.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":eval_util",
|
||||
":tensorflow",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_dialect_lib",
|
||||
deps = [
|
||||
":tensorflow_dialect_registration",
|
||||
":tf_dialect_passes",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eval_util",
|
||||
srcs = ["utils/eval_util.cc"],
|
||||
hdrs = ["utils/eval_util.h"],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":export_tf_dialect_op",
|
||||
":mangling_util",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_lib",
|
||||
srcs = [
|
||||
"translate/tf_mlir_translate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/tf_mlir_translate.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":error_util",
|
||||
":import_utils",
|
||||
":mangling_util",
|
||||
":mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_cl_options",
|
||||
srcs = [
|
||||
"translate/tf_mlir_translate_cl.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/tf_mlir_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_registration",
|
||||
srcs = [
|
||||
"translate/tf_mlir_translate_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
":translate_cl_options",
|
||||
":translate_lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-mlir-translate",
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
":translate_cl_options",
|
||||
":translate_lib",
|
||||
":translate_registration",
|
||||
":translate_tf_dialect_op",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "error_util_test",
|
||||
srcs = ["utils/error_util_test.cc"],
|
||||
deps = [
|
||||
":error_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "derived_attr_populator_gen",
|
||||
srcs = [
|
||||
"translate/derived_attr_populator_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "derived_attr_populator_inc",
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"ir/tf_generated_ops.td",
|
||||
"ir/tf_op_base.td",
|
||||
"ir/tf_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"translate/derived_attr_populator.inc",
|
||||
],
|
||||
cmd = ("$(location :derived_attr_populator_gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"$(location //tensorflow/compiler/mlir/tensorflow:ir/tf_ops.td) " + " -o $@"),
|
||||
tools = [":derived_attr_populator_gen"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tensorflow_optimize_td_files",
|
||||
srcs = [
|
||||
"transforms/optimize.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_optimize_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_optimize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/optimize.td",
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
],
|
||||
)
|
2721
tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md
Executable file
2721
tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md
Executable file
File diff suppressed because it is too large
Load Diff
67
tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc
Normal file
67
tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements the named operations for the "Control Flow" dialect of
|
||||
// TensorFlow graphs
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
||||
// TODO(ycao): Implement following verify methods when we know more about their
|
||||
// invariant.
|
||||
LogicalResult EnterOp::verify() { return success(); }
|
||||
|
||||
LogicalResult MergeOp::verify() { return success(); }
|
||||
|
||||
LogicalResult NextIterationSourceOp::verify() { return success(); }
|
||||
|
||||
LogicalResult NextIterationSinkOp::verify() { return success(); }
|
||||
|
||||
LogicalResult LoopCondOp::verify() { return success(); }
|
||||
|
||||
LogicalResult SwitchOp::verify() { return success(); }
|
||||
|
||||
LogicalResult ExitOp::verify() { return success(); }
|
||||
|
||||
TFControlFlowDialect::TFControlFlowDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"_tf", context) {
|
||||
addOperations<SwitchOp, MergeOp, EnterOp, NextIterationSourceOp,
|
||||
NextIterationSinkOp, ExitOp, LoopCondOp>();
|
||||
addTypes<TFControlType>();
|
||||
|
||||
// We allow unregistered TensorFlow operations in the control dialect.
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type TFControlFlowDialect::parseType(StringRef tyData, Location loc) const {
|
||||
if (tyData != "control")
|
||||
return (emitError(loc, "unknown TFControl type: " + tyData), nullptr);
|
||||
return TFControlType::get(getContext());
|
||||
}
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void TFControlFlowDialect::printType(Type type, raw_ostream &os) const {
|
||||
assert(type.isa<TFControlType>());
|
||||
os << "control";
|
||||
}
|
||||
|
||||
} // namespace TFControlFlow
|
||||
} // namespace mlir
|
278
tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
Normal file
278
tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
Normal file
@ -0,0 +1,278 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations for the "Control Flow" dialect of TensorFlow
|
||||
// graphs. The TensorFlow control flow dialect represents control flow with
|
||||
// Switch/Merge and a few related control flow nodes, along with control
|
||||
// dependencies. This dialect can be raised to the standard TensorFlow dialect
|
||||
// by transforming Switch/Merge and other control flow ops into functional
|
||||
// control flow ops and removing control dependencies.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
||||
class TFControlFlowDialect : public Dialect {
|
||||
public:
|
||||
explicit TFControlFlowDialect(MLIRContext *context);
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type parseType(StringRef tyData, Location loc) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type type, raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
namespace TensorFlowControlTypes {
|
||||
enum Kind {
|
||||
Control = Type::FIRST_TENSORFLOW_CONTROL_TYPE,
|
||||
};
|
||||
}
|
||||
|
||||
class TFControlType : public Type::TypeBase<TFControlType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static TFControlType get(MLIRContext *context) {
|
||||
return Base::get(context, TensorFlowControlTypes::Control);
|
||||
}
|
||||
|
||||
// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == TensorFlowControlTypes::Control;
|
||||
}
|
||||
};
|
||||
|
||||
// The "_tf.Enter" operation forwards its input to Tensorflow while loop. Each
|
||||
// tensor needs its own _tf.Enter to be made available inside the while loop.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// This is defined in Tensorflow as:
|
||||
//
|
||||
// REGISTER_OP("Enter")
|
||||
// .Input("data: T")
|
||||
// .Output("output: T")
|
||||
// .Attr("T: type")
|
||||
// .Attr("frame_name: string")
|
||||
// .Attr("is_constant: bool = false")
|
||||
// .Attr("parallel_iterations: int = 10")
|
||||
//
|
||||
// For example:
|
||||
// %1 = "_tf.Enter"(%0#0) {T: "tfdtype$DT_INT32", frame_name:
|
||||
// "while/while_context",} : (tensor<i32>) -> (tensor<*xi32>)
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class EnterOp
|
||||
: public Op<EnterOp, OpTrait::AtLeastNOperands<1>::Impl,
|
||||
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "_tf.Enter"; }
|
||||
|
||||
Value *getData() { return getOperand(0); }
|
||||
void setData(Value *value) { setOperand(0, value); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "_tf.Merge" operation takes a list of input operands and returns a value
|
||||
// of the operand type along with the index of the first match encountered.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// This is defined in TensorFlow as:
|
||||
//
|
||||
// REGISTER_OP("Merge")
|
||||
// .Input("inputs: N * T")
|
||||
// .Output("output: T")
|
||||
// .Output("value_index: int32")
|
||||
//
|
||||
// For example:
|
||||
// %2 = _tf.Merge %0, %1, %2, %3 : tensor<??xf32>
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class MergeOp : public Op<MergeOp, OpTrait::VariadicOperands,
|
||||
OpTrait::NResults<3>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "_tf.Merge"; }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "_tf.NextIteration.source" and "_tf.NextIteration.sink" operations form
|
||||
// a logical pair. Together, they represent NextIteration op in Tensorflow.
|
||||
//
|
||||
// Tensorflow NextIteration operation forwards its input to the next iteration
|
||||
// of a while loop. Each loop variable needs its own NextIteration op.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// NextIteration op is broken into _tf.NextIteration.sink and
|
||||
// _tf.NextIteration.source because NextIteration is a back-edge in Tensorflow
|
||||
// graph, which would form a data flow cycle if expressed naively in a basic
|
||||
// block. _tf.NextIteration.source takes no input but returns results while
|
||||
// _tf.NextIteration.sink takes input but doesn't return anything. When
|
||||
// optimizing these ops, they are paired by op names and considered as a
|
||||
// single op.
|
||||
//
|
||||
// This is defined in Tensorflow as:
|
||||
//
|
||||
// REGISTER_OP("NextIteration")
|
||||
// .Input("data: T")
|
||||
// .Output("output: T")
|
||||
// .Attr("T: type")
|
||||
//
|
||||
// For example:
|
||||
// %11 = "_tf.NextIteration.source"() {name: "while/NextIteration", T:
|
||||
// "tfdtype$DT_INT32", id: 0} : () -> (tensor<*xi32>, _tf.control)
|
||||
// "_tf.NextIteration.sink"(%10#0) {name: "while/NextIteration", T:
|
||||
// "tfdtype$DT_INT32", id: 0} : (tensor<*xi32>) -> ()
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class NextIterationSourceOp
|
||||
: public Op<NextIterationSourceOp, OpTrait::NResults<2>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "_tf.NextIteration.source"; }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
class NextIterationSinkOp
|
||||
: public Op<NextIterationSinkOp, OpTrait::AtLeastNOperands<1>::Impl,
|
||||
OpTrait::OneResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "_tf.NextIteration.sink"; }
|
||||
|
||||
Value *getData() { return getOperand(0); }
|
||||
void setData(Value *value) { setOperand(0, value); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "_tf.LoopCond" operation forwards a boolean value as loop condition of
|
||||
// Tensorflow while loops.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// This is defined in Tensorflow as:
|
||||
//
|
||||
// REGISTER_OP("LoopCond")
|
||||
// .Input("input: bool")
|
||||
// .Output("output: bool")
|
||||
//
|
||||
// For example:
|
||||
// %5 = "_tf.LoopCond"(%4#0) {device: "", name: "while/LoopCond"} :
|
||||
// (tensor<*xi1>) -> (i1, !_tf.control)
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class LoopCondOp
|
||||
: public Op<LoopCondOp, OpTrait::AtLeastNOperands<1>::Impl,
|
||||
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "_tf.LoopCond"; }
|
||||
|
||||
Value *getData() { return getOperand(0); }
|
||||
void setData(Value *value) { setOperand(0, value); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "_tf.Switch" operation takes a data operand and a boolean predicate
|
||||
// condition, and returns two values matching the type of the data predicate.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// This is defined in TensorFlow as:
|
||||
//
|
||||
// REGISTER_OP("Switch")
|
||||
// .Input("data: T")
|
||||
// .Input("pred: bool")
|
||||
// .Output("output_false: T")
|
||||
// .Output("output_true: T")
|
||||
//
|
||||
// For example:
|
||||
// %2 = _tf.Switch %0, %1 : tensor<??xf32>
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class SwitchOp : public Op<SwitchOp, OpTrait::AtLeastNOperands<2>::Impl,
|
||||
OpTrait::NResults<3>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "_tf.Switch"; }
|
||||
|
||||
Value *getData() { return getOperand(0); }
|
||||
void setData(Value *value) { setOperand(0, value); }
|
||||
|
||||
Value *getPredicate() { return getOperand(1); }
|
||||
void setPredicate(Value *value) { setOperand(1, value); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "_tf.Exit" operation forwards a value from an while loop to its consumer
|
||||
// outside of loop. Each returned tensor needs its own _tf.Exit.
|
||||
//
|
||||
// More details can be found in Tensorflow Controlflow white paper:
|
||||
// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
//
|
||||
// This is defined in Tensorflow as:
|
||||
//
|
||||
// REGISTER_OP("Exit")
|
||||
// .Input("data: T")
|
||||
// .Output("output: T")
|
||||
// .Attr("T: type")
|
||||
//
|
||||
// For example:
|
||||
// %1 = "_tf.Exit"(%0#0) {T: "tfdtype$DT_INT32",} : (tensor<*xi32>) ->
|
||||
// (tensor<*xi32>, !_tf.control)
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class ExitOp : public Op<ExitOp, OpTrait::AtLeastNOperands<1>::Impl,
|
||||
OpTrait::NResults<2>::Impl, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "_tf.Exit"; }
|
||||
|
||||
Value *getData() { return getOperand(0); }
|
||||
void setData(Value *value) { setOperand(0, value); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
} // namespace TFControlFlow
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_CONTROL_FLOW_OPS_H_
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Static initialization for TF dialect registration.
|
||||
static DialectRegistration<TFControlFlow::TFControlFlowDialect>
|
||||
tf_control_flow_ops;
|
||||
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
|
||||
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
|
||||
tf_excutor_dialect;
|
||||
|
||||
} // namespace mlir
|
643
tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
Normal file
643
tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
Normal file
@ -0,0 +1,643 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Executor Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf_executor", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
|
||||
>();
|
||||
addTypes<ControlType>();
|
||||
}
|
||||
|
||||
Type TensorFlowExecutorDialect::parseType(StringRef data_type,
|
||||
Location loc) const {
|
||||
if (data_type == "control") return ControlType::get(getContext());
|
||||
emitError(loc) << "unknown tf_executor type: " << data_type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TensorFlowExecutorDialect::printType(Type type, raw_ostream &os) const {
|
||||
assert(type.isa<ControlType>());
|
||||
os << "control";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementation for all the operations defined in ODS (op definition spec).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Inserts `tf_executor.Terminator` at the end of the region's only block if it
|
||||
// does not have a terminator already. If the region is empty, insert a new
|
||||
// block first.
|
||||
template <typename Terminator>
|
||||
void EnsureExecutorTerminator(Region *region, Builder *builder, Location loc) {
|
||||
if (region->empty()) region->push_back(new Block);
|
||||
|
||||
Block &block = region->back();
|
||||
if (!block.empty() && block.back().isKnownTerminator()) return;
|
||||
|
||||
OperationState terminator_state(loc, Terminator::getOperationName());
|
||||
Terminator::build(builder, &terminator_state, {});
|
||||
block.push_back(Operation::create(terminator_state));
|
||||
}
|
||||
|
||||
// Verifies that every control operands are at the end of the list.
|
||||
// Used by the constraint `ControlOperandsAfterAllData` in ODS.
|
||||
LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
|
||||
bool found_control = false;
|
||||
for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
|
||||
if (op->getOperand(operand_idx)->getType().isa<ControlType>()) {
|
||||
found_control = true;
|
||||
continue;
|
||||
}
|
||||
if (found_control)
|
||||
return op->emitOpError() << "found non-control operand #" << operand_idx
|
||||
<< " after control operand";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.graph
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(GraphOp graph) {
|
||||
auto *executorDialect = graph.getDialect();
|
||||
|
||||
if (graph.GetBody().empty())
|
||||
return graph.emitOpError() << "expects a non-empty body";
|
||||
|
||||
// Only tf_executor dialect operations are allowed to be immediately nested
|
||||
// in a tf_executor.graph region.
|
||||
for (Operation &op : graph.GetBody()) {
|
||||
if (op.getDialect() != executorDialect)
|
||||
return op.emitOpError() << "unallowed inside a tf_executor.graph region";
|
||||
}
|
||||
|
||||
Operation &fetch = graph.GetBody().back();
|
||||
if (!isa<FetchOp>(fetch))
|
||||
return fetch.emitOpError()
|
||||
<< "invalid tf_executor.graph terminator, fetch expected";
|
||||
|
||||
// Ensures that the fetch terminator operands matches the graph result type.
|
||||
// All the non-control operands of the fetch operation must match the graph
|
||||
// returned value.
|
||||
if (fetch.getNumOperands() < graph.getNumResults())
|
||||
return fetch.emitOpError() << "does not have enough operands to cover the "
|
||||
"graph returned values";
|
||||
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
|
||||
Value *operand = fetch.getOperand(i);
|
||||
// Breaks out of the loop at the first control operand encountered.
|
||||
if (operand->getType().isa<ControlType>()) {
|
||||
if (i != graph.getNumResults())
|
||||
return fetch.emitOpError()
|
||||
<< "operand #" << i
|
||||
<< " is a control type, can't be bound to a graph result";
|
||||
break;
|
||||
}
|
||||
if (i >= graph.getNumResults())
|
||||
return fetch.emitOpError()
|
||||
<< "operand #" << i << " does not have a graph results to bind";
|
||||
if (graph.getResult(i)->getType() != operand->getType())
|
||||
return fetch.emitOpError()
|
||||
<< "operand #" << i << " type mismatch graph results";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(GraphOp graph, OpAsmPrinter *p) {
|
||||
*p << graph.getOperationName();
|
||||
p->printRegion(graph.getOperation()->getRegion(0));
|
||||
p->printOptionalAttrDict(graph.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseGraphOp(OpAsmParser *parser, OperationState *result) {
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
|
||||
// Parses the body region.
|
||||
Region &body = *result->addRegion();
|
||||
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
|
||||
|
||||
if (body.getBlocks().size() > 1)
|
||||
return parser->emitError(loc) << "expects a single block region";
|
||||
|
||||
// Ensures that the region is well formed: it contains at least a block with
|
||||
// a FetchOp terminator.
|
||||
EnsureExecutorTerminator<FetchOp>(&body, &parser->getBuilder(),
|
||||
result->location);
|
||||
|
||||
// Gets the results type from the terminator type inside the graph.
|
||||
Operation &fetch = body.back().back();
|
||||
if (!isa<FetchOp>(fetch))
|
||||
return parser->emitError(loc) << "expects a tf_executor.fetch terminator";
|
||||
|
||||
// The return value of the graph operation are the non-control operands of
|
||||
// the fetch operation.
|
||||
result->types.reserve(fetch.getNumOperands());
|
||||
for (Type type : fetch.getOperandTypes()) {
|
||||
if (type.isa<ControlType>()) break;
|
||||
result->types.push_back(type);
|
||||
}
|
||||
|
||||
// Parses the optional attribute list.
|
||||
if (parser->parseOptionalAttributeDict(result->attributes)) return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.fetch
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Print(FetchOp fetch, OpAsmPrinter *p) {
|
||||
*p << fetch.getOperationName();
|
||||
if (fetch.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(fetch.operand_begin(), fetch.operand_end());
|
||||
*p << " : ";
|
||||
interleaveComma(fetch.getOperandTypes(), *p);
|
||||
}
|
||||
p->printOptionalAttrDict(fetch.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseFetchOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
return failure(
|
||||
parser->parseOperandList(opInfo) ||
|
||||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
||||
parser->resolveOperands(opInfo, types, loc, result->operands) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes)
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.island
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(IslandOp island) {
|
||||
if (island.GetBody().empty())
|
||||
return island.emitOpError() << "expects a non-empty body";
|
||||
|
||||
Operation &yield = island.GetBody().back();
|
||||
if (!isa<YieldOp>(yield))
|
||||
return yield.emitOpError()
|
||||
<< "invalid tf_executor.island terminator, yield expected";
|
||||
|
||||
// Ensures that the yield terminator operands matches the island results type.
|
||||
int result_count = island.getNumResults() - 1; // -1 for the control token
|
||||
if (yield.getNumOperands() != result_count)
|
||||
return yield.emitOpError()
|
||||
<< "has " << yield.getNumOperands()
|
||||
<< " operand, but island returns " << result_count;
|
||||
for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
|
||||
if (island.getResult(operand_idx)->getType() !=
|
||||
yield.getOperand(operand_idx)->getType())
|
||||
return yield.emitOpError()
|
||||
<< "operand #" << operand_idx << " type mismatch island results";
|
||||
}
|
||||
|
||||
// Checks that there aren't any control results other than the last one.
|
||||
Type control_type = ControlType::get(island.getContext());
|
||||
for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
|
||||
if (island.getResult(operand_idx)->getType() == control_type)
|
||||
return yield.emitOpError()
|
||||
<< "unexpected control type for operand #" << operand_idx;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(IslandOp op, OpAsmPrinter *p) {
|
||||
*p << op.getOperationName();
|
||||
if (op.getNumOperands()) {
|
||||
// These are always control operand, no explicit type needed.
|
||||
*p << '(';
|
||||
p->printOperands(op.getOperands());
|
||||
*p << ')';
|
||||
}
|
||||
p->printRegion(op.getOperation()->getRegion(0));
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) {
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
Type control_type = ControlType::get(parser->getBuilder().getContext());
|
||||
|
||||
// Parses optional argument list (control dependencies only).
|
||||
SmallVector<OpAsmParser::OperandType, 4> op_infos;
|
||||
if (parser->parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
|
||||
return failure();
|
||||
if (!op_infos.empty()) {
|
||||
SmallVector<Type, 2> types;
|
||||
types.push_back(control_type);
|
||||
parser->resolveOperands(op_infos, types, loc, result->operands);
|
||||
}
|
||||
|
||||
// Parses the body region.
|
||||
Region &body = *result->addRegion();
|
||||
|
||||
// TODO(b/134773778): the custom parser is missing support to implement to
|
||||
// short syntax right now.
|
||||
// if (!parser->parseOptionalKeyword("wraps")) {
|
||||
// body.push_back(new Block);
|
||||
// Block &block = body.back();
|
||||
// parser->getBuilder().setInsertionPointToEnd(&block);
|
||||
// if (parser->parseOperation())
|
||||
// return failure();
|
||||
// }
|
||||
|
||||
if (parser->parseRegion(body, llvm::None, llvm::None)) return failure();
|
||||
|
||||
EnsureExecutorTerminator<YieldOp>(&body, &parser->getBuilder(),
|
||||
result->location);
|
||||
|
||||
// Gets the results type for the island from the terminator operands.
|
||||
Operation &yield = body.back().back();
|
||||
result->types.reserve(yield.getNumOperands() + 1);
|
||||
result->types.append(yield.operand_type_begin(), yield.operand_type_end());
|
||||
result->types.push_back(control_type);
|
||||
|
||||
// Parses the optional attribute list.
|
||||
if (parser->parseOptionalAttributeDict(result->attributes)) return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.yield
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Print(YieldOp yield, OpAsmPrinter *p) {
|
||||
*p << yield.getOperationName();
|
||||
if (yield.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(yield.operand_begin(), yield.operand_end());
|
||||
*p << " : ";
|
||||
interleaveComma(yield.getOperandTypes(), *p);
|
||||
}
|
||||
p->printOptionalAttrDict(yield.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseYieldOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_info;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
return failure(
|
||||
parser->parseOperandList(op_info) ||
|
||||
(!op_info.empty() && parser->parseColonTypeList(types)) ||
|
||||
parser->resolveOperands(op_info, types, loc, result->operands) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.Switch
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult ParseSwitchOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
if (parser->parseOperandList(op_infos, 2) ||
|
||||
parser->parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 1)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects only a single data type";
|
||||
|
||||
// Supports parsing either a functional type (in which case all the types are
|
||||
// fully qualified) or a short form with a single type (in which case the data
|
||||
// input and the outputs are all using this type).
|
||||
if (types.front().isa<FunctionType>()) {
|
||||
FunctionType type = types.front().cast<FunctionType>();
|
||||
if (type.getNumInputs() != 2)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects a single data type and a predicate";
|
||||
result->types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
} else {
|
||||
Type control_type = ControlType::get(parser->getBuilder().getContext());
|
||||
result->types.append(2, types[0]);
|
||||
result->types.push_back(control_type);
|
||||
Type i1_type = parser->getBuilder().getI1Type();
|
||||
RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
|
||||
types.push_back(predicate_type);
|
||||
types.append(op_infos.size() - 2, control_type);
|
||||
}
|
||||
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
if (parser->resolveOperands(op_infos, types, loc, result->operands))
|
||||
return failure();
|
||||
|
||||
return parser->parseOptionalAttributeDict(result->attributes);
|
||||
}
|
||||
|
||||
void Print(SwitchOp switch_op, OpAsmPrinter *p) {
|
||||
*p << switch_op.getOperationName() << ' ';
|
||||
p->printOperands(switch_op.getOperands());
|
||||
Type data_operand_ty = switch_op.data()->getType();
|
||||
// If the types aren't perfectly matching, print the functional type syntax
|
||||
// else print the shorter single type.
|
||||
*p << " : ";
|
||||
if (switch_op.trueOutput()->getType() != data_operand_ty ||
|
||||
switch_op.falseOutput()->getType() != data_operand_ty) {
|
||||
p->printFunctionalType(switch_op.getOperation());
|
||||
} else {
|
||||
*p << switch_op.getType(0);
|
||||
}
|
||||
p->printOptionalAttrDict(switch_op.getAttrs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.SwitchN
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(SwitchNOp switchn) {
|
||||
IntegerAttr num_outs = switchn.getAttrOfType<IntegerAttr>("num_outs");
|
||||
if (!num_outs)
|
||||
return switchn.emitOpError() << "expects a `num_outs` integer attribute";
|
||||
|
||||
// Expects num_outs results + 1 control output.
|
||||
if (switchn.getNumResults() != num_outs.getInt() + 1)
|
||||
return switchn.emitOpError()
|
||||
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
|
||||
<< (switchn.getNumResults() - 1);
|
||||
|
||||
auto operand0_type = switchn.getOperand(0)->getType();
|
||||
for (Value *result : switchn.outputs())
|
||||
if (operand0_type != result->getType())
|
||||
return switchn.emitOpError()
|
||||
<< "type mismatch between data operand and result: "
|
||||
<< operand0_type << " vs " << result->getType();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(SwitchNOp switchn, OpAsmPrinter *p) {
|
||||
*p << switchn.getOperationName() << ' ';
|
||||
auto operands = switchn.getOperands();
|
||||
// Prints the 2 data operands.
|
||||
p->printOperands(operands.begin(), std::next(operands.begin(), 2));
|
||||
*p << " of " << (switchn.getNumResults() - 1);
|
||||
// Prints control dependencies if any
|
||||
if (!llvm::empty(switchn.controlInputs())) {
|
||||
*p << " (";
|
||||
p->printOperands(switchn.controlInputs());
|
||||
*p << ")";
|
||||
}
|
||||
*p << " : " << switchn.getType(0);
|
||||
p->printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
|
||||
}
|
||||
|
||||
ParseResult ParseSwitchNOp(OpAsmParser *parser, OperationState *result) {
|
||||
// Parsing:
|
||||
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
// Where the first operand is the data to replicate, the second is an i32
|
||||
// indicating which output to populate, followed by the keyword `by` and the
|
||||
// number of outputs (+1 for the control token).
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
IntegerAttr num_outs;
|
||||
Type i64_type = parser->getBuilder().getIntegerType(64);
|
||||
if (parser->parseOperandList(op_infos, 2) || parser->parseKeyword("of") ||
|
||||
parser->parseAttribute(num_outs, i64_type, "num_outs",
|
||||
result->attributes) ||
|
||||
parser->parseOperandList(op_infos,
|
||||
OpAsmParser::Delimiter::OptionalParen) ||
|
||||
parser->parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 1)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects only a single data type";
|
||||
|
||||
if (num_outs.getInt() <= 0)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects a positive number of outputs";
|
||||
|
||||
// `types` already contains the type for the data, add an i32 for the
|
||||
// output_index, and then the optional control inputs.
|
||||
types.push_back(parser->getBuilder().getIntegerType(32));
|
||||
Type control_type = ControlType::get(parser->getBuilder().getContext());
|
||||
types.append(op_infos.size() - 2, control_type);
|
||||
|
||||
if (parser->resolveOperands(op_infos, types, loc, result->operands))
|
||||
return failure();
|
||||
|
||||
// Output result types is a replication `num_outs` times the data input type.
|
||||
result->types.append(num_outs.getInt(), types[0]);
|
||||
result->types.push_back(control_type);
|
||||
|
||||
return parser->parseOptionalAttributeDict(result->attributes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.Merge
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(MergeOp merge) {
|
||||
if (!merge.getNumOperands())
|
||||
return merge.emitOpError() << "expects at least one operand";
|
||||
|
||||
Type data_type = merge.getOperand(0)->getType();
|
||||
if (data_type.isa<ControlType>())
|
||||
return merge.emitOpError() << "expects a non-control input";
|
||||
|
||||
// Checks that all operands can be broadcasted to a common type compatible
|
||||
// with the result type.
|
||||
Type broadcasted_type = merge.output()->getType();
|
||||
for (Type operand_type : merge.getOperandTypes()) {
|
||||
if (operand_type.isa<ControlType>()) break;
|
||||
Type new_broadcasted_type =
|
||||
OpTrait::util::getBroadcastedType(broadcasted_type, operand_type);
|
||||
if (!new_broadcasted_type)
|
||||
return merge.emitOpError()
|
||||
<< "expects all operands to be broadcastable"
|
||||
<< " but got " << broadcasted_type << " vs " << operand_type;
|
||||
// Uses the broadcasted type unless we're losing the rank information here.
|
||||
// This is because for example starting with a result of tensor<4xf32>, if
|
||||
// the first operand is unranked, the broadcasted type will be unranked.
|
||||
// Then any tensor operand will be broadcastable to this unranked type.
|
||||
if ((broadcasted_type.isa<TensorType>() &&
|
||||
!broadcasted_type.cast<TensorType>().hasRank()) ||
|
||||
(new_broadcasted_type.isa<TensorType>() &&
|
||||
new_broadcasted_type.cast<TensorType>().hasRank()))
|
||||
broadcasted_type = new_broadcasted_type;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(MergeOp merge, OpAsmPrinter *p) {
|
||||
*p << merge.getOperationName() << ' ';
|
||||
p->printOperands(merge.getOperands());
|
||||
|
||||
// Prints the type signature of the operation.
|
||||
*p << " : " << merge.getType(0);
|
||||
p->printOptionalAttrDict(merge.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseMergeOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
if (parser->parseOperandList(op_infos) || parser->parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 1)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects only a single data type";
|
||||
|
||||
// Expects the type once, but use it for both operands.
|
||||
types.push_back(types.front());
|
||||
// Extra operands are expected to be control inputs.
|
||||
Type control_type = ControlType::get(parser->getBuilder().getContext());
|
||||
types.append(op_infos.size() - 2, control_type);
|
||||
|
||||
if (parser->resolveOperands(op_infos, types, loc, result->operands))
|
||||
return failure();
|
||||
RankedTensorType i32_tensor =
|
||||
RankedTensorType::get({}, parser->getBuilder().getIntegerType(32));
|
||||
result->types = {types.front(), i32_tensor, control_type};
|
||||
|
||||
return parser->parseOptionalAttributeDict(result->attributes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.Enter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Default number for the parallel_iterations attributes on Enter nodes.
|
||||
constexpr int kDefaultParallelIterations = 10;
|
||||
|
||||
void Print(EnterOp enter, OpAsmPrinter *p) {
|
||||
*p << enter.getOperationName() << ' ';
|
||||
p->printOperands(enter.getOperands());
|
||||
|
||||
*p << " frame \"";
|
||||
printEscapedString(enter.frame_name(), p->getStream());
|
||||
*p << "\"";
|
||||
if (enter.parallel_iterations() != kDefaultParallelIterations)
|
||||
*p << " parallel_iterations " << enter.parallel_iterations();
|
||||
if (enter.is_constant()) *p << " constant ";
|
||||
|
||||
// If the types aren't perfectly matching, print the functional type syntax
|
||||
// else print the shorter single type.
|
||||
*p << " : ";
|
||||
if (enter.data()->getType() != enter.output()->getType()) {
|
||||
p->printFunctionalType(enter.getOperation());
|
||||
} else {
|
||||
*p << enter.getType(0);
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(
|
||||
enter.getAttrs(), {"frame_name", "parallel_iterations", "is_constant"});
|
||||
}
|
||||
|
||||
ParseResult ParseEnterOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
MLIRContext *context = parser->getBuilder().getContext();
|
||||
if (parser->parseOperandList(op_infos)) return failure();
|
||||
if (op_infos.empty())
|
||||
return parser->emitError(loc) << " expects at least one data operand";
|
||||
|
||||
Attribute frame;
|
||||
if (parser->parseKeyword("frame") ||
|
||||
parser->parseAttribute(frame, NoneType::get(context), "frame_name",
|
||||
result->attributes))
|
||||
return failure();
|
||||
|
||||
Type i64 = parser->getBuilder().getIntegerType(64);
|
||||
if (parser->parseOptionalKeyword("parallel_iterations")) {
|
||||
result->addAttribute("parallel_iterations",
|
||||
IntegerAttr::get(i64, kDefaultParallelIterations));
|
||||
} else {
|
||||
IntegerAttr parallel_iterations;
|
||||
if (parser->parseAttribute(parallel_iterations, i64, "parallel_iterations",
|
||||
result->attributes))
|
||||
return failure();
|
||||
}
|
||||
bool has_constant = succeeded(parser->parseOptionalKeyword("constant"));
|
||||
result->addAttribute("is_constant", BoolAttr::get(has_constant, context));
|
||||
|
||||
SmallVector<Type, 1> types;
|
||||
if (parser->parseColonTypeList(types)) return failure();
|
||||
if (types.size() != 1)
|
||||
return parser->emitError(loc) << " expects only a single data type";
|
||||
|
||||
// Support parsing either a functional type (in which case all the types are
|
||||
// fully qualified) or a short form with a single type (in which case the data
|
||||
// input and the outputs are all using this type).
|
||||
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
|
||||
if (type.getNumInputs() != 1)
|
||||
return parser->emitError(parser->getNameLoc())
|
||||
<< " expects a single data type";
|
||||
result->types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
} else {
|
||||
Type control_type = ControlType::get(context);
|
||||
types.append(op_infos.size() - 1, control_type);
|
||||
result->addTypes({types.front(), control_type});
|
||||
}
|
||||
|
||||
// Extra operands are expected to be control inputs.
|
||||
|
||||
if (parser->resolveOperands(op_infos, types, loc, result->operands))
|
||||
return failure();
|
||||
|
||||
return parser->parseOptionalAttributeDict(result->attributes);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
|
||||
|
||||
} // namespace tf_executor
|
||||
} // namespace mlir
|
74
tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h
Normal file
74
tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the tf_executor dialect: it models the TensorFlow executor
|
||||
// semantics and can represent arbitrary TensorFlow graphs. As such it follows
|
||||
// the existing execution model that includes deadness propagation, concurrent
|
||||
// semantics, and control dependencies.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_
|
||||
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
|
||||
class TensorFlowExecutorDialect : public Dialect {
|
||||
public:
|
||||
explicit TensorFlowExecutorDialect(MLIRContext *context);
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type parseType(StringRef data, Location loc) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type type, raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
namespace TFTypes {
|
||||
enum Kind {
|
||||
Control = Type::FIRST_TENSORFLOW_EXECUTOR_TYPE,
|
||||
};
|
||||
} // namespace TFTypes
|
||||
|
||||
// The Control type is a token-like value that models control dependencies from
|
||||
// TensorFlow graphs.
|
||||
class ControlType : public Type::TypeBase<ControlType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static ControlType get(MLIRContext *context) {
|
||||
return Base::get(context, TFTypes::Control);
|
||||
}
|
||||
|
||||
// Supports method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == TFTypes::Control; }
|
||||
};
|
||||
|
||||
// Declares the operations for this dialect using the generated header.
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h.inc"
|
||||
|
||||
} // namespace tf_executor
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_
|
423
tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
Normal file
423
tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
Normal file
@ -0,0 +1,423 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the definition file for the TensorFlow Executor Dialect.
|
||||
//
|
||||
|
||||
#ifdef TF_EXECUTOR_DIALECT
|
||||
#else
|
||||
#define TF_EXECUTOR_DIALECT
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TfExecutor_Dialect : Dialect {
|
||||
let name = "tf_executor";
|
||||
|
||||
let description = [{
|
||||
The TensorFlow Executor dialect.
|
||||
|
||||
This dialect models the TensorFlow executor semantics and can represent
|
||||
arbitrary TensorFlow graphs. As such it follows the existing execution model
|
||||
that includes deadness propagation, concurrent semantics, and control
|
||||
dependencies.
|
||||
|
||||
Operations in this dialect return a value of type `!tf_executor.control` as
|
||||
last returned value (exceptions are `tf_executor.NextIteration.graph`,
|
||||
`tf_executor.NextIteration.sink` and `tf_executor.fetch` which don’t return any
|
||||
value).
|
||||
}];
|
||||
|
||||
let cppNamespace = "tf_executor";
|
||||
}
|
||||
|
||||
// Control type.
|
||||
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow Executor Type Constraint
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Predicate to verify that the opId'th operand can be broadcasted to the type
|
||||
// of the resId'th result.
|
||||
def ControlOperandsAfterAllData :
|
||||
PredOpTrait<"all control inputs must appear after any non-control input",
|
||||
CPred<"succeeded(VerifyControlOperandsAfterAllData(&$_op))">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for the operation in this dialect, it'll forward the verifier,
|
||||
// printer, and parser to free functions.
|
||||
class TfExecutor_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TfExecutor_Dialect, mnemonic, traits> {
|
||||
|
||||
let verifier = [{ return ::mlir::tf_executor::Verify(*this); }];
|
||||
let printer = [{ return ::mlir::tf_executor::Print(*this, p); }];
|
||||
let parser = [{ return Parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def TfExecutor_GraphOp : TfExecutor_Op<"graph", []> {
|
||||
let summary = [{The `tf_executor.graph` operation contains a region with a
|
||||
single block that lists the operations in a TensorFlow graph.}];
|
||||
|
||||
|
||||
let description = [{
|
||||
The operations are topologically sorted in-order (no cycles are allowed in
|
||||
the values). The execution model for operations in this block follows the
|
||||
TensorFlow executor semantics:
|
||||
1. Operations that don’t have any transitive dependencies through the
|
||||
def/use chains may be executed in parallel
|
||||
(`tf_executor.NextIteration.Source` is the exception).
|
||||
2. SSA values in this block can be implicitly dead. This means that every
|
||||
SSA value defined in a `tf_executor.graph` can be considered implicitly
|
||||
wrapped in a conceptual `dead_or<T>` structure, and includes a runtime
|
||||
flag indicating if the value is dead or present.
|
||||
3. Operations may have special case handling of dead values.
|
||||
|
||||
The `tf_executor.graph` op only allows specific `tf_executor` dialect
|
||||
operations in its body: the `tf_executor.graph` verifier will reject any
|
||||
unknown operation. In order to execute standard `tf` dialect operations
|
||||
(like `tf.Add`) they must be wrapped in the `tf_executor.island` operation.
|
||||
|
||||
The `tf_executor.graph` operation does not accept any operands, inputs are
|
||||
implicitly captured by the region, representing the feeds to the graph.
|
||||
|
||||
The region attached to `tf_executor.graph` is terminated by a
|
||||
`tf_executor.fetch` operation. The operands of the terminator correspond to
|
||||
the result values (or fetches) of the `tf_executor.graph` operation. The
|
||||
behavior is undefined if any of the operands of the `tf_executor.fetch` is
|
||||
dead.
|
||||
}];
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$results
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block &GetBody() { return getOperation()->getRegion(0).front(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfterAllData]> {
|
||||
let summary = [{
|
||||
The `tf_executor.fetch` operation terminates the graph and returns values";
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The non-control operands of the fetch operation are returned outside of the
|
||||
graph and must match the return type of the graph.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$fetches
|
||||
);
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_IslandOp : TfExecutor_Op<"island", []> {
|
||||
let summary = [{
|
||||
The `tf_executor.island` operation is a wrapper for operations in other
|
||||
dialects to be nested in a `tf_executor.graph`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The `tf_executor.graph` operation does not allow `tf` dialect operations to
|
||||
be immediately nested underneath it. The `tf_executor.island` is introduced
|
||||
as a wrapper for `tf` dialect operations: this results in a more consistent
|
||||
representation which makes analysis and transformation simpler.
|
||||
The `tf_executor.island` operation has a single region with a single block
|
||||
attached (only functional control flow is allowed). The block is terminated
|
||||
by a `tf_executor.yield` operation. The operands of the terminator
|
||||
correspond to the result values of the `tf_executor.graph` operation. An
|
||||
extra result of type `!tf_executor.control` is always produced by every
|
||||
`tf_executor.island`.
|
||||
Within an island, execution semantics follow standard sequential behavior as
|
||||
expected by TF2 and by compiler analyses and transformations, and values
|
||||
can’t be dead. Other nested `tf_executor.graph` operations can be present in
|
||||
the region to re-enable the TensorFlow executor for a subsection of the
|
||||
code.
|
||||
- Initially the functional control flow operations are calling functions
|
||||
involving graphs, if `tf_executor.graph` weren’t allowed in an island,
|
||||
these operations would need to have an equivalent in the `tf_executor`
|
||||
dialect to be modelled in a graph.
|
||||
- Nesting also allows forming islands without involving inter-procedural
|
||||
analyses: any function call may involve a callee with a graph.
|
||||
The `tf_executor.island` region allows implicit capture. If any value
|
||||
captured by a `tf_executor.island` is dead, the whole region does not
|
||||
execute and every produced value is marked as dead as well.
|
||||
An arbitrary number of `tf_executor.control` operands are accepted by a
|
||||
`tf_executor.island` operation.
|
||||
If any operand or implicitly captured value are dead, the region is not
|
||||
executed and dead values are immediately returned for every result.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TfeControlType>:$controlInputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$outputs,
|
||||
TfeControlType:$control
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block &GetBody() { return getOperation()->getRegion(0).front(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def TfExecutor_YieldOp :
|
||||
TfExecutor_Op<"yield", [Terminator, ControlOperandsAfterAllData]> {
|
||||
let summary = [{
|
||||
The `tf_executor.yield` operation terminates and returns values for the
|
||||
`tf_executor.island` operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$fetches
|
||||
);
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
|
||||
[NoSideEffect, ControlOperandsAfterAllData,
|
||||
PredOpTrait<"data operand must be broadcastable to true result",
|
||||
TCOpIsBroadcastableToRes<0, 0>>,
|
||||
PredOpTrait<"data operand must be broadcastable to false result",
|
||||
TCOpIsBroadcastableToRes<0, 1>>]>{
|
||||
let summary = [{
|
||||
The "tf_executor.Switch" operation takes a data operand and a boolean
|
||||
predicate condition, and returns two values matching the type of the data
|
||||
predicate.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
More details can be found in Tensorflow Control Flow white paper:
|
||||
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
|
||||
This is defined in TensorFlow as:
|
||||
|
||||
REGISTER_OP("Switch")
|
||||
.Input("data: T")
|
||||
.Input("pred: bool")
|
||||
.Output("output_false: T")
|
||||
.Output("output_true: T")
|
||||
|
||||
For example:
|
||||
%2 = tf_executor.Switch %0, %1 : tensor<*xf32>
|
||||
|
||||
Note: Additional result corresponds to the control output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyType:$data,
|
||||
TensorOf<[I1]>:$predicate,
|
||||
// Optional extra control inputs.
|
||||
Variadic<TfeControlType>:$controlInputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyType: $trueOutput,
|
||||
AnyType: $falseOutput,
|
||||
TfeControlType: $control
|
||||
);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands = {}",
|
||||
[{
|
||||
assert(operands.size() >= 2 && "tf_executor.Switch builder expects at "
|
||||
"least two operands");
|
||||
return build(builder, result, operands[0], operands[1], operands.drop_front(2));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value *data, Value *predicate, ArrayRef<Value *> controls = {}",
|
||||
[{
|
||||
Type dataTy = data->getType();
|
||||
Type controlTy = ControlType::get(builder->getContext());
|
||||
result->types = { dataTy, dataTy, controlTy };
|
||||
result->operands.push_back(data);
|
||||
result->operands.push_back(predicate);
|
||||
result->operands.insert(result->operands.end(), controls.begin(), controls.end());
|
||||
}]>
|
||||
];
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_SwitchNOp :
|
||||
TfExecutor_Op<"SwitchN", [NoSideEffect, ControlOperandsAfterAllData]> {
|
||||
let summary = [{
|
||||
The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and
|
||||
an integer attribute `num_outs` indicating the number of outputs. The `data`
|
||||
input is copied to output indicated by the `index` input. The other outputs
|
||||
are marked as dead. If one of the inputs or a control token is dead, then
|
||||
all of the outputs are marked as dead as well.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This is defined in TensorFlow as:
|
||||
|
||||
REGISTER_OP("_SwitchN")
|
||||
.Input("data: T")
|
||||
.Input("output_index: int32")
|
||||
.Output("outputs: num_outs * T")
|
||||
.Attr("num_outs: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(SwitchNShape);
|
||||
|
||||
For example:
|
||||
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
|
||||
Note: One additional result corresponds to the control output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyType:$data,
|
||||
I32:$index,
|
||||
// Optional extra control inputs.
|
||||
Variadic<TfeControlType>:$controlInputs,
|
||||
I64Attr:$num_outs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$outputs,
|
||||
TfeControlType: $control
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
|
||||
let summary = [{
|
||||
The "tf_executor.Merge" operation takes a list of input operands and returns
|
||||
a value of the operand type along with the index of the first match encountered.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
More details can be found in Tensorflow Control Flow white paper:
|
||||
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
|
||||
This is defined in TensorFlow as:
|
||||
|
||||
REGISTER_OP("Merge")
|
||||
.Input("inputs: N * T")
|
||||
.Output("output: T")
|
||||
.Output("value_index: int32")
|
||||
|
||||
For example:
|
||||
%2 = tf_executor.Merge %0, %1, %2, %3 : tensor<*xf32>
|
||||
|
||||
Note: Additional result corresponds to the control output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs_and_control
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyType:$output,
|
||||
TensorOf<[I32]>:$valueIndex,
|
||||
TfeControlType:$control
|
||||
);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
|
||||
[{
|
||||
assert(operands.size() >= 1 && "tf_executor.Merge builder expects at "
|
||||
"least one operand");
|
||||
Type data_type = operands[0]->getType();
|
||||
Type control_type = ControlType::get(builder->getContext());
|
||||
result->types = { data_type, builder->getIntegerType(32), control_type};
|
||||
result->operands.append(operands.begin(), operands.end());
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
|
||||
[NoSideEffect, ControlOperandsAfterAllData,
|
||||
PredOpTrait<"data operand must be broadcastable to result",
|
||||
TCOpIsBroadcastableToRes<0, 0>>]>{
|
||||
let summary = [{
|
||||
The "tf_executor.Enter" operation forwards its input to Tensorflow while
|
||||
loop.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
More details can be found in Tensorflow Control Flow white paper:
|
||||
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
|
||||
|
||||
Each tensor needs its own tf_executor.Enter to be made available inside a
|
||||
while loop.
|
||||
|
||||
This is defined in Tensorflow as:
|
||||
|
||||
REGISTER_OP("Enter")
|
||||
.Input("data: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("frame_name: string")
|
||||
.Attr("is_constant: bool = false")
|
||||
.Attr("parallel_iterations: int = 10")
|
||||
|
||||
For example:
|
||||
%res:2 = tf_executor.Enter %arg0 frame "some/frame" parallel_iterations 42 constant : tensor<*xf32>
|
||||
|
||||
Note: Additional result corresponds to the control output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyType:$data,
|
||||
StrAttr:$frame_name,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_constant,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
// Optional extra control inputs.
|
||||
Variadic<TfeControlType>:$controlInputs
|
||||
);
|
||||
|
||||
|
||||
let results = (outs
|
||||
AnyType:$output,
|
||||
TfeControlType:$control
|
||||
);
|
||||
|
||||
let verifier = ?;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
|
||||
[{
|
||||
assert(operands.size() >= 1 && "tf_executor.Enter builder "
|
||||
"expects at least one operand");
|
||||
result->operands.append(operands.begin(), operands.end());
|
||||
|
||||
Type control_type = ControlType::get(builder->getContext());
|
||||
result->types.push_back(operands[0]->getType());
|
||||
result->types.push_back(control_type);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TF_EXECUTOR_DIALECT
|
2802
tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Normal file
2802
tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Normal file
File diff suppressed because it is too large
Load Diff
265
tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
Normal file
265
tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
Normal file
@ -0,0 +1,265 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the base operation definition file for TensorFlow.
|
||||
//
|
||||
// This file includes the definition for the TensorFlow dialect, base TensorFlow
|
||||
// op, and various commonly used TensorFlow types, attributes, and builders.
|
||||
|
||||
#ifdef TF_OP_BASE
|
||||
#else
|
||||
#define TF_OP_BASE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TF_Dialect : Dialect {
|
||||
let name = "tf";
|
||||
|
||||
let description = [{
|
||||
The TensorFlow dialect.
|
||||
|
||||
This dialect maps to TensorFlow operations.
|
||||
|
||||
Invariants:
|
||||
|
||||
* All values are of Tensor type (in particular, scalars are
|
||||
represented using zero-dimentional tensors);
|
||||
|
||||
TODO: Make invariants more structured so that we can reference them in ops.
|
||||
}];
|
||||
|
||||
let cppNamespace = "TF";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TF_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any tensor element type defined in the TensorFlow dialect
|
||||
def TF_TFDialectType :
|
||||
Type<CPred<"$_self.isa<TensorFlowType>()">, "TensorFlow type">;
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
def TF_ElementType : Type<Or<[AnyFloat.predicate, AnyInteger.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
||||
// Any TensorFlow tensor type
|
||||
def TF_Tensor : TensorOf<[TF_ElementType]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types
|
||||
|
||||
def TF_I32Or64 : IntOfWidths<[32, 64]>;
|
||||
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
|
||||
|
||||
def TF_Int : IntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any integer tensor types
|
||||
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating-point types
|
||||
|
||||
def TF_F32Or64 : FloatOfWidths<[32, 64]>;
|
||||
|
||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def TF_FpTensor : TensorOf<[AnyFloat]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex types
|
||||
|
||||
def TF_Complex64 :
|
||||
Type<CPred<"$_self.isa<TF::Complex64Type>()">, "complex64 type">;
|
||||
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||||
|
||||
def TF_Complex128 :
|
||||
Type<CPred<"$_self.isa<TF::Complex128Type>()">, "complex128 type">;
|
||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||||
|
||||
def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
|
||||
"64/128-bit complex type">;
|
||||
|
||||
def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String/variant/resource types
|
||||
|
||||
def TF_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
|
||||
"TensorFlow string type">,
|
||||
BuildableType<"getType<mlir::TF::StringType>()">;
|
||||
def TF_StrTensor : TensorOf<[TF_Str]>;
|
||||
|
||||
def TF_Variant : Type<CPred<"$_self.isa<mlir::TF::VariantType>()">,
|
||||
"TensorFlow variant type">,
|
||||
BuildableType<"getType<mlir::TF::VariantType>()">;
|
||||
def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||||
|
||||
def TF_Resource : Type<CPred<"$_self.isa<mlir::TF::ResourceType>()">,
|
||||
"TensorFlow variant type">,
|
||||
BuildableType<"getType<mlir::TF::ResourceType>()">;
|
||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-category type constraints
|
||||
|
||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
|
||||
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
|
||||
|
||||
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
|
||||
|
||||
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyComplex], "number">;
|
||||
|
||||
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
|
||||
|
||||
def TF_NumberOrStrTensor : TensorOf<[TF_AnyNumber, TF_Str]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String attribute constraints
|
||||
|
||||
// A string attribute whose value are one of the values in `cases`.
|
||||
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
|
||||
CPred<!foldl(
|
||||
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
|
||||
!foreach(case, !tail(cases),
|
||||
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
|
||||
prev, cur, prev # " || " # cur)>,
|
||||
"string attribute whose value is " #
|
||||
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
|
||||
prev, cur, prev # ", or " # cur)>;
|
||||
|
||||
// TODO: Use EnumAttr to define the common attribute cases
|
||||
|
||||
def TF_ConvnetDataFormatAttr : StringBasedAttr<
|
||||
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
|
||||
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
|
||||
"'NHWC' or 'NCHW' convnet data format">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attributes
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// operand. If the `idx`-th operand is a variadic operand, then this attribute
|
||||
// just returns the element type of its first tensor, which is only meaningful
|
||||
// when the variadic operand has at least one tensor and the tensors all have
|
||||
// the same element type.
|
||||
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
|
||||
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
|
||||
|
||||
// A derived attribute that returns the element types of the tensors in the
|
||||
// dynamic value pack that corresponds to the `idx`-th ODS-declared variadic
|
||||
// operand. This returns a list of element types so it is used for variadic
|
||||
// operands that can have different element types.
|
||||
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::OperandElementTypeRange",
|
||||
"auto values = getODSOperands(" # idx # ");\n"
|
||||
"return {mlir::OperandElementTypeIterator(values.begin()), "
|
||||
"mlir::OperandElementTypeIterator(values.end())};"
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// result. If the `idx`-th result is a variadic result, then this attribute
|
||||
// just returns the element type of its first tensor, which is only meaningful
|
||||
// when the variadic result has at least one tensor and the tensors all have
|
||||
// the same element type.
|
||||
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
|
||||
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
|
||||
|
||||
// A derived attribute that returns the element types of the tensors in the
|
||||
// dynamic value pack that corresponds to the `idx`-th ODS-declared variadic
|
||||
// result. This returns a list of element types so it is used for variadic
|
||||
// results that can have different element types.
|
||||
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::ResultElementTypeRange",
|
||||
"auto values = getODSResults(" # idx # ");\n"
|
||||
"return {mlir::ResultElementTypeIterator(values.begin()), "
|
||||
"mlir::ResultElementTypeIterator(values.end())};"
|
||||
>;
|
||||
|
||||
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
|
||||
let returnType = "Type";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow common builders
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Mixin class defining a builder for binary ops supporting broadcast
|
||||
// behavior. The result type has the same element type as both operands.
|
||||
class WithBroadcastableBinOpBuilder {
|
||||
list<OpBuilder> builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value* x, Value* y",
|
||||
[{
|
||||
auto resultType =
|
||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
||||
if (!resultType)
|
||||
mlir::emitError(result->location, "non-broadcastable operands");
|
||||
return build(builder, result, resultType, x, y);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
// Mixin class defining a builder for comparison ops supporting broadcast
|
||||
// behavior. The result type has bool element type.
|
||||
class WithBroadcastableCmpOpBuilder {
|
||||
list<OpBuilder> builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value* x, Value* y",
|
||||
[{
|
||||
Type resultType;
|
||||
if (x->getType().isa<UnrankedTensorType>() ||
|
||||
y->getType().isa<UnrankedTensorType>()) {
|
||||
resultType = builder->getTensorType(builder->getI1Type());
|
||||
} else {
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (!OpTrait::util::getBroadcastedShape(
|
||||
x->getType().cast<ShapedType>().getShape(),
|
||||
y->getType().cast<ShapedType>().getShape(), resultShape)) {
|
||||
mlir::emitError(result->location,
|
||||
"operands have no broadcastable shapes");
|
||||
}
|
||||
|
||||
resultType = builder->getTensorType(resultShape, builder->getI1Type());
|
||||
}
|
||||
return build(builder, result, resultType, x, y);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
#endif // TF_OP_BASE
|
905
tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
Normal file
905
tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
Normal file
@ -0,0 +1,905 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given `value` is of ranked float tensor type with the
|
||||
/// given `rank`.
|
||||
static inline bool isOfRankedFloatTensorType(Value *value, int rank) {
|
||||
auto type = value->getType().dyn_cast<RankedTensorType>();
|
||||
return type && type.getRank() == rank &&
|
||||
type.getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// Returns true if the given `value` has the specified rank or has unranked
|
||||
// type.
|
||||
static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) {
|
||||
if (auto type = value->getType().dyn_cast<RankedTensorType>()) {
|
||||
return type.getRank() == rank;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if the specified element type is a TensorFlow type that is ok
|
||||
/// in a tensor.
|
||||
static inline bool isValidTFElementType(Type type) {
|
||||
return type.isa<FloatType>() || type.isa<IntegerType>() ||
|
||||
type.isa<TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
static inline bool isValidTFTensorType(Type type) {
|
||||
// TensorFlow types should be tensors of one of the valid TensorFlow element
|
||||
// types.
|
||||
if (auto tensorTy = type.dyn_cast<TensorType>())
|
||||
return isValidTFElementType(tensorTy.getElementType());
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<AddToAddV2>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<AddV2OfNegLeft, AddV2OfNegRight>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BitcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<BitcastSameType, BitcastNested>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastToOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(BroadcastToOp op) {
|
||||
// TODO(antiagainst): check that
|
||||
// * The 'shape' input is an 1-D int tensor.
|
||||
// * Each dimension pair of the source and target shapes are either equal
|
||||
// or one of them is one.
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<CastSameType>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConjOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<ConjNested>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
|
||||
// Returns the held attribute value.
|
||||
return value();
|
||||
}
|
||||
|
||||
// Builds a constant op with the specified attribute `value`. The result
|
||||
// op's type is deduced from `value`; if `value` is of scalar type,
|
||||
// wraps it up with a tensor type of empty shape.
|
||||
void ConstOp::build(Builder *builder, OperationState *result, Attribute value) {
|
||||
ShapedType type;
|
||||
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
|
||||
type = elemAttr.getType();
|
||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||
value.isa<IntegerAttr>()) {
|
||||
// All TensorFlow types must be tensor types. In the build() method,
|
||||
// we want to provide more flexiblity by allowing attributes of scalar
|
||||
// types. But we need to wrap it up with ElementsAttr to construct
|
||||
// valid TensorFlow constants.
|
||||
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
||||
value = DenseElementsAttr::get(type, value);
|
||||
}
|
||||
// TODO: support other TensorFlow specific types.
|
||||
assert(type && "unsupported attribute type for building tf.Const");
|
||||
result->types.push_back(type);
|
||||
result->addAttribute("value", value);
|
||||
}
|
||||
|
||||
void ConstOp::build(Builder *builder, OperationState *result, Type type,
|
||||
Attribute value) {
|
||||
ConstOp::build(builder, result, value);
|
||||
assert(type == result->types[0] && "type mismatch in construction");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantWithMinMaxArgsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
|
||||
// TODO(fengliuai): moving the following to an utility method.
|
||||
const llvm::fltSemantics &semantics = op.min().getSemantics();
|
||||
float rmin, rmax;
|
||||
if (&semantics == &APFloat::IEEEsingle()) {
|
||||
rmin = op.min().convertToFloat();
|
||||
rmax = op.max().convertToFloat();
|
||||
} else {
|
||||
rmin = op.min().convertToDouble();
|
||||
rmax = op.max().convertToDouble();
|
||||
}
|
||||
// Range boundaries must be valid.
|
||||
if (rmin >= rmax) {
|
||||
return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
|
||||
"," + Twine(std::to_string(rmax)) + "]");
|
||||
}
|
||||
// Range must straddle zero.
|
||||
if (rmin > 0.0 || rmax < 0.0) {
|
||||
return op.emitOpError("range failed to straddle zero: [" +
|
||||
Twine(std::to_string(rmin)) + "," +
|
||||
Twine(std::to_string(rmax)) + "]");
|
||||
}
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantWithMinMaxVarsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
|
||||
if (!isOfRankedFloatTensorType(op.min(), 0))
|
||||
return op.emitOpError("requires min to be a 0d float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.max(), 0))
|
||||
return op.emitOpError("requires max to be a 0d float tensor");
|
||||
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
if (!isOfRankedFloatTensorType(op.x(), 4))
|
||||
return op.emitOpError("requires x to be a 4D float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.scale(), 1))
|
||||
return op.emitOpError("requires scale to be a 1D float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.offset(), 1))
|
||||
return op.emitOpError("requires offset to be a 1D float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.mean(), 1))
|
||||
return op.emitOpError("requires mean to be a 1D float tensor");
|
||||
|
||||
if (!isOfRankedFloatTensorType(op.variance(), 1))
|
||||
return op.emitOpError("requires variance to be a 1D float tensor");
|
||||
|
||||
// TODO(antiagainst): check attributes
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult IfOp::verify() {
|
||||
auto thenAttr = getAttrOfType<FunctionAttr>("then_branch");
|
||||
if (!thenAttr) return emitOpError("requires then_branch attribute");
|
||||
|
||||
auto elseAttr = getAttrOfType<FunctionAttr>("else_branch");
|
||||
if (!elseAttr) return emitOpError("requires else_branch attribute");
|
||||
|
||||
auto *module = getOperation()->getFunction()->getModule();
|
||||
auto *thenFn = module->getNamedFunction(thenAttr.getValue());
|
||||
if (!thenFn)
|
||||
return emitOpError("then_branch refers to an undefined function : ")
|
||||
<< thenAttr;
|
||||
auto *elseFn = module->getNamedFunction(elseAttr.getValue());
|
||||
if (!elseFn)
|
||||
return emitOpError("else_branch refers to an undefined function : ")
|
||||
<< elseAttr;
|
||||
auto thenFuncType = thenFn->getType();
|
||||
auto elseFuncType = elseFn->getType();
|
||||
|
||||
// Non-conditional operands starting with the second operand are passed to
|
||||
// branches and should be pair-wise compatible with branches' inputs.
|
||||
unsigned expectedNumInputs = getNumOperands() - 1;
|
||||
if (thenFuncType.getNumInputs() != expectedNumInputs ||
|
||||
elseFuncType.getNumInputs() != expectedNumInputs)
|
||||
return emitError("branches should have " + Twine(expectedNumInputs) +
|
||||
" inputs");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||
auto operandType = getOperand(i + 1)->getType().cast<TensorType>();
|
||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||
if (!TensorCastOp::areCastCompatible(operandType, thenInputType))
|
||||
return emitError(
|
||||
llvm::formatv("then branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
thenInputType, operandType, i));
|
||||
|
||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||
if (!TensorCastOp::areCastCompatible(operandType, elseInputType))
|
||||
return emitError(
|
||||
llvm::formatv("else branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
elseInputType, operandType, i));
|
||||
|
||||
// If branches have incompatible input types that means that no tensor can
|
||||
// serve as input to both the functions. Hence, the op is invalid.
|
||||
if (!TensorCastOp::areCastCompatible(thenInputType, elseInputType))
|
||||
return emitError(llvm::formatv(
|
||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||
thenInputType, elseInputType, i));
|
||||
}
|
||||
|
||||
// Branches' results should be pair-wise compatible with the op results.
|
||||
unsigned expectedNumResults = getNumResults();
|
||||
if (thenFuncType.getNumResults() != expectedNumResults ||
|
||||
elseFuncType.getNumResults() != expectedNumResults)
|
||||
return emitError("branches should have " + Twine(expectedNumResults) +
|
||||
" results");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||
auto resultType = getResult(i)->getType().cast<TensorType>();
|
||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||
if (!TensorCastOp::areCastCompatible(thenResultType, resultType))
|
||||
return emitError(
|
||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
thenResultType, resultType, i));
|
||||
|
||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||
if (!TensorCastOp::areCastCompatible(elseResultType, resultType))
|
||||
return emitError(
|
||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
elseResultType, resultType, i));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InvertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<InvertNested>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LeakyReluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "leaky relu has one operand");
|
||||
|
||||
// leaky_relu(x, alpha: 1) -> x
|
||||
if (alpha().convertToFloat() == 1.0f) return getOperand();
|
||||
|
||||
auto calculate = [&](FloatAttr arg) {
|
||||
APFloat val = arg.getValue();
|
||||
if (val.isNegative()) val = alpha() * val;
|
||||
return FloatAttr::get(arg.getType(), val);
|
||||
};
|
||||
|
||||
if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
return calculate(arg);
|
||||
} else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
if (auto elementAttr = arg.getSplatValue().dyn_cast<FloatAttr>())
|
||||
return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LogOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<LogOfSoftmax>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LogicalNotOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LogicalNotOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<LogicalNotNested, LogicalNotOfEqual, LogicalNotOfNotEqual,
|
||||
LogicalNotOfGreater, LogicalNotOfGreaterEqual,
|
||||
LogicalNotOfLess, LogicalNotOfLessEqual>::build(results,
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<NegNested>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReciprocalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ReciprocalOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<ReciprocalNested>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RandomUniformOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(RandomUniformOp op) {
|
||||
if (!IsOfRankOrUnranked(op.shape(), 1))
|
||||
return op.emitOpError("shape must be 1D tensor");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RangeOp::build(Builder *builder, OperationState *result, Value *start,
|
||||
Value *limit, Value *delta) {
|
||||
assert(start->getType() == limit->getType());
|
||||
assert(start->getType() == delta->getType());
|
||||
DenseIntElementsAttr start_val;
|
||||
DenseIntElementsAttr limit_val;
|
||||
DenseIntElementsAttr delta_val;
|
||||
if (matchPattern(start, m_Constant(&start_val)) &&
|
||||
matchPattern(limit, m_Constant(&limit_val)) &&
|
||||
matchPattern(delta, m_Constant(&delta_val))) {
|
||||
auto size = llvm::APIntOps::RoundingSDiv(
|
||||
*limit_val.begin() - *start_val.begin(), *delta_val.begin(),
|
||||
llvm::APInt::Rounding::DOWN);
|
||||
return RangeOp::build(
|
||||
builder, result,
|
||||
builder->getTensorType(
|
||||
size.getSExtValue(),
|
||||
start->getType().cast<TensorType>().getElementType()),
|
||||
start, limit, delta);
|
||||
}
|
||||
return RangeOp::build(
|
||||
builder, result,
|
||||
builder->getTensorType(
|
||||
{-1}, start->getType().cast<TensorType>().getElementType()),
|
||||
start, limit, delta);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RankOp::build(Builder *builder, OperationState *result, Value *input) {
|
||||
return RankOp::build(builder, result,
|
||||
builder->getTensorType({}, builder->getIntegerType(32)),
|
||||
input);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RealDivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<RealDivWithSqrtDivisor>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/128020684): Verify the rank of the output and change to use
|
||||
// m_Constant.
|
||||
static LogicalResult Verify(ReshapeOp op) {
|
||||
auto shapeType = op.shape()->getType().cast<TensorType>();
|
||||
if (shapeType.getRank() != 1)
|
||||
return op.emitOpError("shape must be 1D tensor");
|
||||
auto rankByShape = shapeType.getShape()[0];
|
||||
auto typeOfTensor = op.tensor()->getType().cast<TensorType>();
|
||||
// No compile time verification for unknown sized shape.
|
||||
if (rankByShape == -1 || !typeOfTensor.hasRank()) return success();
|
||||
// Checks values if constant shape. No compiling time verification for
|
||||
// non-constant shape.
|
||||
auto *shapeOp = op.shape()->getDefiningOp();
|
||||
if (!shapeOp) return success();
|
||||
Attribute shapeCst;
|
||||
if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) {
|
||||
shapeCst = shapeStdOp.getValue();
|
||||
} else if (auto shapeTFOp = dyn_cast<ConstOp>(shapeOp)) {
|
||||
shapeCst = shapeTFOp.value();
|
||||
} else {
|
||||
return success();
|
||||
}
|
||||
auto shapeCstAttr = shapeCst.dyn_cast<ElementsAttr>();
|
||||
if (!shapeCstAttr) return op.emitOpError("shape must be a valid tensor");
|
||||
|
||||
if (auto opaqueAttr = shapeCstAttr.dyn_cast<OpaqueElementsAttr>()) {
|
||||
opaqueAttr.decode(shapeCstAttr);
|
||||
}
|
||||
|
||||
// We know the shape is a 1-D Tensor, then let us get the number of
|
||||
// elements it implies.
|
||||
unsigned numByShape = 1;
|
||||
unsigned unknownDimCount = 0;
|
||||
for (int i = 0, e = rankByShape; i != e; ++i) {
|
||||
auto num = shapeCstAttr.getValue(i).cast<IntegerAttr>().getInt();
|
||||
// The dimension size value can be -1, and that the real size needs to
|
||||
// be computed so that the total size remains constant. At most one
|
||||
// component of shape can be -1.
|
||||
if (num == -1) {
|
||||
if (++unknownDimCount > 1) {
|
||||
return op.emitOpError("more than one component of shape are -1");
|
||||
}
|
||||
} else {
|
||||
numByShape *= num;
|
||||
}
|
||||
}
|
||||
auto numByTensor = typeOfTensor.getNumElements();
|
||||
// If there is one component of shape is -1, the dimension should be
|
||||
// computed so that the total size remains constant.
|
||||
if (unknownDimCount == 1) {
|
||||
if (numByTensor % numByShape != 0)
|
||||
return op.emitOpError(
|
||||
"one component of shape is -1 but couldn't infer the dimension");
|
||||
return success();
|
||||
}
|
||||
// If the elements by the tensor and implies by the shape don't match,
|
||||
// fail this static check.
|
||||
if (numByTensor != numByShape) {
|
||||
return op.emitOpError(
|
||||
"mismatch in tensor elements and shape implied elements");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ReshapeOp::build(Builder *builder, OperationState *result, Value *tensor,
|
||||
Value *shape) {
|
||||
auto etype = tensor->getType().cast<ShapedType>().getElementType();
|
||||
DenseIntElementsAttr attr_shape;
|
||||
if (matchPattern(shape, m_Constant(&attr_shape))) {
|
||||
llvm::SmallVector<int64_t, 4> const_shape;
|
||||
if (attr_shape.isSplat()) {
|
||||
const_shape.assign(attr_shape.getType().getNumElements(),
|
||||
(*attr_shape.begin()).getSExtValue());
|
||||
} else {
|
||||
const_shape.reserve(attr_shape.getType().getNumElements());
|
||||
for (auto dim : attr_shape) const_shape.push_back(dim.getSExtValue());
|
||||
}
|
||||
return ReshapeOp::build(builder, result,
|
||||
builder->getTensorType(const_shape, etype), tensor,
|
||||
shape);
|
||||
}
|
||||
return ReshapeOp::build(builder, result, builder->getTensorType(etype),
|
||||
tensor, shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(ShapeOp op) {
|
||||
auto inputType = op.input()->getType();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType || resultType.getShape().size() != 1)
|
||||
return op.emitOpError("requires 1D result type");
|
||||
|
||||
auto rankedTensorType = inputType.dyn_cast<RankedTensorType>();
|
||||
if (rankedTensorType) {
|
||||
// The operand is a ranked tensor.
|
||||
if (resultType.hasStaticShape()) {
|
||||
if ((!rankedTensorType.getShape().empty() &&
|
||||
resultType.getDimSize(0) != rankedTensorType.getShape().size()) ||
|
||||
(rankedTensorType.getShape().empty() &&
|
||||
resultType.getDimSize(0) != 1))
|
||||
return op.emitOpError(
|
||||
"requires dimension size of result to match rank of operand");
|
||||
}
|
||||
} else {
|
||||
// The operand is an unranked tensor, verify that the result is dynamic.
|
||||
if (resultType.hasStaticShape())
|
||||
return op.emitOpError("requires dynamic shape result for unranked input");
|
||||
}
|
||||
|
||||
Type elt = op.getType().cast<ShapedType>().getElementType();
|
||||
if (elt.isInteger(32) || elt.isInteger(64)) return success();
|
||||
return op.emitOpError("requires int32 or int64 return type");
|
||||
}
|
||||
|
||||
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto inputType = getOperand()->getType();
|
||||
auto rankedTensorType = inputType.dyn_cast<RankedTensorType>();
|
||||
if (!rankedTensorType || !rankedTensorType.hasStaticShape()) return {};
|
||||
|
||||
// TODO: This is handling the simple case where the resultant type matches the
|
||||
// size of getShape()'s returned type.
|
||||
auto shape = rankedTensorType.getShape();
|
||||
int rank = shape.size();
|
||||
if (rank == 0) return {};
|
||||
|
||||
Builder b(getContext());
|
||||
auto elementType = getType().cast<ShapedType>().getElementType();
|
||||
|
||||
SmallVector<Attribute, 4> dimensions;
|
||||
dimensions.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i)
|
||||
dimensions.push_back(b.getIntegerAttr(elementType, shape[i]));
|
||||
|
||||
auto resultType = b.getTensorType({rank}, elementType);
|
||||
return b.getDenseElementsAttr(resultType, dimensions);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SoftmaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SoftmaxOp op) {
|
||||
if (!IsOfRankOrUnranked(op.logits(), 2))
|
||||
return op.emitOpError("requires operand to be 2D tensor");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SquareOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<SquareOfSub>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<SubOfNeg>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListReserveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TensorListReserveOp op) {
|
||||
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
|
||||
!IsOfRankOrUnranked(op.element_shape(), 1)) {
|
||||
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
|
||||
}
|
||||
|
||||
if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
|
||||
return op.emitOpError("requires num_elements operand to be 0D tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
// TODO(hinsu): Verify using a custom verifier that,
|
||||
// * Transpose permutation is 1-D of size equal to the rank of the first
|
||||
// input, if the shapes are partially known. Requires use of a more
|
||||
// restrictive type than TF_Tensor.
|
||||
// * Result shape dimensions are possible based on the input shape.
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(jpienaar): perm could be optional too.
|
||||
void TransposeOp::build(Builder *builder, OperationState *result, Value *x,
|
||||
Value *perm) {
|
||||
auto x_type = x->getType().cast<TensorType>();
|
||||
// If value is unranked, then so is results.
|
||||
if (!x_type.hasRank())
|
||||
return TransposeOp::build(builder, result,
|
||||
builder->getTensorType(x_type.getElementType()),
|
||||
x, perm);
|
||||
|
||||
// TODO(jpienaar): Handle unknown perm case.
|
||||
|
||||
// TODO(jpienaar): Extract utility function.
|
||||
auto etype = x_type.cast<ShapedType>().getElementType();
|
||||
DenseIntElementsAttr attr_shape;
|
||||
if (matchPattern(perm, m_Constant(&attr_shape))) {
|
||||
llvm::SmallVector<int64_t, 4> const_shape;
|
||||
if (attr_shape.isSplat()) {
|
||||
const_shape.assign(
|
||||
attr_shape.getType().getNumElements(),
|
||||
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
|
||||
} else {
|
||||
const_shape.reserve(attr_shape.getType().getNumElements());
|
||||
for (auto dim : attr_shape)
|
||||
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
|
||||
}
|
||||
return TransposeOp::build(
|
||||
builder, result, builder->getTensorType(const_shape, etype), x, perm);
|
||||
}
|
||||
return TransposeOp::build(builder, result, builder->getTensorType(etype), x,
|
||||
perm);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TruncateDivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TruncateDivOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
RewriteListBuilder<TruncateDivWithSqrtDivisor>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult WhileOp::verify() {
|
||||
auto condAttr = getAttrOfType<FunctionAttr>("cond");
|
||||
if (!condAttr) return emitOpError("requires cond attribute");
|
||||
|
||||
auto *module = getOperation()->getFunction()->getModule();
|
||||
auto *condFn = module->getNamedFunction(condAttr.getValue());
|
||||
auto condFuncType = condFn->getType();
|
||||
|
||||
// Verify that the cond function has exactly one result.
|
||||
if (condFuncType.getNumResults() != 1)
|
||||
return emitOpError("requires cond function to have exactly one result");
|
||||
|
||||
auto bodyAttr = getAttrOfType<FunctionAttr>("body");
|
||||
if (!bodyAttr) return emitOpError("requires body attribute");
|
||||
auto *bodyFn = module->getNamedFunction(bodyAttr.getValue());
|
||||
auto bodyFuncType = bodyFn->getType();
|
||||
|
||||
SmallVector<Type, 4> operands(getOperandTypes());
|
||||
SmallVector<Type, 4> results(getResultTypes());
|
||||
|
||||
// Collect all the type lists for the op so that different pairs of type lists
|
||||
// can be compared for the compatibility.
|
||||
int numTypeLists = 5;
|
||||
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
|
||||
{"operand", operands},
|
||||
{"body function result", bodyFuncType.getResults()},
|
||||
{"result", results},
|
||||
{"cond function input", condFuncType.getInputs()},
|
||||
{"body function input", bodyFuncType.getInputs()},
|
||||
};
|
||||
|
||||
// A pair of type lists should be cast compatible with each other if one is
|
||||
// converted to the another for a function call or assignment or there is a
|
||||
// common source of inputs for both. Therefore, the While op requires the
|
||||
// following pairs of type lists to be cast compatible for the tensor_cast
|
||||
// operation:
|
||||
//
|
||||
// * Operands and cond inputs to call the cond function before the
|
||||
// first iteration.
|
||||
// * Operands and body inputs to call the body function for the first
|
||||
// iteration if the cond functions returns True or equivalent result.
|
||||
// * Operands and results to assign cond function arguments to op results if
|
||||
// the cond function returns False or equivalent result.
|
||||
// * All three pairs using cond inputs, body inputs and results as operand is
|
||||
// a common source for all three.
|
||||
// * Body result and cond inputs to call the cond function for the subsequent
|
||||
// iterations. Similarly, Body result should be compatible with body inputs
|
||||
// and op results.
|
||||
//
|
||||
// Note that the operands and body results need not be compatible as they are
|
||||
// never converted from one to the another nor there is a common source
|
||||
// tensors. Compatibility requirement is not transitive.
|
||||
|
||||
for (int i = 0; i < numTypeLists; ++i) {
|
||||
// Skips the first pair as the While op operands and body function results
|
||||
// does not need to be compatible with each other.
|
||||
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
|
||||
auto &a = typeLists[i];
|
||||
auto &b = typeLists[j];
|
||||
|
||||
int aSize = a.second.size();
|
||||
if (aSize != b.second.size())
|
||||
return emitOpError(
|
||||
llvm::formatv("requires the number of {0}s to be equal to the "
|
||||
"number of {1}s. Found {2} and {3}, respectively",
|
||||
a.first, b.first, aSize, b.second.size()));
|
||||
|
||||
for (int idx = 0; idx < aSize; ++idx) {
|
||||
auto aType = a.second[idx];
|
||||
auto bType = b.second[idx];
|
||||
|
||||
if (!TensorCastOp::areCastCompatible(aType, bType))
|
||||
return emitError(llvm::formatv(
|
||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||
a.first, aType, b.first, bType, idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XdivyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
RewriteListBuilder<XdivyWithSqrtDivisor>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf", context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
|
||||
, IfOp, WhileOp>();
|
||||
addTypes<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
|
||||
// Supports unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type TensorFlowDialect::parseType(StringRef data, Location loc) const {
|
||||
auto typeKind = llvm::StringSwitch<unsigned>(data)
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
.Case(name, TensorFlowTypes::enumerant)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
.Default(0);
|
||||
switch (typeKind) {
|
||||
default:
|
||||
return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
|
||||
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant: \
|
||||
return tftype##Type::get(getContext());
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
}
|
||||
}
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void TensorFlowDialect::printType(Type ty, raw_ostream &os) const {
|
||||
assert(ty.isa<TensorFlowType>());
|
||||
switch (ty.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow type kind");
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant: \
|
||||
os << name; \
|
||||
break;
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
}
|
||||
}
|
||||
|
||||
Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
// If this is an opaque elements attribute, then generate a tf.Const.
|
||||
if (value.isa<OpaqueElementsAttr>() && value.getType() == type)
|
||||
return builder.create<ConstOp>(loc, type, value);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This verifies that the Op is a well-formed TensorFlow op, checking
|
||||
// that all inputs and results are Tensor or other TensorFlow types, etc.
|
||||
LogicalResult verifyTensorFlowOp(Operation *op) {
|
||||
if (op->getName().getDialect() != "tf")
|
||||
return op->emitError("TensorFlow op ")
|
||||
<< op->getName() << " should start with 'tf.'";
|
||||
|
||||
for (Type type : op->getOperandTypes()) {
|
||||
if (!isValidTFTensorType(type))
|
||||
return op->emitOpError(
|
||||
"requires operands to have a valid TensorFlow tensor type");
|
||||
}
|
||||
|
||||
for (Type type : op->getResultTypes()) {
|
||||
if (!isValidTFTensorType(type))
|
||||
return op->emitOpError(
|
||||
"requires results to have a valid TensorFlow tensor type");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
164
tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
Normal file
164
tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
Normal file
@ -0,0 +1,164 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the standard MLIR TensorFlow dialect
|
||||
// after control dependences are raise to the standard form.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
class TensorFlowDialect : public Dialect {
|
||||
public:
|
||||
TensorFlowDialect(MLIRContext *context);
|
||||
|
||||
// Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a
|
||||
// function references to its gradient function. This attribute in TensorFlow
|
||||
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
||||
// string description of gradient attribute.
|
||||
static StringRef GetGradientAttrName() { return "tf.gradient"; }
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type parseType(StringRef data, Location loc) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type ty, raw_ostream &os) const override;
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
// This verifies that the Op is a well-formed TensorFlow op, checking
|
||||
// that all inputs and results are Tensor or other TensorFlow types, etc.
|
||||
static LogicalResult verifyTensorFlowOp(Operation *op);
|
||||
|
||||
// This Trait should be used by all TensorFlow Ops.
|
||||
//
|
||||
template <typename ConcreteType>
|
||||
class TensorFlowOp : public OpTrait::TraitBase<ConcreteType, TensorFlowOp> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return verifyTensorFlowOp(op);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
|
||||
// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use
|
||||
// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and
|
||||
// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to
|
||||
// undefine here to avoid expanding the getter symbol as macro when including
|
||||
// both mutex.h and this header file.
|
||||
#undef mutex_lock
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
|
||||
|
||||
// The "tf.If" operation takes a condition operand, a list of inputs, and a
|
||||
// function attribute for the then/else branches. The condition operand
|
||||
// doesn't have to be a boolean tensor. It is handled according to these
|
||||
// rules, quoting the TensorFlow op definition:
|
||||
//
|
||||
// If the tensor is a scalar of non-boolean type, the scalar is converted to
|
||||
// a boolean according to the following rule: if the scalar is a numerical
|
||||
// value, non-zero means True and zero means False; if the scalar is a
|
||||
// string, non-empty means True and empty means False. If the tensor is not a
|
||||
// scalar, being empty means False and being non-empty means True.
|
||||
//
|
||||
// This is defined in TensorFlow as:
|
||||
//
|
||||
// REGISTER_OP("If")
|
||||
// .Input("cond: Tcond")
|
||||
// .Input("input: Tin")
|
||||
// .Output("output: Tout")
|
||||
// .Attr("Tcond: type")
|
||||
// .Attr("Tin: list(type) >= 0")
|
||||
// .Attr("Tout: list(type) >= 0")
|
||||
// .Attr("then_branch: func")
|
||||
// .Attr("else_branch: func")
|
||||
//
|
||||
// Note: Additional result corresponds to the control output.
|
||||
class IfOp : public Op<IfOp, TensorFlowOp, OpTrait::AtLeastNOperands<1>::Impl,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "tf.If"; }
|
||||
|
||||
Value *getCondition() { return getOperand(0); }
|
||||
|
||||
// TODO(b/132271680): This is not following Google naming style
|
||||
StringRef getThen() {
|
||||
return getAttrOfType<FunctionAttr>("then_branch").getValue();
|
||||
}
|
||||
|
||||
StringRef getElse() {
|
||||
return getAttrOfType<FunctionAttr>("else_branch").getValue();
|
||||
}
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
// The "tf.While" operation takes a list of inputs and function attributes for
|
||||
// the loop condition and body. Inputs are updated repeatedly by the body
|
||||
// function while the loop condition with the tensors evaluates to true. The
|
||||
// condition result doesn't have to be a boolean tensor. It is handled
|
||||
// according to these rules, quoting the TensorFlow op definition:
|
||||
//
|
||||
// If the tensor is a scalar of non-boolean type, the scalar is converted to
|
||||
// a boolean according to the following rule: if the scalar is a numerical
|
||||
// value, non-zero means True and zero means False; if the scalar is a
|
||||
// string, non-empty means True and empty means False. If the tensor is not a
|
||||
// scalar, being empty means False and being non-empty means True.
|
||||
//
|
||||
// This is defined in TensorFlow as:
|
||||
//
|
||||
// REGISTER_OP("While")
|
||||
// .Input("input: T")
|
||||
// .Output("output: T")
|
||||
// .Attr("T: list(type) >= 0")
|
||||
// .Attr("cond: func")
|
||||
// .Attr("body: func")
|
||||
// .Attr("output_shapes: list(shape) = []")
|
||||
//
|
||||
class WhileOp : public Op<WhileOp, TensorFlowOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "tf.While"; }
|
||||
|
||||
StringRef getCond() { return getAttrOfType<FunctionAttr>("cond").getValue(); }
|
||||
StringRef getBody() { return getAttrOfType<FunctionAttr>("body").getValue(); }
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user