Merge branch 'master' into toupstream/fix-tflite-interpreter-test

This commit is contained in:
Anton Kachatkou 2019-09-13 11:14:29 +01:00 committed by GitHub
commit d3564251e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2803 changed files with 104716 additions and 55250 deletions

View File

@ -92,17 +92,13 @@ build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script
build:gdr --define=with_gdr_support=true
build:ngraph --define=with_ngraph_support=true
build:verbs --define=with_verbs_support=true
build:numa --define=with_numa_support=true
# Options to disable default on features
build:noaws --define=no_aws_support=true
build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nokafka --define=no_kafka_support=true
build:noignite --define=no_ignite_support=true
build:nonccl --define=no_nccl_support=true
build --define=use_fast_cpp_protos=true

View File

@ -5,7 +5,6 @@
/tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq
/tensorflow/python/eager @jaingaurav @alextp
@ -37,9 +36,7 @@
/tensorflow/contrib/hadoop @yongtang
/tensorflow/contrib/hvx/ @satok16
/tensorflow/contrib/integrate/ @shoyer
/tensorflow/contrib/kafka @yongtang
/tensorflow/contrib/kernel_methods/ @petrosmol
/tensorflow/contrib/kinesis @yongtang
/tensorflow/contrib/ios_examples/ @petewarden
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke

View File

@ -49,34 +49,34 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.5.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.

View File

@ -3,56 +3,56 @@ package(default_visibility = ["//visibility:public"])
filegroup(
name = "gcc",
srcs = [
"bin/arm-linux-gnueabihf-gcc",
"bin/arm-rpi-linux-gnueabihf-gcc",
],
)
filegroup(
name = "ar",
srcs = [
"bin/arm-linux-gnueabihf-ar",
"bin/arm-rpi-linux-gnueabihf-ar",
],
)
filegroup(
name = "ld",
srcs = [
"bin/arm-linux-gnueabihf-ld",
"bin/arm-rpi-linux-gnueabihf-ld",
],
)
filegroup(
name = "nm",
srcs = [
"bin/arm-linux-gnueabihf-nm",
"bin/arm-rpi-linux-gnueabihf-nm",
],
)
filegroup(
name = "objcopy",
srcs = [
"bin/arm-linux-gnueabihf-objcopy",
"bin/arm-rpi-linux-gnueabihf-objcopy",
],
)
filegroup(
name = "objdump",
srcs = [
"bin/arm-linux-gnueabihf-objdump",
"bin/arm-rpi-linux-gnueabihf-objdump",
],
)
filegroup(
name = "strip",
srcs = [
"bin/arm-linux-gnueabihf-strip",
"bin/arm-rpi-linux-gnueabihf-strip",
],
)
filegroup(
name = "as",
srcs = [
"bin/arm-linux-gnueabihf-as",
"bin/arm-rpi-linux-gnueabihf-as",
],
)

View File

@ -1145,78 +1145,6 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp):
"""Set MPI_HOME."""
default_mpi_home = which('mpirun') or which('mpiexec') or ''
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
def valid_mpi_path(mpi_home):
exists = (
os.path.exists(os.path.join(mpi_home, 'include')) and
(os.path.exists(os.path.join(mpi_home, 'lib')) or
os.path.exists(os.path.join(mpi_home, 'lib64')) or
os.path.exists(os.path.join(mpi_home, 'lib32'))))
if not exists:
print(
'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
% (os.path.join(mpi_home, 'include'),
os.path.exists(os.path.join(mpi_home, 'lib')),
os.path.exists(os.path.join(mpi_home, 'lib64')),
os.path.exists(os.path.join(mpi_home, 'lib32'))))
return exists
_ = prompt_loop_or_load_from_env(
environ_cp,
var_name='MPI_HOME',
var_default=default_mpi_home,
ask_for_var='Please specify the MPI toolkit folder.',
check_success=valid_mpi_path,
error_msg='',
suppress_default_error=True)
def set_other_mpi_vars(environ_cp):
"""Set other MPI related variables."""
# Link the MPI header files
mpi_home = environ_cp.get('MPI_HOME')
symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
# Determine if we use OpenMPI or MVAPICH, these require different header files
# to be included here to make bazel dependency checker happy
if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
symlink_force(
os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
'third_party/mpi/mpi_portable_platform.h')
# TODO(gunan): avoid editing files in configure
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = False',
'MPI_LIB_IS_OPENMPI = True')
else:
# MVAPICH / MPICH
symlink_force(
os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
symlink_force(
os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
# TODO(gunan): avoid editing files in configure
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = True',
'MPI_LIB_IS_OPENMPI = False')
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
else:
raise ValueError(
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
(mpi_home, mpi_home, mpi_home))
def system_specific_test_config(env):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
@ -1549,11 +1477,6 @@ def main():
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
'At most 1 GPU platform can be configured.')
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':
set_mpi_home(environ_cp)
set_other_mpi_vars(environ_cp)
set_cc_opt_flags(environ_cp)
set_system_libs_flag(environ_cp)
if is_windows():
@ -1580,8 +1503,6 @@ def main():
'details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('gdr', 'Build with GDR support.')
config_info_line('verbs', 'Build with libverbs support.')
config_info_line('ngraph', 'Build with Intel nGraph support.')
config_info_line('numa', 'Build with NUMA support.')
config_info_line(
@ -1593,8 +1514,6 @@ def main():
config_info_line('noaws', 'Disable AWS S3 filesystem support.')
config_info_line('nogcp', 'Disable GCP support.')
config_info_line('nohdfs', 'Disable HDFS support.')
config_info_line('noignite', 'Disable Apache Ignite support.')
config_info_line('nokafka', 'Disable Apache Kafka support.')
config_info_line('nonccl', 'Disable NVIDIA NCCL support.')

View File

@ -267,18 +267,6 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "no_ignite_support",
define_values = {"no_ignite_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_kafka_support",
define_values = {"no_kafka_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_nccl_support",
define_values = {"no_nccl_support": "true"},
@ -309,18 +297,6 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gdr_support",
define_values = {"with_gdr_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_verbs_support",
define_values = {"with_verbs_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_numa_support",
define_values = {"with_numa_support": "true"},
@ -421,12 +397,6 @@ config_setting(
},
)
config_setting(
name = "with_mpi_support",
values = {"define": "with_mpi_support=true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "override_eigen_strong_inline",
values = {"define": "override_eigen_strong_inline=true"},
@ -470,6 +440,7 @@ config_setting(
package_group(
name = "internal",
packages = [
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",

View File

@ -56,10 +56,10 @@ elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
@ -125,25 +125,6 @@ if _running_from_pip_package():
if _fi.file_exists(plugin_dir):
_ll.load_library(plugin_dir)
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# pylint: enable=undefined-variable
# Add module aliases
if hasattr(_current_module, 'keras'):
losses = keras.losses

View File

@ -60,6 +60,10 @@ elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
@ -134,23 +138,3 @@ if _running_from_pip_package():
if _fi.file_exists(plugin_dir):
_ll.load_library(plugin_dir)
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# pylint: enable=undefined-variable

View File

@ -270,6 +270,7 @@ tf_cuda_library(
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
exports_files(

View File

@ -159,7 +159,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
return new TF_DeprecatedSession({session});
} else {
DCHECK_EQ(nullptr, session);
@ -332,7 +332,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
// TODO(nolivia): check this on a subset of the graph instead of all of
// it.
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
session->graph->mu.unlock();
return false;
}
@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
*graph_def.mutable_library() = graph.flib_def().ToProto();
session->graph->mu.unlock();
status->status = session->session->Extend(std::move(graph_def));
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
// Contract is we always delete input_values[i].
return false;
}
@ -382,7 +382,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
if (TF_GetCode(status) != TF_OK) return false;
if (!status->status.ok()) return false;
}
return true;
}
@ -439,7 +439,7 @@ static void TF_Run_Helper(
// Serialize back to upstream client, who now owns the new buffer
if (run_metadata != nullptr) {
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
} else {
// NOTE(zongheng): PRun does not support RunOptions yet.
@ -459,7 +459,7 @@ static void TF_Run_Helper(
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
}
@ -516,7 +516,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
string new_handle;
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -555,7 +555,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
status->status = tensorflow::LoadLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
@ -983,7 +983,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* value, TF_Status* status) {
Tensor t;
status->status = TF_TensorToTensor(value, &t);
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
@ -993,13 +993,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
std::vector<Tensor> t;
t.reserve(num_values);
for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
for (int i = 0; i < num_values && status->status.ok(); ++i) {
Tensor v;
status->status = TF_TensorToTensor(values[i], &v);
t.emplace_back(v);
}
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
@ -1048,11 +1048,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
/*consume=*/true);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
// Run shape inference function for newly added node.
status->status = desc->graph->refiner.AddNode(ret);
}
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
// Add the node to the name-to-node mapping.
desc->graph->name_map[ret->name()] = ret;
} else if (ret != nullptr) {
@ -1101,7 +1101,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
if (TF_GetCode(status) != TF_OK) return -1;
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
@ -1123,7 +1123,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
if (TF_GetCode(status) != TF_OK) return -1;
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
@ -1142,6 +1142,16 @@ TF_Output TF_OperationInput(TF_Input oper_in) {
return {ToOperation(edge->src()), edge->src_output()};
}
void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
int max_inputs) {
for (auto* edge : oper->node.in_edges()) {
if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
inputs[edge->dst_input()] = {ToOperation(edge->src()),
edge->src_output()};
}
}
}
int TF_OperationOutputNumConsumers(TF_Output oper_out) {
int count = 0;
for (const auto* edge : oper_out.oper->node.out_edges()) {
@ -1221,7 +1231,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
TF_Status* status) {
TF_AttrMetadata metadata;
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return metadata;
if (!status->status.ok()) return metadata;
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
@ -1328,7 +1338,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
void* value, size_t max_length,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kS) {
status->status =
InvalidArgument("Attribute '", attr_name, "' is not a string");
@ -1346,7 +1356,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
int max_values, void* storage,
size_t storage_size, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1379,7 +1389,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (TF_GetCode(status) != TF_OK) return; \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
@ -1401,7 +1411,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
PartialTensorShape shape;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
auto len = std::min(shape.dims(), num_dims);
for (int i = 0; i < len; ++i) {
value[i] = shape.dim_size(i);
@ -1415,7 +1425,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
std::vector<PartialTensorShape> shapes;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
int64_t* p = storage;
int storage_left = storage_size;
@ -1443,7 +1453,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
const char* attr_name,
TF_Buffer* value, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kShape) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a shape.");
@ -1457,7 +1467,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
TF_Buffer** values, int max_values,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1467,7 +1477,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
for (int i = 0; i < len; ++i) {
values[i] = TF_NewBuffer();
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
// Delete everything allocated to far, the operation has failed.
for (int j = 0; j <= i; ++j) {
TF_DeleteBuffer(values[j]);
@ -1482,7 +1492,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
*value = nullptr;
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status);
}
@ -1491,7 +1501,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
TF_Status* status) {
std::vector<Tensor> ts;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
@ -1502,7 +1512,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
status->status = MessageToBuffer(*attr, output_attr_value);
}
@ -1583,7 +1593,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
{
mutex_lock l(graph->mu);
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
}
status->status = MessageToBuffer(*op_def, output_op_def);
}
@ -1701,7 +1711,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
// Add new nodes to name_map
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
@ -1755,7 +1765,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
auto results = new TF_ImportGraphDefResults();
mutex_lock l(graph->mu);
GraphImportGraphDefLocked(graph, def, options, results, status);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
delete results;
return nullptr;
}
@ -1813,7 +1823,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
// TODO(skyewm): set placeholder shape
TF_Operation* oper = TF_FinishOperation(desc, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!status->status.ok()) return false;
*input = {oper, 0};
return true;
}
@ -1958,7 +1968,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
body_graph, body_inputs, body_outputs, name};
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
FreeWhileResources(&params);
return EmptyWhileParams();
}
@ -2160,7 +2170,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
@ -2208,7 +2218,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
status->status =
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
export_dir, tag_set, &bundle);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
// extends using GraphDefs. The Graph instance is different, but equivalent
@ -2221,11 +2231,11 @@ TF_Session* TF_LoadSessionFromSavedModel(
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
import_opts, &results, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
if (meta_graph_def != nullptr) {
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
}
TF_Session* session = new TF_Session(bundle.session.release(), graph);
@ -2325,7 +2335,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
if (status->status.ok()) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -2387,9 +2397,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
tensor, graph->refiner, *graph->graph.op_registry(),
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(TF_GetCode(status) == TF_OK);
DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status);
if (TF_GetCode(status) != TF_OK) evaluated = false;
if (!status->status.ok()) evaluated = false;
}
return evaluated;
}
@ -2444,7 +2454,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2456,7 +2466,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2468,7 +2478,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
tensorflow::GetRegisteredKernelsForOp(name);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (TF_GetCode(status) != TF_OK) {
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2498,7 +2508,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len,
std::unique_ptr<tensorflow::ServerInterface> out_server;
status->status = tensorflow::NewServer(server_def, &out_server);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
return new TF_Server(std::move(out_server));
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)

View File

@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper,
// producer.index) to consumer.oper's input (given by consumer.index).
TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in);
// Get list of all inputs of a specific operation. `inputs` must point to
// an array of length at least `max_inputs` (ideally set to
// TF_OperationNumInputs(oper)). Beware that a concurrent
// modification of the graph can increase the number of inputs of
// an operation.
TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper,
TF_Output* inputs,
int max_inputs);
// Get the number of current consumers of a specific output of an
// operation. Note that this number can change when new operations
// are added to the graph.

View File

@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
static void CheckOk(TF_Status* status) {
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
if (!TFE_TensorHandleIsConcrete(handle)) {

View File

@ -41,6 +41,7 @@ namespace {
// node names, so if necessary we add a suffix to make
// names unique. If we have an input named "A" and a node in the function
// body named "a", they will be renamed to "a" and "a_0".
// TODO(b/139886381) Unify this and the one in graph_to_functiondef.cc
class NodeNameMapping {
public:
NodeNameMapping() = default;
@ -64,14 +65,14 @@ class NodeNameMapping {
string Lookup(const string& name) const;
private:
string UniquifyHelper(const string& name) const;
string UniquifyHelper(const string& name);
static string Normalize(string name);
// The normalized/uniquified names already used as
// input names (in signature), output names (in signature), and node names
// (in node_def).
// This is a superset of values in name_mapping_.
std::unordered_set<string> used_names_;
std::unordered_map<string, uint64> used_names_;
// Mapping from original node name from the graph to the normalized
// and uniquified version of it.
std::unordered_map<string, string> name_mapping_;
@ -102,13 +103,16 @@ string NodeNameMapping::Normalize(string name) {
return i == n ? "unknown" : name.substr(i);
}
string NodeNameMapping::UniquifyHelper(const string& name) const {
string NodeNameMapping::UniquifyHelper(const string& name) {
auto it = used_names_.emplace(name, 0);
// If the name hasn't been used yet, use it as-is.
if (used_names_.find(name) == used_names_.end()) return name;
if (it.second) return name;
// Add a suffix to name to make it unique.
for (int i = 0;; ++i) {
const string candidate = strings::StrCat(name, "_", i);
if (used_names_.find(candidate) == used_names_.end()) return candidate;
while (true) {
const string candidate = strings::StrCat(name, "_", it.first->second);
it.first->second++;
if (used_names_.emplace(candidate, 0).second) return candidate;
}
}
@ -120,16 +124,13 @@ string NodeNameMapping::GetInputName(const string& name) {
string NodeNameMapping::GetOutputName(const string& name) {
const string& input_name = UniquifyHelper(Normalize(name));
// Record that we used this name, but don't add it to name_mapping_
// since this name is not for a node.
used_names_.insert(input_name);
// Don't add it to name_mapping_ since this name is not for a node.
return input_name;
}
string NodeNameMapping::Uniquify(const string& name) {
const string uniqued = UniquifyHelper(name);
name_mapping_[name] = uniqued;
used_names_.insert(uniqued);
return uniqued;
}
@ -139,7 +140,7 @@ Status NodeNameMapping::UseOutputName(const string& name) {
return InvalidArgument("Cannot have duplicate output names. Name '", name,
"' appears more than once in 'output_names' array.");
}
used_names_.insert(iter, name);
used_names_.emplace(name, 0);
return Status::OK();
}

View File

@ -9,7 +9,6 @@ load(
)
load(
"//tensorflow/core/platform:default/build_config.bzl",
"tf_additional_device_tracer_test_flags",
"tf_kernel_tests_linkstatic",
)
load(
@ -37,6 +36,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/container:fixed_array",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -79,6 +79,7 @@ tf_cuda_library(
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
)
tf_cuda_library(
@ -226,6 +227,7 @@ tf_cuda_library(
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
)
tf_cuda_cc_test(
@ -234,8 +236,7 @@ tf_cuda_cc_test(
srcs = [
"c_api_experimental_test.cc",
],
args =
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],

View File

@ -26,12 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA
@ -60,6 +62,7 @@ limitations under the License.
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -99,32 +102,34 @@ string DeviceName(const tensorflow::Device* d) {
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
tensorflow::Notification n;
tensorflow::mutex remote_devices_mu;
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, remote_worker,
[&status, &n, &remote_devices](
tensorflow::Env::Default(), worker_cache, remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
status = s;
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
n.Notify();
counter.DecrementCount();
});
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
@ -134,11 +139,15 @@ tensorflow::Status CreateRemoteContexts(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
for (int i = 0; i < remote_workers.size(); i++) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
@ -158,16 +167,17 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
tensorflow::Notification n;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
eager_client->CreateContextAsync(
&request, &response, [&status, &n](const tensorflow::Status& s) {
status = s;
n.Notify();
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(status);
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
@ -214,7 +224,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
@ -246,7 +256,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor()->Async(), base_request));
ctx->context->Executor().Async(), base_request));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
@ -384,7 +394,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(std::move(devices)));
new tensorflow::StaticDeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
@ -563,7 +573,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), handle->Context()->Executor(),
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
@ -893,10 +903,9 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
*num_retvals);
status->status =
tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
if (!status->status.ok()) {
return;
}
@ -916,7 +925,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
ctx->context->Executor(),
&ctx->context->Executor(),
device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
@ -967,7 +976,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
status->status = ctx->context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
@ -979,9 +988,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
for (const auto& attr : func.attr()) {
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
}
return func_op;
}
@ -1029,7 +1038,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} break;
case tensorflow::AttrValue::kFunc: {
const auto func_op = GetFunc(ctx, default_value.func(), status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.

View File

@ -597,5 +597,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(ctx->context->Executor());
return new TFE_Executor(&ctx->context->Executor());
}

View File

@ -84,11 +84,6 @@ void ExecuteWithProfiling(bool async) {
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
#ifndef TENSORFLOW_USE_ROCM
// ROCm platform does not yet support stream level tracing
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
#endif
}
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));

View File

@ -1069,10 +1069,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
// still fail.
TF_SetStatus(status, TF_OK, "");
TFE_DeleteTensorHandle(retvals[0]);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
EXPECT_NE(TF_OK, TF_GetCode(status));
TF_SetStatus(status, TF_OK, "");
retvals[0] = nullptr;
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status));
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorClearError(executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -18,6 +18,7 @@ limitations under the License.
// Language-agnostic gradient tape. Does not perform backpropagation, just
// maintains the data structures required to do so.
#include <stack>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
@ -209,7 +210,9 @@ class ForwardAccumulator {
// ForwardAccumulator.
explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
: vspace_(vspace) {
call_state_.emplace(nullptr, false);
}
virtual ~ForwardAccumulator() {
for (auto accumulated : accumulated_gradients_) {
@ -262,11 +265,11 @@ class ForwardAccumulator {
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
// Returns true if `Accumulate` is active somewhere above on the stack. This
// is useful for ordering ForwardAccumulators, where more deeply nested
// accumulators should not see computations from less deeply nested
// accumulators.
bool BusyAccumulating() const { return this->accumulating_; }
// Returns true if `Accumulate` is active somewhere above on the stack and
// there isn't an intervening PushState. This is useful for ordering
// ForwardAccumulators, where more deeply nested accumulators should not see
// computations from less deeply nested accumulators.
bool BusyAccumulating() const { return call_state_.top().accumulating; }
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
// a nullptr if none is available.
@ -282,6 +285,15 @@ class ForwardAccumulator {
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }
void PopState() { call_state_.pop(); }
private:
// Helper for Accumulate: uses a GradientTape to compute forward gradients
// from a backward gradient function. Fills `out_grads` corresponding to
@ -289,7 +301,7 @@ class ForwardAccumulator {
//
// Executes the backward function in order to trace its gradient, which will
// waste computation if executing eagerly (when graph building the unneeded
// computation is pruned). Temporarily sets `backward_tape_` so that
// computation is pruned). Temporarily sets `backward_tape` so that
// Accumulate will forward op executions to the tape while the backward
// function is running; this effectively adds the backward tape to the active
// set (but does not require complicated callbacks to the language bindings).
@ -305,16 +317,26 @@ class ForwardAccumulator {
// Not owned; provides operations on Tensors which are currently only
// available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Set temporarily while in the Accumulate method; if backward_tape_ is not
// nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function.
//
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
// keeps ownership.
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
// While the Accumulate method is running (accumulating_ is True), any op
// executions not forwarded to backward_tape_ should be ignored.
bool accumulating_;
struct AccumulatorCallState {
AccumulatorCallState(
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
bool accumulating)
: backward_tape(backward_tape), accumulating(accumulating) {}
// Set temporarily while in the Accumulate method; if backward_tape is not
// nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function.
//
// Not owned by the ForwardAccumulator. The method which sets
// `backward_tape` keeps ownership.
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
// While the Accumulate method is running (accumulating is True), any op
// executions not forwarded to backward_tape should be ignored.
bool accumulating;
};
// A deque-backed stack, whose element references are not invalidated by
// pushes and pops at the back.
std::stack<AccumulatorCallState> call_state_;
};
// Template instantiations here
@ -645,16 +667,15 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
Status s = InitialGradients(vspace, target_tensor_ids,
sources_that_are_targets, output_gradients,
tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
auto cleanup = gtl::MakeCleanup([this, &state]() {
if (!persistent_) {
// Release all backprop functions
for (const auto& pair : state.op_tape) {
pair.second.backward_function_deleter(pair.second.backward_function);
}
}
};
});
if (!s.ok()) {
cleanup();
return s;
}
@ -732,11 +753,17 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
s = vspace.CallBackwardFunction(trace.backward_function,
unneeded_gradients, out_gradients,
&in_gradients);
if (in_gradients.size() != trace.input_tensor_id.size()) {
return tensorflow::errors::Internal(
"Recorded operation '", trace.op_type,
"' returned too few gradients. Expected ",
trace.input_tensor_id.size(), " but received ",
in_gradients.size());
}
if (!persistent_) {
trace.backward_function_deleter(trace.backward_function);
}
if (!s.ok()) {
cleanup();
return s;
}
} else {
@ -847,12 +874,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
if (backward_tape_ != nullptr) {
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
if (call_state_.top().backward_tape != nullptr) {
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
// we should also delegate ShouldRecord.
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
}
if (accumulating_) {
if (call_state_.top().accumulating) {
return false;
}
for (int i = 0; i < tensor_ids.size(); ++i) {
@ -884,9 +911,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
*/
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
backward_tape_ = tape.get();
AccumulatorCallState& call_state = call_state_.top();
call_state.backward_tape = tape.get();
auto pop_backward_tape =
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
std::vector<Gradient*> forwardprop_aids;
std::vector<int64> sources;
std::unordered_set<int64> sources_set;
@ -894,6 +922,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
for (const TapeTensor& output_tensor : output_tensors) {
// Ownership of `aid` transferred to CallBackwardFunction below.
Gradient* aid = vspace_.Ones(output_tensor);
if (TF_PREDICT_FALSE(aid == nullptr)) {
return tensorflow::errors::Internal(
"Failed to create ones tensor for tensor ", output_tensor.GetID(),
" with dtype ", output_tensor.GetDType());
}
forwardprop_aids.push_back(aid);
int64 aid_id = vspace_.TensorId(aid);
sources.push_back(aid_id);
@ -961,10 +994,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (backward_tape_ != nullptr) {
// If backward_tape_ is not null, then this call to Accumulate is the result
if (call_state_.top().backward_tape != nullptr) {
// If backward_tape is not null, then this call to Accumulate is the result
// of a still-active call to Accumulate which is running operations. We
// forward these operations to backward_tape_ so the outer Accumulate call
// forward these operations to backward_tape so the outer Accumulate call
// can do its work.
//
// Rather than re-entering and delegating Accumulate like this, we could
@ -972,9 +1005,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
// (so it can deactivate itself and activate its GradientTape). Currently
// that is managed by the language binding and would require relatively
// messy callbacks.
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
input_dtypes, backward_function_getter,
backward_function_deleter);
call_state_.top().backward_tape->RecordOperation(
op_type, output_tensors, input_tensor_id, input_dtypes,
backward_function_getter, backward_function_deleter);
return Status::OK();
}
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
@ -1012,9 +1045,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
// Avoid infinite recursion. Whichever forward function we run, it'll end up
// executing ops, and we don't want to watch those with this accumulator.
accumulating_ = true;
auto reset_accumulating =
gtl::MakeCleanup([this] { this->accumulating_ = false; });
call_state_.emplace(nullptr, true);
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
std::vector<Gradient*> forward_grads;
if (forward_function == nullptr) {

View File

@ -123,6 +123,7 @@ cc_library(
"//tensorflow/core/util/tensor_bundle:naming",
# mobile not supported yet
]),
alwayslink = 1,
)
tf_cc_test(

View File

@ -48,12 +48,12 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
export_dir);
}
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags,
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
SavedModel* saved_model_proto,
MetaGraphDef* meta_graph_def) {
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
<< " }";
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
// Get tags from the graph_def.
std::unordered_set<string> graph_tags;
for (const string& tag : graph_def.meta_info_def().tags()) {
@ -61,7 +61,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
}
// Match with the set of tags provided.
if (graph_tags == tags) {
*meta_graph_def = graph_def;
*meta_graph_def = std::move(graph_def);
return Status::OK();
}
}
@ -81,7 +81,8 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
MetaGraphDef* const meta_graph_def) {
SavedModel saved_model_proto;
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
TF_RETURN_IF_ERROR(
FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
return Status::OK();
}

View File

@ -1,5 +1,3 @@
# -*- Python -*-
"""Build macro that compiles a TensorFlow graph into a cc_library.
To use from your BUILD file, add the following line to load the macro:

View File

@ -1,12 +1,11 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
package(
default_visibility = [
":internal",
"//tensorflow/core/common_runtime/eager:__pkg__",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
@ -39,7 +38,7 @@ cc_library(
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
] + if_cuda([
] + if_cuda_or_rocm([
":xla_gpu_device",
":xla_gpu_jit",
]),
@ -62,7 +61,7 @@ cc_library(
cc_library(
name = "xla_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda([
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -145,8 +144,57 @@ cc_library(
],
)
XLA_DEVICE_DEPS = [
":common",
":xla_launch_util",
":xla_tensor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor/platform",
]
cc_library(
name = "xla_device",
name = "xla_device_no_jit_rewrite_registration",
srcs = [
"xla_compile_on_demand_op.cc",
"xla_device.cc",
@ -159,56 +207,22 @@ cc_library(
"xla_device_context.h",
"xla_device_ops.h",
],
deps = XLA_DEVICE_DEPS,
)
cc_library(
name = "xla_device",
hdrs = [
"xla_compile_on_demand_op.h",
"xla_device.h",
"xla_device_context.h",
"xla_device_ops.h",
],
# Public visibility is needed for external TF/XLA backends.
visibility = ["//visibility:public"],
deps = [
":common",
deps = XLA_DEVICE_DEPS + [
":jit_compilation_passes",
":xla_launch_util",
":xla_tensor",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor/platform",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
":xla_device_no_jit_rewrite_registration",
],
)
@ -262,6 +276,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -269,6 +284,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/lib/gtl:array_slice",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
@ -323,22 +340,27 @@ cc_library(
":compilation_passes",
":xla_activity_logging_listener",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/compiler/tf2xla:mlir_bridge_pass_registration",
"//tensorflow/core:core_cpu_internal",
] + tf_jit_compilation_passes_extra_deps(),
alwayslink = 1,
)
# Linked by tensorflow core, without registration of jit compilation passes
# which is not necessary to create and run a XlaLocalLaunchBase kernel.
# Linking jit compilation passes could cause programs stuck right now (b/140069592).
cc_library(
name = "xla_kernel_creator",
name = "xla_kernel_creator_util",
srcs = [
"xla_kernel_creator.cc",
"xla_kernel_creator_util.cc",
],
hdrs = ["xla_kernel_creator.h"],
hdrs = ["xla_kernel_creator_util.h"],
visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"],
deps = [
":common",
":compilability_check_util",
":compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -351,6 +373,23 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_kernel_creator",
srcs = [
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
deps = [
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_cc_test(
name = "xla_kernel_creator_test",
srcs = [
@ -846,6 +885,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/memory",

View File

@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@ -85,7 +86,7 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
} // anonymous namespace
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
RecursiveCompilabilityChecker::UncompilableNodesMap
RecursiveCompilabilityChecker::FindUncompilableNodes(
const Node& node, FunctionLibraryRuntime* lib_runtime,
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
@ -100,12 +101,14 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
}
}
stack_trace.emplace_back(StackFrameView{node.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace,
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
return uncompilable_nodes;
}
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
RecursiveCompilabilityChecker::UncompilableNodesMap
RecursiveCompilabilityChecker::FindUncompilableNodes(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
@ -120,22 +123,31 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
}
}
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
IsCompilableCall(call_def, lib_runtime, &stack_trace,
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
return uncompilable_nodes;
}
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
bool RecursiveCompilabilityChecker::HasXLAKernel(
const Node& node, string* uncompilable_reason) const {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
if (node.type_string() == "SymbolicGradient") {
*uncompilable_reason =
"SymbolicGradient should be handled by IsCompilableCall().";
return false;
}
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
*uncompilable_reason =
"Const op with type DT_STRING is not supported by XLA.";
return false;
}
}
@ -145,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
*uncompilable_reason = "Identity with unsafe cast.";
return false;
}
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
if (!s.ok()) {
*uncompilable_reason = s.error_message();
return false;
}
return true;
}
// Tests whether 'if_node' is compilable. Every operator in the then_branch and
@ -156,16 +174,18 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
bool RecursiveCompilabilityChecker::IsCompilableIf(
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
bool is_compilable = true;
is_compilable &= ExtractNodeDefAndCheckCompilability(
if_node, "then_branch", "if_then", lib_runtime, stack_trace,
uncompilable_nodes);
if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
stack_trace, uncompilable_nodes);
if (!uncompilable_nodes && !is_compilable) return is_compilable;
is_compilable &= ExtractNodeDefAndCheckCompilability(
if_node, "else_branch", "if_else", lib_runtime, stack_trace,
uncompilable_nodes);
if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
stack_trace, uncompilable_nodes);
return is_compilable;
}
@ -176,37 +196,43 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
bool RecursiveCompilabilityChecker::IsCompilableWhile(
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
bool is_compilable = true;
is_compilable &= ExtractNodeDefAndCheckCompilability(
while_node, "cond", "while_cond", lib_runtime, stack_trace,
uncompilable_nodes);
while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
stack_trace, uncompilable_nodes);
if (!uncompilable_nodes && !is_compilable) return is_compilable;
is_compilable &= ExtractNodeDefAndCheckCompilability(
while_node, "body", "while_body", lib_runtime, stack_trace,
uncompilable_nodes);
while_node, "body", "while_body", encapsulating_function, lib_runtime,
stack_trace, uncompilable_nodes);
return is_compilable;
}
bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
const Node& node, const std::string& attr_name,
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
const std::string& call_name, NameAttrList* encapsulating_function,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
NodeDef call;
call.set_name(call_name);
if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
const auto uncompilable_reason = absl::StrCat(
"missing '", attr_name, "' attribute from node", node.name());
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
<< ".";
return false;
}
if (!IsCompilableCall(call, lib_runtime, stack_trace, uncompilable_nodes)) {
if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
VLOG(2) << "Rejecting node " << node.name()
<< ": can't compile : " << call.op();
return false;
@ -220,24 +246,33 @@ bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
bool RecursiveCompilabilityChecker::IsCompilableCall(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
if (stack_trace->size() > kMaxRecursionDepth) {
std::string uncompilable_reason = "function depth limit exceeded";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
<< ".";
return false;
}
FunctionLibraryRuntime::Handle handle;
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
if (!status.ok()) {
Status s;
NameAttrList function;
s = NameAndAttrsFromFunctionCall(call_def, &function);
if (s.ok()) {
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
&handle);
}
if (!s.ok()) {
std::string uncompilable_reason = "could not instantiate call";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
<< uncompilable_reason << " : " << status;
<< uncompilable_reason << " : " << s;
return false;
}
@ -246,9 +281,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
bool is_compilable = true;
for (const Node* node : fbody->graph->op_nodes()) {
stack_trace->emplace_back(StackFrameView{node->name(), call_def.op()});
is_compilable &=
IsCompilableNode(*node, lib_runtime, stack_trace, uncompilable_nodes);
stack_trace->emplace_back(StackFrameView{node->name(), function.name()});
is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
&function, uncompilable_nodes);
stack_trace->pop_back();
if (!uncompilable_nodes && !is_compilable) return is_compilable;
}
@ -279,12 +314,14 @@ bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
bool RecursiveCompilabilityChecker::IsCompilableNode(
const Node& node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
auto stack_depth = stack_trace->size();
if (node.IsSource() || node.IsSink()) {
absl::string_view uncompilable_reason = "source or sink node";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -295,7 +332,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
absl::string_view uncompilable_reason = "top level _Arg or _Retval";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -307,33 +344,36 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
absl::string_view uncompilable_reason =
"_scoped_allocator or _forward_from attribute";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
string uncompilable_reason;
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
uncompilable_nodes)) {
encapsulating_function, uncompilable_nodes)) {
LogNotCompilable(node, "unsupported function");
return false;
}
} else if (!HasXLAKernel(node)) {
absl::string_view uncompilable_reason = "unsupported op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
} else if (!HasXLAKernel(node, &uncompilable_reason)) {
MaybeMarkUncompilableNode(
absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
if (node.IsWhileNode() &&
!IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) {
!IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
LogNotCompilable(node, "unsupported while");
return false;
}
if (node.IsIfNode() &&
!IsCompilableIf(node, lib_runtime, stack_trace, uncompilable_nodes)) {
!IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
LogNotCompilable(node, "unsupported if");
return false;
}
@ -342,7 +382,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
IsStatefulRandomOp(node.type_string())) {
absl::string_view uncompilable_reason = "stateful random op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -350,7 +390,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
absl::string_view uncompilable_reason = "not allowed control trigger";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -359,7 +399,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
IsAssertOrCheckNumerics(node.type_string())) {
absl::string_view uncompilable_reason = "Assert or CheckNumerics";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -368,7 +408,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
OpProducesOrConsumesVariant(node)) {
absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -376,7 +416,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
absl::string_view uncompilable_reason = "Stack op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -384,7 +424,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
absl::string_view uncompilable_reason = "TensorArray op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -394,7 +434,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
absl::string_view uncompilable_reason =
"resource variable op in called function";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -406,7 +446,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
node.DebugString())
.IgnoreError();
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -417,7 +457,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
node.DebugString())
.IgnoreError();
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
uncompilable_nodes);
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}
@ -446,8 +486,9 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
/*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
const absl::string_view reason,
const std::vector<StackFrameView>& stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_node_list) {
if (!uncompilable_node_list) return;
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
if (!uncompilable_nodes) return;
UncompilableNodeInfo node_info;
node_info.uncompilable_reason = std::string(reason);
@ -459,7 +500,20 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
});
node_info.name = std::string(stack_trace.back().name);
(*uncompilable_node_list).push_back(std::move(node_info));
auto function =
encapsulating_function ? *encapsulating_function : NameAttrList();
auto function_identifier = function.ShortDebugString();
auto it = uncompilable_nodes->find(function_identifier);
if (it == uncompilable_nodes->end()) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompileable_node_info{std::move(node_info)};
uncompilable_nodes->emplace(
std::move(function_identifier),
std::make_pair(function, std::move(uncompileable_node_info)));
} else {
it->second.second.emplace_back(std::move(node_info));
}
}
} // namespace tensorflow

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
@ -129,19 +130,35 @@ class RecursiveCompilabilityChecker {
const DeviceType* jit_device_type)
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
// Returns a list of uncompilable nodes. When `node` is inside a function
// body, users can set `node_stack_trace` to provide an additional
// context for `node`'s placement within the outer most graph.
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
using UncompilableNodesMap =
std::map<std::string,
std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
// Returns a map where the key is the function identifier(short debug
// string) of the function encapsulating the uncompilable nodes, and the
// value is a pair of NameAttrList of the function and a vector of
// uncompilable node info. When uncompilable node is not inside any
// function call nodes, then key is a ShortDebugString() of an empty
// NameAttrList.
//
// Also, when `node` is inside a function body, users can set
// `node_stack_trace` to provide an additional context for `node`'s
// placement within the outer most graph.
UncompilableNodesMap FindUncompilableNodes(
const Node& node, FunctionLibraryRuntime* lib_runtime,
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
// Returns a list of uncompilable nodes in `call_def` that cannot be
// compiled by XLA. It is assumed that `call_def` is a call operation.
// When `node` is inside a function body, users can set
// Returns a map where the key is the function identifier(short debug
// string) of the function encapsulating the uncompilable nodes, and the
// value is a pair of NameAttrList of the function and a vector of
// uncompilable node info. When uncompilable node is not inside any
// function call nodes, then key is a ShortDebugString() of an empty
// NameAttrList.
//
// Also, when `node` is inside a function body, users can set
// `node_stack_trace` to provide an additional context for `node`'s
// placement within the outer most graph.
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
UncompilableNodesMap FindUncompilableNodes(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
@ -176,27 +193,31 @@ class RecursiveCompilabilityChecker {
bool IsCompilableNode(
const Node& node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
NameAttrList* encapsulating_function = nullptr,
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
bool IsCompilableCall(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
bool IsCompilableIf(
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
bool IsCompilableWhile(
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
NameAttrList* encapsulating_function = nullptr,
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
bool IsCompilableWhile(const Node& while_node,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Returns compilability of node def retrieved from `node`'s attribute with
// name `attr_name`.
bool ExtractNodeDefAndCheckCompilability(
const Node& node, const std::string& attr_name,
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
const std::string& call_name, NameAttrList* encapsulating_function,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
UncompilableNodesMap* uncompilable_nodes) const;
bool IsStackOp(const Node& node) const {
const XlaResourceOpInfo* op_info =
@ -226,12 +247,14 @@ class RecursiveCompilabilityChecker {
absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node) const;
bool HasXLAKernel(const Node& node,
string* uncompilable_reason = nullptr) const;
static void MaybeMarkUncompilableNode(
const absl::string_view reason,
const std::vector<StackFrameView>& stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_node_list);
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes_map);
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;

View File

@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -117,10 +119,16 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
const auto uncompilable_nodes =
checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
ASSERT_EQ(1, uncompilable_nodes.size());
const auto& node_info = uncompilable_nodes.at(0);
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
ASSERT_EQ(1, node_info.stack_trace.size());
ASSERT_EQ("", node_info.stack_trace.at(0).function_name);
auto node_info_it =
uncompilable_nodes.find(NameAttrList().ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
"unsupported op"));
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
}
TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
@ -147,14 +155,21 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
checker_->FindUncompilableNodes(*functional_node, flib_runtime);
EXPECT_EQ(1, uncompilable_nodes.size());
const auto& node_info = uncompilable_nodes.at(0);
NameAttrList function;
function.set_name(kUncompilableFunctionName);
const auto node_info_it =
uncompilable_nodes.find(function.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_node_list = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& node_info = uncompilable_node_list.at(0);
const auto& node_stack = node_info.stack_trace;
ASSERT_EQ(2, node_stack.size());
EXPECT_EQ("D", node_stack.at(0).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
@ -212,7 +227,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
ASSERT_EQ(1, uncompilable_nodes.size());
const auto& node_info = uncompilable_nodes.at(0);
NameAttrList function;
function.set_name(kUncompilableFunctionName);
const auto node_info_it =
uncompilable_nodes.find(function.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_node_list = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& node_info = uncompilable_node_list.at(0);
const auto& node_stack = node_info.stack_trace;
ASSERT_EQ(2, node_stack.size());
const auto& stacktrace_first_node_info = node_stack.at(0);
@ -225,7 +248,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
@ -280,7 +304,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
ASSERT_EQ(2, uncompilable_nodes.size());
const auto& uncompilable_node_one = uncompilable_nodes.at(0);
NameAttrList function_one;
function_one.set_name(kUncompilableFunctionName);
auto it = uncompilable_nodes.find(function_one.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), it);
const auto& uncompilable_node_list = it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& uncompilable_node_one = uncompilable_node_list.at(0);
const auto& node_one_stack = uncompilable_node_one.stack_trace;
ASSERT_EQ(2, node_one_stack.size());
@ -294,9 +325,17 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
const auto& uncompilable_node_two = uncompilable_nodes.at(1);
NameAttrList function_two;
function_two.set_name(kUncompilableFunctionTwoName);
it = uncompilable_nodes.find(function_two.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), it);
const auto& uncompilable_node_two_list = it->second.second;
ASSERT_EQ(1, uncompilable_node_two_list.size());
const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
const auto& node_two_stack = uncompilable_node_two.stack_trace;
ASSERT_EQ(2, node_two_stack.size());
const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
@ -310,7 +349,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
node_two_stacktrace_second_node.function_name);
EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
EXPECT_EQ("unsupported op", uncompilable_node_two.uncompilable_reason);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
}
} // namespace

View File

@ -29,6 +29,9 @@ limitations under the License.
namespace tensorflow {
namespace jit {
class DeviceInfoCache;
class DeviceSet;
// Instances of DeviceId represent TensorFlow devices as integers.
//
// This helps avoid having to manipulate device names as strings when

View File

@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
}
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices));
absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(device_mgr.get(),

View File

@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts;
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
}
Status ExtractOutsideCompilationTest(

View File

@ -5,35 +5,48 @@ package(
licenses = ["notice"], # Apache 2.0
)
XLA_OPS_DEPS = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_activity_listener",
"//tensorflow/compiler/jit:xla_activity_proto_cc",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
]
# Linked by tensorflow core, without registration of jit compilation passes.
cc_library(
name = "xla_ops",
name = "xla_ops_no_jit_rewrite_registration",
srcs = ["xla_ops.cc"],
hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_activity_listener",
"//tensorflow/compiler/jit:xla_activity_proto_cc",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
deps = XLA_OPS_DEPS,
alwayslink = 1,
)
cc_library(
name = "xla_ops",
hdrs = ["xla_ops.h"],
deps = XLA_OPS_DEPS + [
":xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/jit:jit_compilation_passes",
],
alwayslink = 1,
)

View File

@ -313,6 +313,8 @@ static Status CompileToLocalExecutable(
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// TODO(b/138728225): Set options.alias_passthrough_params for clusters
// without ref variables.
std::map<int, Tensor> constant_args;
for (int i : constants) {
@ -397,9 +399,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0));
/*missing_ctx_input_prefix=*/0, input_output_alias));
VLOG(1) << "Done";
}
@ -595,6 +599,9 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
@ -605,7 +612,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
ctx,
launch_context.PopulateOutputs(
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/closure.num_constant_args()));
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias));
}
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);

View File

@ -1639,10 +1639,9 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
}
} // anonymous namespace
bool IsCompilable(
FunctionLibraryRuntime* flr, const NodeDef& ndef,
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
uncompilable_node_info) {
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info) {
Device* device = flr->device();
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
@ -1668,8 +1667,8 @@ bool IsCompilable(
return checker.IsCompilableCall(ndef, flr);
}
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_result = checker.FindUncompilableNodes(ndef, flr);
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
checker.FindUncompilableNodes(ndef, flr);
uncompilable_node_info->swap(uncompilable_node_result);
return uncompilable_node_info->empty();
}

View File

@ -52,10 +52,9 @@ class MarkForCompilationPass : public GraphOptimizationPass {
// function is compilable iff every operator in the function body is
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
bool IsCompilable(
FunctionLibraryRuntime* flr, const NodeDef& ndef,
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
uncompilable_node_info = nullptr);
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info = nullptr);
namespace testing {
// DO NOT USE IN PRODUCTION.

View File

@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
/*input_buffer_bytes=*/k_buffer_size,
/*output_buffer_bytes=*/k_buffer_size,
io::ZlibCompressionOptions::GZIP());
string decompressed_pbtxt_string;
tstring decompressed_pbtxt_string;
Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
if (!s.ok() && !errors::IsOutOfRange(s)) {
// OutOfRange is fine since we set the number of read bytes to INT_MAX.

View File

@ -158,6 +158,7 @@ Status XlaCompilationCache::BuildExecutable(
: client_->default_device_ordinal());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);

View File

@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0));
/*missing_ctx_input_prefix=*/0, input_output_alias));
return Status::OK();
}

View File

@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 14> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16}};
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);

View File

@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 14> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16}};
constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);

View File

@ -14,243 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // namespace
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
const FunctionDef* function_def =
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
if (function_def == nullptr) {
// The node def is not calling a function. Individual ops can be
// run directly using on-demand mode, no need to create XlaLaunch
// kernel for them.
return false;
}
// If kXlaCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaCompileAttr);
if (it != node_def.attr().end()) {
return it->second.b();
}
// kXlaCompileAttr is not set on node_def, check if it is set on
// FunctionDef.
bool xla_compile = false;
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) {
if (!status.ok()) {
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString();
} else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
}
}
return false;
}
return true;
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
}
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (int i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
}
return Status::OK();
return CanCreateXlaKernel(flr, node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) const {
if (!CanCreateKernel(*flr, node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
if (!IsCompilable(flr, node_def, &uncompilable_node_info)) {
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
node_def.ShortDebugString(), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:\n");
for (const auto& node_info : uncompilable_node_info) {
string node_message =
absl::StrCat("\t", node_info.name, ": ",
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
// node_def is calling a function that XLA can't compile.
return errors::InvalidArgument(message);
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (int i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (int i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
function.set_name(node_def.op());
*(function.mutable_attr()) = node_def.attr();
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);
return s;
return CreateXlaKernel(flr, node_def, kernel);
}
namespace {

View File

@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto);
OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -0,0 +1,259 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // namespace
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
const FunctionDef* function_def =
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
if (function_def == nullptr) {
// The node def is not calling a function. Individual ops can be
// run directly using on-demand mode, no need to create XlaLaunch
// kernel for them.
return false;
}
// If kXlaCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaCompileAttr);
if (it != node_def.attr().end()) {
return it->second.b();
}
// kXlaCompileAttr is not set on node_def, check if it is set on
// FunctionDef.
bool xla_compile = false;
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) {
if (!status.ok()) {
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString();
} else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
}
}
return false;
}
return true;
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
}
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (int i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
}
return Status::OK();
}
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(*flr, node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
for (const auto& it : uncompilable_nodes_map) {
for (const auto& info : it.second.second) {
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
node_def.ShortDebugString(), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:\n");
for (const auto& node_info : uncompilable_node_info) {
string node_message =
absl::StrCat("\t", node_info.name, ": ",
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
// node_def is calling a function that XLA can't compile.
return errors::InvalidArgument(message);
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (int i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (int i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
function.set_name(node_def.op());
*(function.mutable_attr()) = node_def.attr();
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def);
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_

View File

@ -247,9 +247,50 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
namespace {
bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias,
int output_num) {
xla::ShapeIndex output_index;
if (input_output_alias.shape().IsTuple()) {
output_index = {output_num};
} else {
DCHECK_EQ(output_num, 0)
<< "output_num must be 0 for non-tuple shapes but is " << output_num;
output_index = {};
}
if (input_output_alias.shape().tuple_shapes_size() == 0) {
return false;
}
return input_output_alias.OutputHasAlias(output_index) &&
input_output_alias.GetAliasedParameter(output_index).value().kind ==
xla::HloInputOutputAliasConfig::kUserAlias;
}
} // namespace
Tensor XlaComputationLaunchContext::MakeOutputTensor(
DataType type, const TensorShape& shape, se::DeviceMemoryBase buffer,
int output_num, const xla::HloInputOutputAliasConfig& input_output_alias,
Allocator* allocator) {
bool is_aliased = false;
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
.value()
.parameter_number;
DCHECK(arg_ptrs_[xla_param] != nullptr);
buffer = arg_ptrs_[xla_param]->root_buffer();
is_aliased = true;
}
return XlaTensorBuffer::MakeTensor(type, shape,
/*unref_buffer=*/!is_aliased, buffer,
allocator);
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output, int missing_ctx_input_prefix) {
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@ -343,8 +384,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
if (MustAliasOutput(input_output_alias, output_num)) {
return errors::Unimplemented(
"Aliasing is not yet supported for allocate_xla_tensors_.");
}
Tensor* output_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
@ -359,8 +407,10 @@ Status XlaComputationLaunchContext::PopulateOutputs(
CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor =
MakeOutputTensor(ctx->expected_output_dtype(i), shape, buffer,
output_num, input_output_alias, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
@ -408,6 +458,10 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
if (allocate_xla_tensors_) {
if (MustAliasOutput(input_output_alias, output_num)) {
return errors::Unimplemented(
"Aliasing is not yet supported for allocate_xla_tensors_.");
}
Tensor output_tensor;
TF_RETURN_IF_ERROR(
ctx->allocate_temp(write.type, write.shape, &output_tensor));
@ -423,8 +477,9 @@ Status XlaComputationLaunchContext::PopulateOutputs(
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
Tensor output_tensor =
MakeOutputTensor(write.type, write.shape, buffer, output_num,
input_output_alias, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
}
++output_num;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
@ -149,16 +150,21 @@ class XlaComputationLaunchContext {
//
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
// missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output,
int missing_ctx_input_prefix);
Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
Tensor MakeOutputTensor(
DataType type, const TensorShape& shape, se::DeviceMemoryBase buffer,
int output_num, const xla::HloInputOutputAliasConfig& input_output_alias,
Allocator* allocator);
xla::LocalClient* client_;
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
@ -193,12 +199,15 @@ class XlaTensorBuffer : public TensorBuffer {
}
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
se::DeviceMemoryBase buffer, Allocator* allocator) {
bool unref_buffer, se::DeviceMemoryBase buffer,
Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
buffer.size(), allocator);
Tensor t(dtype, shape, tensor_buffer);
tensor_buffer->Unref();
if (unref_buffer) {
tensor_buffer->Unref();
}
return t;
}

View File

@ -4,7 +4,10 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = ["@local_config_mlir//:friends"],
default_visibility = [
"//tensorflow/compiler/tf2xla:__subpackages__",
"@local_config_mlir//:friends",
],
licenses = ["notice"], # Apache 2.0
)
@ -19,10 +22,20 @@ filegroup(
srcs = glob(["**/*.td"]),
)
cc_library(
name = "op_name_mapper",
srcs = ["op_name_mapper.cc"],
hdrs = ["op_name_mapper.h"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@llvm//:support",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "tf_mlir_opt_main",
srcs = ["tf_mlir_opt_main.cc"],
copts = ["-std=c++14"],
deps = [
":init_mlir",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
@ -32,9 +45,12 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/xla",
"//tensorflow/compiler/mlir/xla:lxla",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",

View File

@ -0,0 +1,3 @@
# TensorFlow MLIR
These are the docs for: https://www.tensorflow.org/mlir

View File

@ -0,0 +1,24 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
path: /resources
is_default: true
menu:
- include: /resources/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide & Tutorials
contents:
- title: Overview
path: /mlir/overview
- heading: Dialects
- title: TensorFlow
path: /mlir/tf_ops
- title: TensorFlow Lite
path: /mlir/tfl_ops
- include: /_upper_tabs_right.yaml

View File

@ -0,0 +1,48 @@
book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
items:
- description: >
The MLIR project defines a common intermediate representation (IR) that
unifies the infrastructure required to execute high performance machine
learning models in TensorFlow and similar ML frameworks. This project
will include the application of HPC techniques, along with integration of
search algorithms like reinforcement learning. MLIR aims to reduce the
cost to bring up new hardware, and improve usability for existing
TensorFlow users.
- code_block: |
<pre class = "prettyprint">
// Syntactically similar to LLVM:
func @testFunction(%arg0: i32) {
%x = call @thingToCall(%arg0) : (i32) -> i32
br ^bb1
^bb1:
%y = addi %x, %x : i32
return %y : i32
}
</pre>
- classname: devsite-landing-row-cards
items:
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
youtube_id: qzljG6DKgic
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=qzljG6DKgic
- heading: "A new intermediate representation and compiler framework"
image_path: /resources/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
- heading: TensorFlow MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/tensorflow/mlir
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/mlir

View File

@ -0,0 +1,11 @@
name: TensorFlow MLIR
breadcrumb_name: MLIR
home_url: /mlir/
parent_project_metadata_path: /_project.yaml
description: >
MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
use_site_branding: true
hide_from_products_list: true
content_license: cc-apache
buganizer_id: 443907
include: /_project_included.yaml

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 148 KiB

View File

@ -0,0 +1,5 @@
# MLIR overview
## Overview
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
load(
"@local_config_mlir//:tblgen.bzl",
"gentbl",
@ -146,7 +146,6 @@ cc_library(
hdrs = [
"utils/validators.h",
],
copts = ["-std=c++14"],
deps = [
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
@ -169,7 +168,6 @@ cc_library(
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
@ -187,6 +185,41 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "lstm_utils",
srcs = [
"utils/lstm_utils.cc",
],
hdrs = [
"utils/lstm_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
tf_cc_test(
name = "lstm_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
)
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
@ -194,16 +227,18 @@ cc_library(
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
"transforms/generated_prepare_tf.inc",
"transforms/legalize_ophint_func_op.cc",
"transforms/legalize_tf.cc",
"transforms/lower_static_tensor_list.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
],
hdrs = [
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
copts = ["-std=c++14"],
deps = [
":common",
":tensorflow_lite",
@ -233,7 +268,6 @@ cc_library(
hdrs = [
"transforms/passes.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
":validators",
@ -252,6 +286,7 @@ cc_library(
name = "tensorflow_lite_quantize",
srcs = [
"transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc",
"transforms/post_quantize.cc",
"transforms/prepare_quantize.cc",
"transforms/quantize.cc",
@ -260,7 +295,6 @@ cc_library(
hdrs = [
"transforms/passes.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
":validators",
@ -305,7 +339,6 @@ cc_library(
srcs = [
"ir/dialect_registration.cc",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"@local_config_mlir//:IR",
@ -349,7 +382,6 @@ cc_library(
hdrs = [
"flatbuffer_operator.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
@ -379,7 +411,6 @@ cc_library(
hdrs = [
"emit_error_reporter.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/lite/core/api",
"@local_config_mlir//:IR",
@ -398,11 +429,11 @@ cc_library(
"flatbuffer_translate.h",
"utils/convert_type.h",
],
copts = ["-std=c++14"],
deps = [
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir:op_name_mapper",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
@ -450,7 +481,6 @@ cc_library(
hdrs = [
"tf_tfl_translate_cl.h",
],
copts = ["-std=c++14"],
deps = [
"@llvm//:support",
],
@ -462,7 +492,6 @@ cc_library(
hdrs = [
"common/tfl_pass_config.h",
],
copts = ["-std=c++14"],
deps = [
"@llvm//:support",
],
@ -523,7 +552,6 @@ cc_library(
hdrs = [
"tf_tfl_passes.h",
],
copts = ["-std=c++14"],
deps = [
":common",
":tensorflow_lite_legalize_tf",
@ -531,6 +559,7 @@ cc_library(
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
@ -552,7 +581,6 @@ cc_library(
hdrs = [
"tf_to_tfl_flatbuffer.h",
],
copts = ["-std=c++14"],
deps = [
":flatbuffer_translate_lib",
":tensorflow_lite",

View File

@ -98,8 +98,7 @@ Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
if (tensor.name.empty()) {
return base;
}
return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base,
builder.getContext());
return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
}
// Returns the correct type for a quantized tensor
@ -478,8 +477,7 @@ StatusOr<FuncOp> ConvertSubgraph(
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc,
builder.getContext());
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
// Construct function type
for (auto input : subgraph.inputs) {

View File

@ -53,9 +53,10 @@ limitations under the License.
#include "mlir/Translation.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/op_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils//convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -89,6 +90,8 @@ using mlir::TranslateFromMLIRRegistration;
using mlir::Type;
using mlir::UnknownLoc;
using mlir::Value;
using tensorflow::OpLocNameMapper;
using tensorflow::OpNameMapper;
using tensorflow::Status;
using tflite::flex::IsWhitelistedFlexOp;
using xla::StatusOr;
@ -96,7 +99,10 @@ using xla::StatusOr;
template <typename T>
using BufferOffset = flatbuffers::Offset<T>;
using CustomOptionsOffset = BufferOffset<flatbuffers::Vector<uint8_t>>;
template <typename T>
using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
namespace error = tensorflow::error;
namespace tfl = mlir::TFL;
@ -341,16 +347,16 @@ class Translator {
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops,
bool strip_debug_info);
OpNameMapper* op_name_mapper);
private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
bool strip_debug_info)
OpNameMapper* op_name_mapper)
: module_(module),
builder_(kInitialBufferSize),
strip_debug_info_(strip_debug_info) {
name_mapper_(*op_name_mapper),
builder_(kInitialBufferSize) {
// The first buffer must be empty according to the schema definition.
empty_buffer_ = tflite::CreateBuffer(builder_);
buffers_.push_back(empty_buffer_);
@ -369,10 +375,6 @@ class Translator {
Optional<std::string> TranslateInternal();
// Returns name that should be used by tensors for values generated by this
// operation.
std::string GetName(Operation* inst);
// Returns TFLite buffer populated with constant value if the operation is
// TFLite constant operation. Otherwise, returns an empty buffer. Emits error
// and returns llvm::None on failure.
@ -416,6 +418,15 @@ class Translator {
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
StringRef content);
// Encodes the `tfl.metadata` dictionary attribute of the module to the
// metadata section in the final model.
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector();
// Uses the tf.entry_function attribute (if set) to initialize the op to name
// mapping.
void InitializeNamesFromAttribute(FuncOp fn);
@ -427,11 +438,10 @@ class Translator {
// Returns a unique name for `op`.
std::string UniqueName(mlir::Operation* op);
// Returns a unique name starting with a given prefix.
std::string UniqueName(llvm::StringRef prefix);
ModuleOp module_;
tensorflow::OpNameMapper& name_mapper_;
flatbuffers::FlatBufferBuilder builder_;
BufferOffset<tflite::Buffer> empty_buffer_;
@ -446,61 +456,14 @@ class Translator {
absl::flat_hash_map<std::string, int> subgraph_index_map_;
absl::flat_hash_set<OpType> enabled_op_types_;
// Maps from op to name.
absl::flat_hash_map<mlir::Operation*, std::string> op_to_name_;
absl::flat_hash_map<std::string, int64_t> name_to_count_;
// Points to TensorFlow and TFLite dialects, respectively. nullptr if the
// dialect is not registered.
const Dialect* tf_dialect_;
const Dialect* tfl_dialect_;
// Suffix used to generate unique tensor names from operation names.
int name_counter_ = 0;
// Whether to strip or not emit debug info.
const bool strip_debug_info_;
};
std::string Translator::GetName(Operation* inst) {
// If strip_debug_info_ is set, then simply return counter value.
if (strip_debug_info_) return Twine(name_counter_++).str();
if (auto name_loc = inst->getLoc().dyn_cast<mlir::NameLoc>())
return name_loc.getName().str();
if (auto call_loc = inst->getLoc().dyn_cast<mlir::CallSiteLoc>()) {
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
// if imported with DebugInfo), else use the fallback naming scheme below.
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
return name_loc.getName().str();
}
// If the location is none of the expected types, then simply use name
// generated using the op type.
return inst->getName().getStringRef().str();
}
std::string Translator::UniqueName(llvm::StringRef prefix) {
// Keep incrementing the counter until we find a unique name.
std::string name = prefix;
int64_t& prefix_count = name_to_count_[name];
int64_t val = prefix_count;
while (val != 0) {
name = (prefix + Twine(prefix_count)).str();
++prefix_count;
val = name_to_count_[name];
}
name_to_count_[name] = 1;
return name;
}
std::string Translator::UniqueName(mlir::Operation* op) {
auto& name = op_to_name_[op];
if (!name.empty()) return name;
// Update the value in the map with unique name.
name = UniqueName(GetName(op));
return name;
return name_mapper_.GetUniqueName(op);
}
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
@ -867,8 +830,8 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) {
return;
}
for (auto it : llvm::enumerate(fn.getArguments())) {
op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
++name_to_count_[input_names[it.index()].str()];
name_mapper_.InitOpName(*it.value()->user_begin(),
input_names[it.index()]);
}
}
@ -888,8 +851,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) {
// insert an op so that we can have a buffer named such. This cannot
// currently happen due to pseudo_input nodes.
if (auto op = it.value()->getDefiningOp()) {
op_to_name_[op] = output_names[it.index()];
name_to_count_[output_names[it.index()].str()] = 1;
name_mapper_.InitOpName(op, output_names[it.index()]);
} else {
fn.emitWarning() << "output is not due to an op and '"
<< output_names[it.index()]
@ -1027,14 +989,44 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
/*name=*/builder_.CreateString(fn.getName().str()));
}
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
StringRef content) {
auto buffer_index = buffers_.size();
auto buffer_data = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(content.data()), content.size());
buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
}
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
Translator::CreateMetadataVector() {
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
std::vector<BufferOffset<tflite::Metadata>> metadata;
for (const auto& named_attr : dict_attr) {
StringRef name = named_attr.first;
mlir::Attribute attr = named_attr.second;
if (auto content = attr.dyn_cast<StringAttr>()) {
metadata.push_back(BuildMetadata(name, content.getValue()));
} else {
module_.emitError(
"all values in tfl.metadata's dictionary key-value pairs should be "
"string attributes");
return llvm::None;
}
}
return builder_.CreateVector(metadata);
}
Optional<std::string> Translator::Translate(ModuleOp module,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops,
bool strip_debug_info) {
OpNameMapper* op_name_mapper) {
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, strip_debug_info);
emit_custom_ops, op_name_mapper);
return translator.TranslateInternal();
}
@ -1074,12 +1066,17 @@ Optional<std::string> Translator::TranslateInternal() {
} else {
model_description = "MLIR Converted.";
}
// Build the model and finish the model building process.
auto description = builder_.CreateString(model_description.data());
VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
auto metadata = CreateMetadataVector();
if (!metadata) return llvm::None;
auto model = tflite::CreateModel(
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
builder_.CreateVector(subgraphs), description,
builder_.CreateVector(buffers_));
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
tflite::FinishModelBuffer(builder_, model);
// Return serialized string for the built FlatBuffer.
@ -1100,22 +1097,38 @@ Optional<std::string> Translator::TranslateInternal() {
//
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops) {
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
OpNameMapper* op_name_mapper) {
auto maybe_translated =
Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, strip_debug_info_flag);
emit_custom_ops, op_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops) {
OpLocNameMapper op_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, &op_name_mapper);
}
static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction(
ModuleOp module, llvm::StringRef filename) {
std::string serialized_flatbuffer;
std::unique_ptr<OpNameMapper> op_name_mapper;
if (strip_debug_info) {
op_name_mapper = std::make_unique<tensorflow::OpStripNameMapper>();
} else {
op_name_mapper = std::make_unique<OpLocNameMapper>();
}
if (tflite::MlirToFlatBufferTranslateFunction(
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops))
emit_select_tf_ops, emit_custom_ops, op_name_mapper.get()))
return mlir::failure();
auto file = openOutputFile(filename);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/op_name_mapper.h"
// These flags are used to control the emission or not of different kinds of ops
// during the flatbuffer translation.
@ -33,12 +34,19 @@ extern bool strip_debug_info;
namespace tflite {
// Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string.
// serialized flatbuffer into the string. This uses OpLocNameMapper to convert
// location of the op to name in flatbuffer.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string *serialized_flatbuffer,
std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops);
// Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpNameMapper* op_name_mapper);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_

View File

@ -655,14 +655,40 @@ static LogicalResult Verify(UnpackOp op) {
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
ElementsAttr attr;
if (!matchPattern(value, m_Constant(&attr))) return {};
IntegerAttr int_attr = attr.getValue(llvm::None).cast<IntegerAttr>();
return int_attr.getValue().getSExtValue();
}
// Returns a RankedTensorType which is similar to `input_type` but replaces the
// dimension size of `dim` with `dim_size`. For example,
// `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
// `tensor<3x2xi32>`.
static RankedTensorType SubstituteRankedTensorTypeDimSize(
RankedTensorType input_type, int64_t dim, int64_t dim_size) {
auto shape = input_type.getShape().vec();
shape[dim] = dim_size;
return RankedTensorType::get(shape, input_type.getElementType());
}
// Verifies the output tensor types of SplitOp or SplitVOp.
template <typename ExpectedOutputTypeGetter>
static LogicalResult VerifySplitOpOutputTypes(
Operation *op, int64_t num_splits,
ExpectedOutputTypeGetter get_expected_output_type) {
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
Value *output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type;
}
return success();
}
static LogicalResult Verify(SplitOp op) {
int64_t num_splits = op.num_splits().getSExtValue();
if (op.getOperation()->getNumResults() != num_splits)
if (op.getNumResults() != num_splits)
return op.emitOpError("output count should match 'num_splits' attribute");
// If 'split_dim' is not a constant, there are no other checks.
@ -688,21 +714,100 @@ static LogicalResult Verify(SplitOp op) {
if (dim_size % num_splits != 0)
return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
// Creates sliced tensor type.
auto slice_shape = input_type.getShape().vec();
slice_shape[split_dim] = dim_size / num_splits;
RankedTensorType slice_type =
RankedTensorType::get(slice_shape, input_type.getElementType());
// Verifies output tensor types.
RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
input_type, split_dim, dim_size / num_splits);
return VerifySplitOpOutputTypes(
op.getOperation(), num_splits,
[expected_output_type](int64_t) { return expected_output_type; });
}
// Verifies result tensor types.
for (int64_t i = 0; i < num_splits; ++i) {
Value *result = op.getResult(i);
auto result_type = result->getType().dyn_cast<RankedTensorType>();
if (!result_type || result_type != slice_type)
return op.emitOpError() << "output #" << i << " should be " << slice_type;
static LogicalResult Verify(SplitVOp op) {
int64_t num_splits = op.num_splits().getSExtValue();
if (op.getNumResults() != num_splits)
return op.emitOpError("output count should match 'num_splits' attribute");
// If 'split_dim' is not a constant, there are no other checks.
llvm::Optional<int64_t> split_dim_opt =
ExtractConstantIntFromTensor(op.split_dim());
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
const int64_t rank = input_type.getRank();
if (split_dim < 0) split_dim += rank;
if (split_dim < 0 || split_dim >= rank)
return op.emitOpError("'split_dim' should be in [-rank, rank)");
// If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
// there are no other checks.
const int64_t dim_size = input_type.getDimSize(split_dim);
if (ShapedType::isDynamic(dim_size)) return success();
// If 'size_splits' is not a constant, there are no other checks.
ElementsAttr size_splits_attr;
if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
return success();
if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type =
op.size_splits()->getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ")
<< expected_size_splits_type;
}
return success();
// Normalizes and verifies 'size_splits'.
// Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
// means the rest of the dimension size.
llvm::SmallVector<int64_t, 4> size_splits;
size_splits.reserve(num_splits);
int64_t negative_size_split_loc = -1;
int64_t total_size_splits = 0;
for (int64_t i = 0; i < num_splits; ++i) {
auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
int64_t size_split = size_split_attr.getValue().getSExtValue();
size_splits.push_back(size_split);
if (size_split >= 0) {
total_size_splits += size_split;
continue;
}
if (size_split < -1)
return op.emitOpError(
"elements of 'size_splits' should be greater than or equal to -1");
if (negative_size_split_loc != -1)
return op.emitOpError("'size_splits' can only have one -1");
negative_size_split_loc = i;
}
if (negative_size_split_loc != -1) {
if (total_size_splits > dim_size)
return op.emitOpError(
"sum of non-negative elements of 'size_splits' is greater than the "
"dimension size of 'split_dim' axis");
size_splits[negative_size_split_loc] = dim_size - total_size_splits;
total_size_splits = dim_size;
}
if (total_size_splits != dim_size)
return op.emitOpError(
"sum of 'size_splits' should match the dimension size of 'split_dim' "
"axis");
// Verifies result tensor types.
auto get_expected_output_type = [input_type, split_dim,
&size_splits](int64_t i) {
return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
size_splits[i]);
};
return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
get_expected_output_type);
}
//===----------------------------------------------------------------------===//

View File

@ -132,15 +132,35 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general
// location.
// Returns true if the n-th operand has unknown rank or has rank m.
class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
// CPred version of TFL_OperandHasRank.
class TFL_OperandHasRankPred<int n, int m> :
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>;
// True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
# dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
@ -155,6 +175,32 @@ class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
// True if x_shape[dim] == y_shape[dim].
class TFL_DimOfOperandEqualsDimOfOperandPred<int x, int y, int dim> : And<[
TFL_OperandIsRankedAndHasDimPred<x, dim>,
TFL_OperandIsRankedAndHasDimPred<y, dim>,
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getShape()[" # dim # "] == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[" # dim # "]">]>;
// Select operands must satisfy one of the following constraints:
// All inputs are unranked/scalars
// OR
// All inputs are ranked AND have equal dim[0] AND X & Y have same rank.
def SelectShapeConstraints :
PredOpTrait<"Select operands meet shape criteria",
Or<[
And<[
TFL_OperandHasRankPred<0, 0>,
TFL_OperandHasRankPred<1, 0>,
TFL_OperandHasRankPred<2, 0>]>,
And<[
TFL_DimOfOperandEqualsDimOfOperandPred<0, 1, 0>,
TFL_DimOfOperandEqualsDimOfOperandPred<0, 2, 0>,
CPred<"$_op.getOperand(1)->getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(2)->getType().cast<ShapedType>().getRank()">]>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
TCOpResIsShapedTypePred<i, j>,
@ -239,7 +285,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
}
class TFL_ConvOp<string mnemonic, string opSummary> :
TFL_Op<mnemonic, [NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
let summary = opSummary # " operator";
let description = [{
@ -315,7 +361,7 @@ def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
// TODO(haoliang): Implement legalization pass after pattern rewrite generator
// supports variadic inputs.
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "add_n operator";
let description = [{
@ -323,11 +369,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I32]>>:$inputs
Variadic<TensorOf<[F32, I32, QI16, QUI16]>>:$inputs
);
let results = (outs
TensorOf<[F32, I32]>:$sum
TensorOf<[F32, I32, QI16, QUI16]>:$sum
);
}
@ -359,7 +405,7 @@ retained with length 1.
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
let description = [{
@ -454,7 +500,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameOperandsAndResultsScale
SameOperandsAndResultsScale
]> {
let summary = "Concatenation operator";
@ -464,14 +510,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
let arguments = (
ins Variadic<TensorOf<
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>>:$values,
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>>:$values,
I32Attr:$axis,
TFL_AFAttr:$fused_activation_function
);
let results = (outs
TensorOf<
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output
);
let hasOptions = 1;
@ -529,13 +575,13 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP.
// TODO(jpienaar): Include more shape verification.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
let summary = "Fully connected op";
let arguments = (ins
TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$input,
TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$filter,
TFL_TensorOfOrNone<[F32, TFL_QI32, TFL_QUI32]>:$bias,
TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
TFL_AFAttr:$fused_activation_function,
TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format,
@ -544,7 +590,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
// Depending on the weights format, this op can have one or two outputs.
let results = (outs
Variadic<TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>>:$output
Variadic<TensorOf<[F32, QI8, QUI8, QI16, QUI16]>>:$output
);
let hasOptions = 1;
@ -552,7 +598,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
def TFL_GatherOp : TFL_Op<"gather", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
@ -564,7 +610,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$params,
TensorOf<[F32, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
@ -577,7 +623,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
];
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, I16, I32, I64, TFL_Str, QI8, QUI8]>:$output
);
let hasOptions = 1;
@ -602,7 +648,7 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
// Same type check of lhs and rhs is handled by the Broadcastable trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Less_equal operator";
let description = [{
@ -610,8 +656,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
}];
let arguments = (
ins TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$lhs,
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$rhs);
ins TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs,
TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs);
let results = (outs TFL_BoolTensor:$output);
@ -643,7 +689,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
}];
let arguments = (ins
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, QI8, QUI8]>:$input,
I32Attr:$radius,
F32Attr:$bias,
F32Attr:$alpha,
@ -651,14 +697,14 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
);
let results = (outs
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, QI8, QUI8]>:$output
);
let hasOptions = 1;
}
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater_equal operator";
let description = [{
@ -680,8 +726,119 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
let hasOptions = 0;
}
// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be
// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the
// sense that one is an incremental change over the other.
// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and
// NonMaxSuppressionV4 performs hard NMS.
def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [
NoSideEffect,
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
TFL_OperandHasRank<0, 2>,
PredOpTrait<"boxes should have dim[1] == 4",
TFL_OperandDimEquals<0, 1, 4>>,
// Operand 1 (scores) should be a 1-dim tensor
TFL_OperandHasRank<1, 1>,
// Other operands are scalar params.
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
TFL_OperandHasRank<4, 0>]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
}];
let arguments = (ins
TFL_FpTensor:$boxes,
TFL_FpTensor:$scores,
I32Tensor:$max_output_size,
TFL_FpTensor:$iou_threshold,
TFL_FpTensor:$score_threshold
);
let results = (outs
I32Tensor:$selected_indices,
I32Tensor:$valid_outputs
);
}
def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [
NoSideEffect,
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
TFL_OperandHasRank<0, 2>,
PredOpTrait<"boxes should have dim[1] == 4",
TFL_OperandDimEquals<0, 1, 4>>,
// Operand 1 (scores) should be a 1-dim tensor
TFL_OperandHasRank<1, 1>,
// Other operands are scalar params.
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> {
let summary = [{
Greedily selects a subset of bounding boxes in descending order of score,
}];
let description = [{
pruning away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes with score less than
`score_threshold` are removed. Bounding boxes are supplied as
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
is agnostic to where the origin is in the coordinate system and more
generally is invariant to orthogonal transformations and translations
of the coordinate system; thus translating or reflections of the coordinate
system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input
collection of bounding boxes representing the selected boxes. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather operation`. For example:
selected_indices = tf.image.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold, score_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
of other overlapping boxes instead of directly causing them to be pruned.
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
larger than 0.
}];
let arguments = (ins
TFL_FpTensor:$boxes,
TFL_FpTensor:$scores,
I32Tensor:$max_output_size,
TFL_FpTensor:$iou_threshold,
TFL_FpTensor:$score_threshold,
TFL_FpTensor:$soft_nms_sigma
);
let results = (outs
I32Tensor:$selected_indices,
TFL_FpTensor:$selected_scores,
I32Tensor:$valid_outputs
);
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, TFL_NoQuantizableResult]> {
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
let summary = "Not_equal operator";
let description = [{
@ -766,7 +923,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
}
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
TFL_NoQuantizableResult,
NoQuantizableResult,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator";
@ -776,8 +933,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
let arguments = (
ins
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$x,
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$y
TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x,
TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y
);
let results = (outs TFL_BoolTensor:$output);
@ -844,7 +1001,7 @@ size 1.
}
def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "Removes dimensions of size 1 from the shape of a tensor.";
let description = [{
@ -944,7 +1101,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, TFL_NoQuantizableResult]> {
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator";
let description = [{
@ -979,6 +1136,25 @@ def TFL_InputOp : Op<TFL_Dialect, "pseudo_input", [SameOperandsAndResultType]> {
let results = (outs AnyTensor:$output);
}
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> {
let summary = "L2 Normalize Operator";
let description = [{
L2Normalization Op
}];
let arguments = (ins
TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
TFL_AFAttr:$fused_activation_function
);
let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
let hasOptions = 1;
let customOption = "L2NormOptions";
}
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Leaky Relu operator";
@ -1000,7 +1176,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
let hasOptions = 0b1;
}
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, TFL_NoQuantizableResult]> {
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator";
let description = [{
@ -1073,17 +1249,17 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
let summary = "Logistic operator";
let description = [{
Computes element-wise Sigmoid of input
}];
let arguments = (ins TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$x);
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
let results = (outs TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$y);
let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
}
def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
@ -1106,8 +1282,8 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
SameOperandsAndResultShape,
// zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<127, 625, -4>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<255, 625, -4>>]> {
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
FixedResultScale<UInt8UniformQuantizedType<255, 625, -4>>]> {
let summary = "Log softmax operator";
let description = [{
@ -1133,13 +1309,13 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
And<[
// The input and output tensors should have the same elemental type
// and they should be one of the specified types below.
TCopVTEtIs<0, AnyTypeOf<[F32, TFL_QI8, TFL_QUI8]>>,
TCopVTEtIs<0, AnyTypeOf<[F32, QI8, QUI8]>>,
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
NoSideEffect,
MaxPoolOperandAndResultConstraints,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "Max Pool 2D op";
let description = [{
@ -1167,19 +1343,19 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
let summary = "Max operator";
let description = [{
Element-wise max operation.
}];
let arguments = (
ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
);
let results = (outs
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$max
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -1187,7 +1363,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
let hasOptions = 0;
}
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Mean operator";
let description = [{
@ -1199,13 +1375,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_SameOperandsAndResultsScale]>
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$input,
TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims
);
let results = (outs
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output);
TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -1256,7 +1432,7 @@ Rounds the values of a tensor to the nearest integer, element-wise.
}
def TFL_SliceOp : TFL_Op<"slice", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Return a slice from 'input'.";
let description = [{
@ -1363,19 +1539,19 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
}
def TFL_MinimumOp : TFL_Op<"minimum", [
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
let summary = "Min operator";
let description = [{
Element-wise min operation.
}];
let arguments = (
ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
);
let results = (outs
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$min
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -1422,7 +1598,7 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Packs a list of tensors along a dimension into one tensor";
let description = [{
@ -1453,14 +1629,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64]>>:$values,
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64]>:$output
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -1470,7 +1646,7 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
def TFL_PadOp : TFL_Op<"pad", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
TFL_OperandHasRank<1, 2>,
TFL_OperandRankEquals1DimOfOperand<0, 1>]> {
let summary = "Padding operator";
@ -1500,17 +1676,17 @@ def TFL_PadOp : TFL_Op<"pad", [
}];
let arguments = (
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$padding);
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
let hasOptions = 1;
}
def TFL_PadV2Op : TFL_Op<"padv2", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
TFL_OperandHasRank<1, 2>,
TFL_OperandHasRank<2, 0>,
TFL_OperandRankEquals1DimOfOperand<0, 1>,
@ -1545,11 +1721,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
}];
let arguments = (
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$padding,
TensorOf<[F32, I8, I32, I64]>:$constant_values);
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
let hasOptions = 1;
}
@ -1570,6 +1746,8 @@ def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect]> {
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
@ -1587,7 +1765,7 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
SameOperandsAndResultShape,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "Relu operator";
let description = [{
@ -1602,7 +1780,7 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
SameOperandsAndResultShape,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "Relu6 operator";
let description = [{
@ -1616,7 +1794,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Reshape operator";
let description = [{
@ -1682,7 +1860,7 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
let hasFolder = 1;
}
def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> {
def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, NoQuantizableResult]> {
let summary = "Shape operator";
let description = [{
@ -1756,8 +1934,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
}
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
// TODO(jpienaar): This is too retrictive, rank 1 input is also allowed.
SameOperandsAndResultShape,
SelectShapeConstraints,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> {
@ -1810,12 +1987,12 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
let summary = "Softmax operator";
let description = [{
Computes element-wise softmax activiations with the following formula
Computes element-wise softmax activations with the following formula
exp(input) / tf.reduce_sum(exp(input * beta), dim)
}];
@ -1851,9 +2028,9 @@ def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> {
Computes element-wise Square of input
}];
let arguments = (ins TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$x);
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x);
let results = (outs TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$y);
let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y);
let hasOptions = 0b1;
@ -1913,17 +2090,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<0, 78125, -7>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<128, 78125, -7>>]> {
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
let summary = "Hyperbolic tangent operator";
let description = [{
Computes element-wise Hyperbolic tangent of input
}];
let arguments = (ins TensorOf<[F32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$x);
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
let results = (outs TensorOf<[F32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$y);
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
@ -1988,7 +2165,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
// TFL_OperandRankEquals1DimOfOperand<0, 1>,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "Transpose operator";
let description = [{
@ -2028,14 +2205,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[F32, I8, I32, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
I32Attr:$num,
I32Attr:$axis
);
let results = (outs
Variadic<TensorOf<[F32, I8, I32, TFL_QI8, TFL_QUI8]>>:$outputs
Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
@ -2059,7 +2236,7 @@ def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> {
def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
@ -2070,19 +2247,19 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
TensorOf<[I32]>:$block_shape,
TensorOf<[I32]>:$indices
);
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
);
}
def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
@ -2093,19 +2270,19 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
TensorOf<[I32]>:$block_shape,
TensorOf<[I32]>:$paddings
);
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
);
}
def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
@ -2119,12 +2296,12 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$input,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input,
I32Attr:$block_size
);
let results = (outs
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$output
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output
);
let hasOptions = 1;
@ -2132,7 +2309,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
@ -2148,21 +2325,18 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$input,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input,
I32Attr:$block_size
);
let results = (outs
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$output
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output
);
let hasOptions = 1;
}
def Rank0I32Tensor : Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0]>]>,
"tensor<i32>">;
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
@ -2172,13 +2346,13 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale
}];
let arguments = (ins
Rank0I32Tensor:$split_dim,
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
0DTensorOf<[I32]>:$split_dim,
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
PositiveI32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
@ -2186,7 +2360,7 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale
let hasOptions = 1;
}
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
@ -2196,21 +2370,23 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsSc
}];
let arguments = (ins
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
I32Tensor:$size_splits,
I32Tensor:$split_dim,
I32Attr:$num_splits
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
1DTensorOf<[I32]>:$size_splits,
0DTensorOf<[I32]>:$split_dim,
PositiveI32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
}
def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "ResizeBilinear Op";
let description = [{
@ -2219,12 +2395,12 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
let arguments = (ins
// TODO(ycling): Support quantized types.
TensorOf<[F32, I32, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, I32, QI8, QUI8]>:$input,
TensorOf<[I32]>:$size,
BoolAttr:$align_corners);
let results = (outs
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, QI8, QUI8]>:$output
);
let hasOptions = 1;
@ -2232,7 +2408,7 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
[NoSideEffect,
TFL_SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale]> {
let summary = "ResizeNearestNeighbor Op";
let description = [{
@ -2240,13 +2416,13 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
}];
let arguments = (ins
TensorOf<[F32, I8, TFL_Uint8, TFL_QUI8, TFL_QI8]>:$input,
TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input,
TensorOf<[I32]>:$size,
BoolAttr:$align_corners
);
let results = (outs
TensorOf<[F32, I8, TFL_Uint8, TFL_QUI8, TFL_QI8]>:$output
TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output
);
let hasOptions = 1;
@ -2294,7 +2470,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameOperandsAndResultsScale
SameOperandsAndResultsScale
]> {
let summary = "StridedSlice Op";
@ -2303,7 +2479,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
}];
let arguments = (ins
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$input,
TensorOf<[I32]>:$begin,
TensorOf<[I32]>:$end,
TensorOf<[I32]>:$strides,
@ -2316,13 +2492,14 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
);
let results = (outs
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8]>:$output
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$output
);
let hasOptions = 1;
}
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
def TFL_CastOp : TFL_Op<"cast", [
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
let summary = "Cast operator";
let description = [{
@ -2411,7 +2588,7 @@ in the unique output `y`. In other words:
// Quantization ops.
//===----------------------------------------------------------------------===//
def TFL_DequantizeOp: TFL_Op<"dequantize", [
NoSideEffect, TFL_NoQuantizableResult]> {
NoSideEffect, NoQuantizableResult]> {
let summary = "Dequantize operator";
let description = [{
@ -2447,7 +2624,7 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
}
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
let summary = "Quantized constant pseudo op";
let description = [{
@ -2465,7 +2642,7 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
}
def TFL_QuantizeOp: TFL_Op<"quantize", [
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
let summary = "Quantize operator";
let description = [{
@ -2609,6 +2786,10 @@ Ba et al. “Layer Normalization”
let results = (outs AnyTensor:$output);
// TODO(fengliuai): customize printer and parser to not display
// empty region.
let regions = (region AnyRegion:$internal);
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];

View File

@ -15,7 +15,6 @@ cc_library(
hdrs = [
"graphdef_to_tfl_flatbuffer.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",

View File

@ -138,7 +138,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
mlir::PassManager pm;
mlir::PassManager pm(module->getContext());
bool run_quantize = tensorflow::ShouldRunQuantizePasses(module.get());
mlir::TFL::PassConfig pass_config;
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;

View File

@ -33,7 +33,6 @@ cc_library(
hdrs = [
"quantization_utils.h",
],
copts = ["-std=c++14"],
deps = [
"@com_google_absl//absl/memory",
"@llvm//:support",

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for TensorFlow Lite.
// This is the quantization definition file for TensorFlow.
#ifdef TFL_Quantization
#ifdef TF_Quantization
#else
#define TFL_Quantization
#define TF_Quantization
#ifdef OP_BASE
#else
@ -46,7 +46,7 @@ def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
//===----------------------------------------------------------------------===//
// The base class of a quantized type.
class TFL_QuantizedType<string n, list<int> params, bit signed>
class QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
@ -59,21 +59,21 @@ class TFL_QuantizedType<string n, list<int> params, bit signed>
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
// express the Mantissa and Exponent components of the floating-point scale so
// the scale of the quantized type is "smantissa * 10 ^ sexp".
class TFL_UInt8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: TFL_QuantizedType<"Uniform",
class UInt8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: QuantizedType<"Uniform",
[8, zero_pt, smantissa, sexp, 0, 255], 0>;
class TFL_Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: TFL_QuantizedType<"Uniform",
class Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: QuantizedType<"Uniform",
[8, zero_pt, smantissa, sexp, -128, 127], 1>;
// General uniform quantized types. The definitions can be used to specify
// operand's tensor types.
def TFL_QUI8 : TFL_QuantizedType<"Uniform", [8], 0>;
def TFL_QI8 : TFL_QuantizedType<"Uniform", [8], 1>;
def TFL_QUI16 : TFL_QuantizedType<"Uniform", [16], 0>;
def TFL_QI16 : TFL_QuantizedType<"Uniform", [16], 1>;
def TFL_QUI32 : TFL_QuantizedType<"Uniform", [32], 0>;
def TFL_QI32 : TFL_QuantizedType<"Uniform", [32], 1>;
def QUI8 : QuantizedType<"Uniform", [8], 0>;
def QI8 : QuantizedType<"Uniform", [8], 1>;
def QUI16 : QuantizedType<"Uniform", [16], 0>;
def QI16 : QuantizedType<"Uniform", [16], 1>;
def QUI32 : QuantizedType<"Uniform", [32], 0>;
def QI32 : QuantizedType<"Uniform", [32], 1>;
//===----------------------------------------------------------------------===//
// TFL native op traits (for quantization).
@ -83,23 +83,23 @@ def TFL_QI32 : TFL_QuantizedType<"Uniform", [32], 1>;
//===----------------------------------------------------------------------===//
// Specify this trait if the op has a fixed output value range.
class TFL_FixedResultScale<TFL_QuantizedType qt> : NativeOpTrait<!strconcat(
"TFL::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
// Specify this trait if the op requires same inputs and outputs quantization
// scales.
def TFL_SameOperandsAndResultsScale : NativeOpTrait<
"TFL::SameOperandsAndResultsScale">;
def SameOperandsAndResultsScale : NativeOpTrait<
"quant::SameOperandsAndResultsScale">;
// Specify this trait if the b-th input of the op is a bias input, which needs
// a scale based on the scales of op1 and op2.
class TFL_AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
!strconcat("TFL::AccumulatorUniformScale<",
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
!strconcat("quant::AccumulatorUniformScale<",
StrJoinInt<[bias, op1, op2]>.result,
">::Impl")>;
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
// apply quantization on this op.
def TFL_NoQuantizableResult : NativeOpTrait<"TFL::NoQuantizableResult">;
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;
#endif // TFL_Quantization
#endif // TF_Quantization

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@ -118,17 +119,19 @@ class QuantizationDriver {
// result.
void Finalize();
// Whether the constant is used as a bias input of another op. Here we assume
// bias is used immediately by the user. This assumption is always correct
// after constant folding.
bool UsedAsBias(ConstantOp cst) {
Value *value = cst.getResult();
for (auto &use : value->getUses()) {
auto biases = GetQuantSpec(use.getOwner())->biases_params;
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
}
return false;
}
// The quantization parameters of bias operand are usually determined by
// other operands, so if a constant is used by different ops as bias, it needs
// to be duplicated, thus each op can assign its own quantization parameter
// for this bias. Also this methods add all the non-bias constants to a set
// for looking up later.
void PreprocessConstantOps();
// Setup all the data structures for quantization propagation.
void SetupAllStates();
// Whether the constant is a weight, which shouldn't be shared by different
// ops.
bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
// Returns all the related quantization constraints of the op.
std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
@ -266,6 +269,11 @@ class QuantizationDriver {
OpBuilder builder_;
bool is_signed_;
// We should distinguish weights and bias constants. Biases are specified by
// the quantization spec or are the operands of ops with same scale spec. The
// rest are weights.
llvm::DenseSet<Operation *> weights_;
// All the ops needs to propagate the quantization parameters to.
std::vector<Operation *> work_list_;
std::unordered_set<Operation *> quantized_;
@ -539,12 +547,49 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
return {};
}
// This method scans the operations in the function to setup the initial
// states for quantization parameter propagation.
// TODO(fengliuai): This algorithm assumes there are only one pair of
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
// check should be applied.
void QuantizationDriver::Initialize() {
void QuantizationDriver::PreprocessConstantOps() {
fn_.walk([&](ConstantOp cst) {
// Non-float tensors are neither weights nor require quantization.
auto type = cst.getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return;
Value *value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false;
for (auto &use : value->getUses()) {
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
int operand_num = use.getOperandNumber();
// The user doesn't use this value as a bias operand or require same
// scale, then this constant is considered to be a weight.
if (biases.find(operand_num) == biases.end() &&
!spec->requires_same_scale) {
used_as_weight = true;
} else {
bias_users.push_back({user, operand_num});
}
}
// If the constant is used as a weight, this constant will be duplicated for
// each bias user, so it isn't shared with the weight usage. Otherwise, the
// first bias user can use the original constant and the rest use the
// duplications, so we pop bias user from the set.
if (used_as_weight) {
weights_.insert(cst);
} else {
bias_users.pop_back();
builder_.setInsertionPoint(cst);
}
for (auto bias_user : bias_users) {
auto copied = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
bias_user.first->setOperand(bias_user.second, copied.getResult());
}
});
}
void QuantizationDriver::SetupAllStates() {
llvm::DenseMap<Value *, int> value_to_state;
fn_.walk([&](Operation *op) {
@ -582,6 +627,21 @@ void QuantizationDriver::Initialize() {
});
}
// This method scans the operations in the function to setup the initial
// states for quantization parameter propagation.
// TODO(fengliuai): This algorithm assumes there are only one pair of
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
// check should be applied.
void QuantizationDriver::Initialize() {
// Duplicate the bias constant, so the states can be setup correctly.
// TODO(fengliuai): Function definition should also be duplicated if there are
// multiple call sites.
PreprocessConstantOps();
// Setup all the internal states.
SetupAllStates();
}
bool QuantizationDriver::PropagateParams() {
// TODO(fengliuai): uses a typed indicator instead of a bool value.
bool changed = false;
@ -590,7 +650,7 @@ bool QuantizationDriver::PropagateParams() {
work_list_.pop_back();
// This op has been quantized, so we should not consider it again.
if (quantized_.find(op) != quantized_.end()) continue;
if (llvm::is_contained(quantized_, op)) continue;
quantized_.insert(op);
auto spec = GetQuantSpec(op);
@ -600,9 +660,8 @@ bool QuantizationDriver::PropagateParams() {
if (!spec->is_quantizable) continue;
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
// This constant is used as a bias in another op, then the quantization
// parameters are determined by that op.
if (UsedAsBias(cst) || IsQuantized(op)) continue;
// If it isn't a weight or has been quantized, skip.
if (!IsWeight(cst) || IsQuantized(op)) continue;
// The quantization parameters are determined by the content of the
// constant.

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace mlir {
namespace OpTrait {
namespace TFL {
namespace quant {
using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
@ -119,7 +119,7 @@ class NoQuantizableResult
static bool IsQuantizable() { return false; }
};
} // namespace TFL
} // namespace quant
} // namespace OpTrait
} // namespace mlir

View File

@ -32,11 +32,12 @@ static Type GetQuantizedType(Builder builder, Type input_type, double min,
double max, int storage_type_width,
bool narrow_range, bool is_signed) {
auto converter =
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
quant::ExpressedToQuantizedConverter::forInputType(input_type);
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
converter.expressedType, is_signed);
if (!quantizedEleType) return {};
return converter.convert(quantizedEleType);
}
@ -79,20 +80,24 @@ Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
double min = std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::min();
if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
double ele_value = FloatAttr::getValueAsDouble(*it);
min = std::min(min, ele_value);
max = std::max(max, ele_value);
// If all the element values are same we don't need to scan the content.
if (fp.isSplat()) {
min = max =
FloatAttr::getValueAsDouble(fp.getSplatValue<llvm::APFloat>());
} else {
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
double ele_value = FloatAttr::getValueAsDouble(*it);
min = std::min(min, ele_value);
max = std::max(max, ele_value);
}
}
// The range must straddle zero.
if (min > 0.0 || max < 0.0) return {};
auto type = GetQuantizedType(builder, attr.getType(), min, max,
storage_type_width, narrow_range, is_signed);
if (auto ele_type = type.dyn_cast_or_null<TensorType>())
return ele_type.getElementType();
}
// The range from SplatElementAttr and other element attribute types couldn't
// The range from SplatElementAttr and other element attribute types couldn't
// straddle zero, so the quantization parameters couldn't be derived from its
// range.
return {};

View File

@ -40,7 +40,8 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
"FixedResultUniformScale<([0-9]+).*(true|false)>"};
emitSourceFileHeader("Generated Ops Quant Spec Getters", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
// Retrieve all the definitions derived from Op defintion and sort by record
// name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
@ -53,8 +54,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
for (const auto t : op.getTraits()) {
if (auto opTrait = llvm::dyn_cast<mlir::tblgen::NativeOpTrait>(&t)) {
auto trait = opTrait->getTrait();
// We only handle TFL specific native op traits.
if (!trait.consume_front("OpTrait::TFL::")) continue;
if (!trait.consume_front("OpTrait::quant::")) continue;
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
<< ">(op)) {\n";
@ -73,7 +73,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n";
OUT(6) << "spec->restricted_output_params[std::make_pair("
<< matches[1] << ", " << matches[2]
<< ")].push_back(tfl.OpTrait::TFL::" << trait << "<"
<< ")].push_back(tfl.OpTrait::quant::" << trait << "<"
<< op.getQualCppClassName()
<< ">::GetResultQuantizedType(i));\n";
matches.clear();

View File

@ -13,6 +13,11 @@ func @extractSimpleOphint() {
return
}
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"}
// -----
// CHECK-LABEL: extractPackedInputOphint
func @extractPackedInputOphint() {
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
@ -30,6 +35,11 @@ func @extractPackedInputOphint() {
return
}
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"}
// -----
// CHECK-LABEL: extractFirstInputOphint
func @extractFirstInputOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -46,6 +56,11 @@ func @extractFirstInputOphint() {
return
}
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"}
// -----
// CHECK-LABEL: extractLastInputOphint
func @extractLastInputOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -62,6 +77,11 @@ func @extractLastInputOphint() {
return
}
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"}
// -----
// CHECK-LABEL: extractPackOneInputOphint
func @extractPackOneInputOphint() {
// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32>
@ -75,13 +95,16 @@ func @extractPackOneInputOphint() {
return
}
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"}
// -----
// CHECK-LABEL: extractStackInputOutputOphint
func @extractStackInputOutputOphint() {
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -98,11 +121,14 @@ func @extractStackInputOutputOphint() {
return
}
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"}
// -----
// CHECK-LABEL: extractMultipleInputsOutputsOphint
func @extractMultipleInputsOutputsOphint() {
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
@ -119,21 +145,33 @@ func @extractMultipleInputsOutputsOphint() {
return
}
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation"}
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_stack"}
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_first"}
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_last"}
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"}
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"}
// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"}
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"}
// -----
// CHECK-LABEL: inputsAfterOutputs
func @inputsAfterOutputs() {
// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
%0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<f32>) -> tensor<f32>
%2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
%3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
%5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
%7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return
}
// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"}
// -----

View File

@ -0,0 +1,26 @@
// RUN: tf-opt -tfl-legalize-ophint-func-op %s | FileCheck %s
module {
// CHECK-LABEL: func @testConvertUnidirectionalSequenceRNN
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<1x3xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<1x3xf32>)
func @testConvertUnidirectionalSequenceRNN(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x4xf32> {
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<1x4xf32>
// CHECK: %[[CST_0:.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32>
// CHECK: %[[CST_2:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32>
// CHECK: %[[PACKED_INPUT:[a-z0-9]*]] = "tfl.pack"(%[[ARG_0]], %[[ARG_1]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32>
// CHECK: %[[FUSED_OUTPUT:[a-z0-9]*]] = "tfl.unidirectional_sequence_rnn"(%[[PACKED_INPUT]], %[[CST_1]], %[[CST_2]], %[[CST_0]], %[[CST]]) {fused_activation_function = "TANH", time_major = true} : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[FUSED_OUTPUT]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>)
%cst = constant dense<0.000000e+00> : tensor<1x4xf32>
%cst0 = constant dense<0.000000e+00> : tensor<4xf32>
%cst1 = constant dense<0.000000e+00> : tensor<4x3xf32>
%cst2 = constant dense<0.000000e+00> : tensor<4x4xf32>
%2 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32>
%3 = call @a9211722c23011e9875cdc4a3e957995(%2, %cst1, %cst2, %cst0, %cst) : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
%4:2 = "tfl.unpack"(%3) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>)
return %4#0 : tensor<1x4xf32>
}
func @a9211722c23011e9875cdc4a3e957995(tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
attributes {_tflite_function_name = "UnidirectionalSequenceRnn"}
}

View File

@ -50,7 +50,7 @@ func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor
func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
%0 = "tf.Squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
%2 = constant dense<[2, 5]> : tensor<2xi32>
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
@ -119,8 +119,8 @@ func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
}
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%arg1 = constant dense<-0.1> : tensor<f32>
%arg2 = constant dense<0.2> : tensor<f32>
%arg1 = "tf.Const"() { value = dense<-0.1> : tensor<f32> } : () -> tensor<f32>
%arg2 = "tf.Const"() { value = dense<0.2> : tensor<f32> } : () -> tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
@ -153,6 +153,14 @@ func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
return %0: tensor<i32>
// CHECK-LABEL: @placeholder_int
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
}
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
@ -409,7 +417,7 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3
}
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = constant dense<[1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
return %1 : tensor<1x3x5x20xf32>
@ -418,7 +426,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
}
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
@ -427,7 +435,7 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
}
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = constant dense<[1]> : tensor<1xi32>
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
@ -509,6 +517,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) ->
// CHECK: return %0 : tensor<8xf32>
}
func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
return %0: tensor<8x3xf32>
// CHECK-LABEL: select_multidim
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8x3xf32>
}
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
@ -518,6 +535,15 @@ func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>)
// CHECK: return %0 : tensor<8xf32>
}
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
return %0: tensor<8x3xf32>
// CHECK-LABEL: select_v2_multidim
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8x3xf32>
}
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
@ -536,7 +562,7 @@ func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?
}
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
@ -546,7 +572,7 @@ func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
}
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
@ -556,7 +582,7 @@ func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
}
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
@ -566,7 +592,7 @@ func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2
}
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0 = constant dense<2> : tensor<i32>
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
@ -660,9 +686,18 @@ func @pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
// CHECK: return %0 : tensor<?xf32>
}
func @pow(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x1x1xf32>) -> tensor<2x1x3xf32> {
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x1x1xf32>) -> tensor<2x1x3xf32>
return %0 : tensor<2x1x3xf32>
// CHECK-LABEL: pow
// CHECK: %[[pow:.*]] = "tfl.pow"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x1x1xf32>) -> tensor<2x1x3xf32>
// CHECK: return %[[pow]] : tensor<2x1x3xf32>
}
func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>):
%cst = constant dense<[1, 2]> : tensor<2xi32>
%cst = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> } : () -> tensor<2xi32>
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
return %0 : tensor<2x6xf32>
@ -673,7 +708,7 @@ func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
%cst = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
@ -832,12 +867,12 @@ func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32>
// CHECK: %0:3 = "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
}
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<1xi32>) -> tensor<1x4x2x3xf32> {
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<1x4x2x3xf32> {
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
return %0#0 : tensor<1x4x2x3xf32>
// CHECK-LABEL: splitv
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
}
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
@ -849,8 +884,8 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
}
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
%0 = constant dense<[1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
@ -858,8 +893,8 @@ func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi
}
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
@ -867,8 +902,8 @@ func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2
}
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
%0 = constant dense<[-1]> : tensor<1xi32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2With3Tensors
@ -1084,3 +1119,35 @@ func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> {
// CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32>
// CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
}
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v4
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
}
func @non_max_suppression_v4_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v4_no_pad
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
}
func @non_max_suppression_v5(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v5
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
}
func @non_max_suppression_v5_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0#0 : tensor<2xi32>
// CHECK-LABEL: non_max_suppression_v5_no_pad
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
}

View File

@ -0,0 +1,107 @@
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
%0 = "tfl.lstm"(%arg0, // input
%arg1, %arg2, %arg3, %arg4, // weights
%arg5, %arg6, %arg7, %arg8, // recurrent weights
%arg9, %arg10, %arg11, // cell weights
%arg12, %arg13, %arg14, %arg15, // bias
%arg16, %arg17, // projection weight and bias
%arg18, %arg19, // stateful
%arg20, %arg21, %arg22, %arg23 // layer norm coefficients
) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<? xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: "tfl.lstm"
// CHECK-NEXT: %[[cst:.*]] = constant unit
// input gate
// CHECK-NEXT: %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in3:.*]] = "tfl.mul"(%arg19, %arg9)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in6:.*]] = tfl.add %[[in4]], %[[in5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[in8:.*]] = "tfl.logistic"(%[[in7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// forget gate
// CHECK-NEXT: %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]])
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[fo8:.*]] = "tfl.logistic"(%[[fo7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// cell gate
// CHECK-NEXT: %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ce7:.*]] = "tfl.tanh"(%[[ce6]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// output gate
// CHECK-NEXT: %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15)
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ou8:.*]] = "tfl.logistic"(%[[ou7]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// output activation
// CHECK-NEXT: %[[ac4:.*]] = "tfl.tanh"(%[[ac3]])
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]]
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
// CHECK-NEXT: %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17)
// CHECK-SAME: (tensor<?x!quant.any<i16:f32>>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x!quant.any<i8:f32>>
// CHECK-NEXT: %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor<?x!quant.any<i8:f32>>) -> tensor<?x!quant.any<i8:f32>>
// CHECK-NEXT: })
// CHECK-NEXT: return
return %0 : tensor<?xf32>
}

View File

@ -143,6 +143,19 @@ func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: t
// CHECK: return [[RESULT]] : tensor<?x10xf32>
}
func @tensorlistLength(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>) -> (tensor<i32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListLength"(%0) : (tensor<!tf.variant<tensor<10xf32>>>) -> tensor<i32>
return %1: tensor<i32>
// CHECK-LABEL: tensorlistLength
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>)
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32>
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
// CHECK: return [[RESULT]] : tensor<i32>
}
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>

View File

@ -278,6 +278,6 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
%21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
%22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
%23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
}

View File

@ -0,0 +1,31 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
module attributes {
tfl.metadata = {key1 = "value1", key2 = "value2"}
} {
func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
^bb0(%arg0: tensor<3x2xi32>):
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "NONE"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %2 : tensor<3x2xi32>
}
}
// CHECK: buffers: [ {
// CHECK: }, {
// CHECK: }, {
// CHECK: }, {
// CHECK: }, {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "key1",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "key2",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]

View File

@ -1,5 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
// | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
@ -9,7 +8,7 @@ func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf
return %2 : tensor<40x40xf32>
}
// CHECK-NEXT: operators: [ {
// CHECK: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,

View File

@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
// test invalid AddN
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
return %0 : tensor<? x f16>
}
@ -537,7 +537,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
// test invalid Logistic input
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
@ -591,8 +591,9 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -600,8 +601,9 @@ func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -610,7 +612,7 @@ func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %
// test invalid none type applied to a tensor type arg
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -619,7 +621,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
// test violation of projection weight and projection bias pred op trait
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -628,7 +630,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
// test invalid kernel type
func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -652,6 +654,15 @@ func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi
// -----
// test select with multi-dim inputs
// CHECK-LABEL: testSelectMultiDim
func @testSelectMultiDim(%cond : tensor<?xi1>, %arg0 : tensor<?x4xi32>, %arg1 : tensor<?x4xi32>) -> tensor<?x4xi32> {
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?x4xi32>,tensor<?x4xi32>) -> tensor<?x4xi32>
return %0 : tensor<?x4xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
@ -660,6 +671,14 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>
// -----
func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> {
// expected-error @+1 {{failed to verify that Select operands meet shape criteria}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
@ -762,6 +781,21 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te
// -----
// CHECK-LABEL: testPadQuantizedU8
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
// CHECK-LABEL: testPadQuantizedI8
func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}
// -----
// CHECK-LABEL: testPadV2
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
@ -817,6 +851,20 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) ->
// -----
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<u8:f32, 0.1>>
}
func @packQuantizedI8(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<i8:f32, 0.1>>, tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1>>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
@ -1143,8 +1191,8 @@ func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.
// -----
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<i32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<1xi32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<1xi32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
@ -1331,16 +1379,16 @@ func @testSplitOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf
func @testSplitOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%split_dim = constant dense<0> : tensor<2x2xi32>
// expected-error @+1 {{'tfl.split' op operand #0 must be tensor<i32>}}
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<2x2xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<? x i32>) -> tensor<16x4x4xf32> {
// expected-error @+1 {{'tfl.split' op operand #0 must be tensor<i32>}}
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<?xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<*xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
@ -1424,3 +1472,252 @@ func @testSplitOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
}
// -----
func @testSplitVOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () {
%size_splits = constant dense<[]> : tensor<0xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}}
"tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 0 : i32} : (tensor<16xf32>, tensor<0xi32>, tensor<i32>) -> ()
return
}
// -----
func @testSplitVOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[4, 4, 4, 4]> : tensor<4xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output count should match 'num_splits' attribute}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 4 : i32} : (tensor<16xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[[8, 8], [2, 2]]> : tensor<2x2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2x2xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %size_splits: tensor<*xi32>) -> tensor<16x4x4xf32> {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<*xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstant(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[-2]> : tensor<1xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op elements of 'size_splits' should be greater than or equal to -1}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstantMultipleNegativeOne(%arg0: tensor<16x4x4xf32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>) {
%size_splits = constant dense<[-1, -1, 14]> : tensor<3xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'size_splits' can only have one -1}}
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 3 : i32} : (tensor<16x4x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>)
return %0, %1, %2 : tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstantSum(%arg0: tensor<16x4x4xf32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>) {
%size_splits = constant dense<[-1, 17]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op sum of non-negative elements of 'size_splits' is greater than the dimension size of 'split_dim' axis}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>)
return %0, %1 : tensor<0x4x4xf32>, tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsSize(%arg0: tensor<16x4x4xf32>) -> tensor<15x4x4xf32> {
%size_splits = constant dense<[15, 1]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'size_splits' should be 'tensor<1xi32>'}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> tensor<15x4x4xf32>
return %0 : tensor<15x4x4xf32>
}
// -----
func @testSplitVOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[16]> : tensor<1xi32>
%split_dim = constant dense<0> : tensor<2x2xi32>
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<2x2xi32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[16]> : tensor<1xi32>
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<*xi32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDim(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<1> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDimTFLConst(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDimNegative(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<-2> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithMismatchSizeSplitsSum(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 4]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op sum of 'size_splits' should match the dimension size of 'split_dim' axis}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
return %0, %1 : tensor<8xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeSplitDimOut0(%arg0 : tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<4xf32>, tensor<4xf32>)
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeSplitDimOut1(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #1 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
return %0, %1 : tensor<8xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8x4xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x2xf32>, tensor<8x2xf32>)
return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32>
}
// -----
func @testSplitVOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
%size_splits_0 = constant dense<[8, 8]> : tensor<2xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
%size_splits_1 = constant dense<[2, 2]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%2, %3 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0, %1, %2, %3 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>
}
// -----
func @testSplitVOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x?xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
}
// -----
func @testSplitVOpWithValidSizeSplitsUneven(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>) {
%size_splits_0 = constant dense<[7, 3, 6]> : tensor<3xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
%size_splits_1 = constant dense<[1, 3]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x3xf32>)
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>
}
// -----
func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>) {
%size_splits_0 = constant dense<[7, -1, 6]> : tensor<3xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
%size_splits_1 = constant dense<[-1, 4]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x0xf32>, tensor<16x4xf32>)
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>
}
// -----
func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}}
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}}
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}

View File

@ -107,10 +107,10 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
return %1 : tensor<4x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %0 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast
@ -123,10 +123,10 @@ func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>
return %1 : tensor<1x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %0 : tensor<1x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<1x2xf32>
}
// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
@ -139,9 +139,9 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
return %1 : tensor<4x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %0 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
@ -255,3 +255,40 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
// CHECK: %2 = "tfl.strided_slice"(%1, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
}
// CHECK-LABEL: @L2NormalizePattern
func @L2NormalizePattern(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%cst = constant dense<[0]> : tensor<1xi32>
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%2 = "tfl.rsqrt"(%1) : (tensor<f32>) -> tensor<f32>
%3 = "tfl.mul"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
return %3: tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @L2NormalizePattern1
func @L2NormalizePattern1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%cst = constant dense<[0]> : tensor<1xi32>
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
%3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
return %3: tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @InvalidL2NormalizePattern
// Div and square ops must take the same argument to be eligible.
func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%cst = constant dense<[0]> : tensor<1xi32>
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
%3 = "tfl.div"(%arg1, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
return %3: tensor<2xf32>
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
// CHECK: return %3
}

View File

@ -211,6 +211,17 @@ func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>
// CHECK: return %3 : tensor<1x6x6x16xf32>
}
// CHECK-LABEL: NotQuantizeConcatConstantOperand
func @NotQuantizeConcatConstantOperand(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
%0 = constant dense<1.0> : tensor<2xf32>
%1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[cst]])
// CHECK-NEXT: return %[[cc]]
}
// CHECK-LABEL: QuantizeConcatOperand0ToAll
func @QuantizeConcatOperand0ToAll(tensor<2x!quant.uniform<u8:f32, 0.1:128>>, tensor<2xf32>) -> tensor<2x2xf32> {
^bb0(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1:128>>, %arg1: tensor<2xf32>):
@ -360,4 +371,85 @@ func @QuantizeConstant() -> tensor<2x3xf32> {
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8<1:255>:f32, 0.023622047244094488:128>>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<2x3xf32>
}
}
// CHECK-LABEL: NotQuantizeNoneType
func @NotQuantizeNoneType() -> none {
%cst = constant unit
return %cst : none
// CHECK-NEXT: %[[cst:.*]] = constant unit
// CHECK-NEXT: return %[[cst]]
}
// CHECK-LABEL: QuantizeZeroSplat
func @QuantizeZeroSplat() -> tensor<2x3xf32> {
%cst = constant dense<0.0> : tensor<2x3xf32>
return %cst : tensor<2x3xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<2x3xf32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
}
// CHECK-LABEL: QuantizeZeroScalar
func @QuantizeZeroScalar() -> tensor<f32> {
%cst = constant dense<0.0> : tensor<f32>
return %cst : tensor<f32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
}
// CHECK-LABEL: QuantizeSharedBiases
func @QuantizeSharedBiases(
%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>,
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>,
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
%cst = constant dense<1.0> : tensor<32xf32>
%1 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
%2 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
%conv1 = "tfl.conv_2d"(%1, %2, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%3 = "tfl.quantize"(%conv1) {qtype = tensor<1x112x112x32xf32>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
%4 = "tfl.dequantize"(%3) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
%5 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
%conv2 = "tfl.conv_2d"(%4, %5, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
%6 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
return %6 : tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]])
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]])
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}
// CHECK-LABEL: QuantizeSharedBiases2
func @QuantizeSharedBiases2(
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
%arg1: tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>,
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
%cst = constant dense<1.0> : tensor<32xf32>
%1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
%add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform<u8:f32, 1.0>>
%5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
%6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
%conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
%7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
return %3, %7 : tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}

View File

@ -16,13 +16,13 @@ func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> (tensor<256x30x30x1
return %0, %1, %2, %3, %4 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
// CHECK-LABEL: conv
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
// CHECK: %0 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
// CHECK: %0 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %2 = "tf.Conv2D"
// CHECK: %3 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %3 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: %5 = "tf.Conv2D"
// CHECK: %6 = "tf.Conv2D"
}
@ -41,13 +41,13 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> (tensor<2
return %0, %1, %2, %3 : tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>
// CHECK-LABEL: depthwiseConv2D
// CHECK: %cst = constant dense<0.000000e+00> : tensor<12xf32>
// CHECK: %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
// CHECK: %0 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<12xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
// CHECK: %0 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %2 = "tf.DepthwiseConv2dNative"
// CHECK: %3 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %3 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
// CHECK: %5 = "tf.DepthwiseConv2dNative"
}
@ -155,10 +155,10 @@ func @fakeQuantFolded() -> (tensor<8xf32>) {
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %rst : tensor<8xf32>
// CHECK: %cst = constant dense<0.000000e+00> : tensor<8xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<8xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<8xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
}
// CHECK-LABEL: fakeQuantNotFolded
@ -261,12 +261,12 @@ func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>)
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %cst)
// CHECK: return %2
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
// CHECK: return %[[CONV]]
}
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
@ -281,12 +281,12 @@ func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<0.000000e+00> : tensor<48xf32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %cst)
// CHECK: return %2
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
// CHECK: return %[[CONV]]
}
func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) {
@ -348,3 +348,11 @@ func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
// CHECK-LABEL: stop_gradient
// CHECK: return %arg0 : tensor<3xi32>
}
func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// Should be converted to Identity and then from Identity to value
// CHECK-LABEL: CheckNumerics
// CHECK: return %arg0 : tensor<3xf32>
}

View File

@ -0,0 +1,223 @@
// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
}
func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2FlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
}
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulV2Matrix
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[v0]] : tensor<4x6xf32>
}
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
}
func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulFlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
}
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
return %0 : tensor<4x6xf32>
// CHECK-LABEL: batchMatMulMatrix
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[v0]] : tensor<4x6xf32>
}

View File

@ -41,10 +41,13 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::PassManager* pass_manager) {
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
// Ophint extraction will happen after island extraction pass.
pass_manager->addPass(mlir::TFL::CreateExtractOphintPass());
// Convert composite op pass will happen after ophint extraction pass.
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
if (pass_config.lower_tensor_list_ops) {
// Execute this pass before `CanonicalizerPass` in case some TensorList

View File

@ -131,7 +131,7 @@ int main(int argc, char **argv) {
// message. So we can just return here.
if (!module.ok()) return kTrFailure;
mlir::PassManager pm;
mlir::PassManager pm(&context);
bool run_quantize =
tensorflow::ShouldRunQuantizePasses(module.ValueOrDie().get());
mlir::TFL::PassConfig pass_config;
@ -149,7 +149,12 @@ int main(int argc, char **argv) {
lower_tensor_list_ops, &result, &pm);
if (!status.ok()) return kTrFailure;
auto output = mlir::openOutputFile(output_file_name);
std::string error_msg;
auto output = mlir::openOutputFile(output_file_name, &error_msg);
if (output == nullptr) {
llvm::errs() << error_msg << '\n';
return kTrFailure;
}
output->os() << result;
output->keep();

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include <map>
#include <queue>
#include <vector>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
@ -353,6 +355,127 @@ struct OphintCompositeOp {
std::map<int, AggregatedOperand> outputs;
};
// Preprocess the graph for topo sort. (each operation is a node, while
// inputs/outputs indictate edges) Assume the graph is acyclic. The preprocess
// does the following:
// Compute each operations's in-degress (how many input nodes they're taken)
// Get all consumer operations for every operations. (operation_to_ouputs)
// Get the init_queue (those operations will be processed first).
void PreprocessTopoSortGraph(
Block* block, std::queue<Operation*>* init_queue,
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
for (auto& op : *block) {
if (&op == block->getTerminator()) continue;
if (op.getNumOperands() == 0) {
init_queue->push(&op);
} else {
// The operand of the ops is not a direct indication of the "edge" as we
// can have a pack op after a unpack op (they have multiple edges), we
// should only count as one.
llvm::DenseSet<Operation*> input_ops;
for (int i = 0; i < op.getNumOperands(); ++i) {
Operation* input_op = op.getOperand(i)->getDefiningOp();
if (input_op) input_ops.insert(input_op);
}
if (input_ops.empty()) {
init_queue->push(&op);
continue;
}
operation_to_in_degrees->try_emplace(&op, input_ops.size());
for (auto* input_op : input_ops) {
auto preceeding_op_it = operation_to_ouputs->find(input_op);
if (preceeding_op_it == operation_to_ouputs->end()) {
auto result = operation_to_ouputs->try_emplace(
input_op, llvm::DenseSet<Operation*>());
preceeding_op_it = result.first;
}
preceeding_op_it->second.insert(&op);
}
}
}
}
bool IsSideEffectOp(Operation* op) {
if (op->hasNoSideEffect()) return false;
// Identity op has no side effect.
// Check the OperationName maybe more elegant here.
auto tf_identity_op = dyn_cast_or_null<TF::IdentityOp>(op);
if (tf_identity_op) return false;
return true;
}
// It's possible other transformations can benefit from this util function, but
// since currently there's none, so we only limit this function to the ophint
// extraction pass. We may refactor this function to extend the usage in future.
//
// Assume the graph is disconnected from outside.
// Also assume the block has no arguments.
LogicalResult TopoSortOperations(OpBuilder* builder) {
std::queue<Operation*> init_queue;
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
std::vector<Operation*> sorted_ops;
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
&operation_to_ouputs, &operation_to_in_degrees);
while (!init_queue.empty()) {
Operation* current_op = init_queue.front();
init_queue.pop();
sorted_ops.push_back(current_op);
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
if (current_op_to_output_it == operation_to_ouputs.end()) {
continue;
}
for (Operation* output_op : current_op_to_output_it->second) {
auto output_op_it = operation_to_in_degrees.find(output_op);
if (output_op_it == operation_to_in_degrees.end()) return failure();
output_op_it->second -= 1;
if (output_op_it->second == 0) {
init_queue.push(output_op);
operation_to_in_degrees.erase(output_op_it);
}
}
operation_to_ouputs.erase(current_op_to_output_it);
}
// Before we performs the sort. We need to make sure we didn't mess the
// ordering of original side-effect operations.
// It's possible those side-effect operations have no topogocial relations
// at all!
std::vector<Operation*> original_side_effect_ops;
std::vector<Operation*> after_sort_side_effect_ops;
for (auto& op : *builder->getBlock()) {
if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator()))
original_side_effect_ops.push_back(&op);
}
for (auto* op : sorted_ops) {
if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op);
}
if (original_side_effect_ops.size() != after_sort_side_effect_ops.size())
return failure();
for (int i = 0; i < original_side_effect_ops.size(); ++i) {
if (original_side_effect_ops[i] != after_sort_side_effect_ops[i])
return failure();
}
// Performs the sort.
// Ideally it would be nice to just clear the block then write the sorted ops.
// But unfortunately that's hard to do.
for (int i = sorted_ops.size() - 1; i > 0; --i) {
Operation* current_op = sorted_ops[i];
for (int j = i - 1; j >= 0; --j) {
Operation* prev_op = sorted_ops[j];
prev_op->moveBefore(current_op);
}
}
return success();
}
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
Operation* insert_before_op,
const std::map<int, Value*>& inputs,
@ -360,10 +483,12 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
OpBuilder* builder, ModuleOp* module_op) {
SmallVector<Type, 4> input_types;
SmallVector<Value*, 4> input_values;
SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) {
Value* input = kv.second;
input_types.push_back(input->getType());
input_values.push_back(input);
input_indexes.push_back(kv.first);
}
SmallVector<Type, 4> func_output_types;
@ -378,6 +503,8 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
SmallVector<NamedAttribute, 4> attrs;
attrs.push_back(builder->getNamedAttr(
kTfLiteFunctionName, builder->getStringAttr(fused_func_type)));
attrs.push_back(builder->getNamedAttr(
kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes)));
FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name,
function_type, llvm::makeArrayRef(attrs));
module_op->push_back(func_op);
@ -507,6 +634,10 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
};
builder->getBlock()->walk(removeRemovableOps);
// Step 8: Topo sort to fix any invalid temporary IRs.
if (failed(TopoSortOperations(builder))) return failure();
return success();
}

View File

@ -0,0 +1,209 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
namespace {
constexpr char kTfLiteFunctionName[] = "_tflite_function_name";
constexpr char kUnidirectionalSequenceRnn[] = "UnidirectionalSequenceRnn";
// This pass is used for converting to TFLite composite op like
// UnidirectionalSequenceRNN, UnidirectionalSequenceLSTM or SVDF Op. Currently,
// this pass is only for ophint converted function op only. See below diagram:
//
// InputOp1 InputOp2 ...
// \ /
// \ /
// call funcOp (say UnidirectionalSequenceRNN)
// |
// |
// OutputOp1
//
// funcOp() { '_tflite_function_name' = 'UnidirectionalSequenceRNN'}
//
// ||
// ||
// \ /
//
// InputOp1 InputOp2 ...
// \ /
// \ /
// tfl.UnidirectionalSequenceRNN
// |
// |
// OutputOp1
struct LegalizeOphintFuncOpPass : public ModulePass<LegalizeOphintFuncOpPass> {
void runOnModule() override;
};
llvm::StringMap<FuncOp> FindCompositeFuncOps(ModuleOp module) {
llvm::StringMap<FuncOp> composite_func_ops;
for (FuncOp func : module.getOps<FuncOp>()) {
if (func.getAttr(kTfLiteFunctionName))
composite_func_ops[func.getName()] = func;
}
return composite_func_ops;
}
LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
CallOp* call_op,
OpBuilder* builder,
Operation** fused_op) {
// UnidirectionalSequenceRnn takes exactly 5 inputs.
if (composite_func_op.getNumArguments() != 5) return failure();
if (call_op->getNumOperands() != 5) return failure();
// UnidirectionalSequenceRnn has exactly 1 input.
if (call_op->getNumResults() != 1) return failure();
// Inputs is indexed at 0.
Value* input = call_op->getOperand(0);
// Input_weight is indexed at 1.
Value* weight = call_op->getOperand(1);
// Recurrent_weight is indexed at 2.
Value* recurrent_weight = call_op->getOperand(2);
// Bias is indexed at 3.
Value* bias = call_op->getOperand(3);
// Hidden_state is indexed at 4.
Value* hidden_state = call_op->getOperand(4);
// Build Output.
auto output_type = call_op->getResult(0)->getType();
// Currently, ophinted RNN only supports time_major = True.
const bool time_major = true;
// Activation will always be TanH.
StringAttr fused_activation_function = builder->getStringAttr("TANH");
builder->setInsertionPoint(call_op->getOperation());
*fused_op = builder->create<TFL::UnidirectionalSequenceRNNOp>(
call_op->getLoc(), output_type, input, weight, recurrent_weight, bias,
hidden_state, builder->getBoolAttr(time_major),
fused_activation_function);
return success();
}
LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name,
FuncOp composite_func_op,
CallOp* call_op,
OpBuilder* builder) {
Operation* fused_op = nullptr;
if (func_name == kUnidirectionalSequenceRnn) {
// TODO(renjieliu): Validate the func op inputs.
LogicalResult build_fused_op_result = BuildUnidirectionalSequenceRnnOp(
composite_func_op, call_op, builder, &fused_op);
if (failed(build_fused_op_result)) return build_fused_op_result;
} else { // If we support more fused op, we should add the conversion here.
return failure();
}
call_op->replaceAllUsesWith(fused_op);
// Delete call op.
Operation* call = call_op->getOperation();
call->dropAllDefinedValueUses();
call->dropAllReferences();
call->erase();
return success();
}
LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
ModuleOp* module) {
for (auto func : module->getOps<FuncOp>()) {
// Ideally it will be much simpler if we can just use walk, but we also
// want to early return if encounter errors. :(
OpBuilder builder(func.getBody());
// The call_op replacement within this loop works like an in-place
// replacement, so it should be safe to do so.
for (auto call_op :
llvm::make_early_inc_range(builder.getBlock()->getOps<CallOp>())) {
auto it = composite_func_ops->find(call_op.getCallee());
if (it == composite_func_ops->end()) return failure();
// Replace the call op with TfLite fused op.
// Currently it's only handled case by case, but ideally it would be
// much better if we can do this automatically.
FuncOp composite_func_op = it->second;
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
.cast<StringAttr>()
.getValue();
if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op,
&call_op, &builder)))
return failure();
composite_func_ops->erase(it);
// Delete func op.
Operation* func = composite_func_op.getOperation();
func->erase();
}
}
return success();
}
void LegalizeOphintFuncOpPass::runOnModule() {
ModuleOp module = getModule();
// Find all composite funcs, then for every call op inside every func op
// within the module, we go ahead and replace the callop with the tflite
// corresponding op and destroy the func op. This two-phase processing is
// intended:
//
// Every func op is meant to be used exactly once.
// Instead of finding the composite func then loop through the graph and
// convert the call op immediately, we break finding & converting into two
// phases. This changes the complexity from O(op_in_module *
// function_in_module * attr_in_func) to O(op_in_module) * O(map_look_up) +
// O(function_in_module * attr_in_func). O(op_in_module) is the dominant
// factor here and map look up should be very cheap.
llvm::StringMap<FuncOp> composite_func_ops = FindCompositeFuncOps(module);
if (composite_func_ops.empty()) return;
if (failed(ConvertCallOps(&composite_func_ops, &module))) {
module.emitError() << "Legalize ophint: ConvertCallOps failed.";
return signalPassFailure();
}
}
} // namespace
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
/// pass.
std::unique_ptr<ModulePassBase> CreateLegalizeOphintFuncOpPass() {
return std::make_unique<LegalizeOphintFuncOpPass>();
}
static PassRegistration<LegalizeOphintFuncOpPass> pass(
"tfl-legalize-ophint-func-op", "Convert composite op for TfLite dialect.");
} // namespace TFL
} // namespace mlir

View File

@ -20,6 +20,10 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def NonOpaqueElementsAttr : ElementsAttrBase<
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
"non-opaque constant tensor">;
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
@ -56,8 +60,13 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Convert to std constant for statically shaped, non-opaque constants.
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
//===----------------------------------------------------------------------===//
// Unary ops patterns.
//===----------------------------------------------------------------------===//
@ -124,6 +133,7 @@ def : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>;
def : Pat<(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis),
(TFL_OneHotOp $indices, $depth, $on_value, $off_value,
(convertIntAttrTo32Bit $axis))>;
def : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>;
def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>;
def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
@ -156,7 +166,8 @@ def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
// The following two rules can both match an tf.Placeholder.input node with
// min/max/type attributes, so we increase the benefit of the first rule by one
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
def : Pat<(TF_PlaceholderInputOp TensorOf<[F16, F32, F64]>:$inputs,
$min, $max, $type),
(TFL_DequantizeOp
(TFL_QuantizeOp
(TFL_InputOp $inputs),
@ -190,7 +201,8 @@ def : Pat<(TF_GatherV2Op $params, $indices,
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
(TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
@ -251,7 +263,7 @@ def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
@ -307,3 +319,11 @@ def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
def : Pat<
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>;
def : Pat<
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>;

View File

@ -0,0 +1,228 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepare the tflite fused ops for quantization.
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
//===----------------------------------------------------------------------===//
// The LoadQuantizationRecipe Pass.
//
namespace mlir {
namespace TFL {
namespace {
// This pass loads the quantization recipe for the TFLite ops to be quantized.
// Specifically, it extends the fused ops with their internal implementation as
// op regions. Each ops in the region produces results with element type
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
// defines the op quantization traits, which are used to propgate the
// quantization parameters by the following passes.
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
void runOnFunction() override;
private:
void Initialize(LSTMOp lstm, OpBuilder* builder);
// Create LSTM gates with different weights for input, recurrent and
// cell state, and also the layer normalization parameters.
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell,
Value* ln_w, Value* ln_bias, OpBuilder* builder);
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
Value* ln_bias, OpBuilder* builder);
// Add the internal implementation of the LSTM to its regions.
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
StringAttr none_af;
StringAttr fc_format;
BoolAttr keep_dims;
Type int8;
Type int16;
ConstantOp none_cst;
};
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
Type expressed_type =
lstm.input()->getType().cast<ShapedType>().getElementType();
Type int8_storage_type = builder->getIntegerType(8);
Type int16_storage_type = builder->getIntegerType(16);
auto flag = quant::QuantizationFlags::FlagValue::Signed;
int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger(
flag, /*integralWidth=*/8);
int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger(
flag, /*integralWidth=*/8);
int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger(
flag, /*integralWidth=*/16);
int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger(
flag, /*integralWidth=*/16);
auto any_int8 = quant::AnyQuantizedType::get(
flag, int8_storage_type, expressed_type, int8_min, int8_max);
auto any_int16 = quant::AnyQuantizedType::get(
flag, int16_storage_type, expressed_type, int16_min, int16_max);
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
}
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
Value* ln_w, Value* ln_bias,
OpBuilder* builder) {
// Note that l2_normalization and add ops here are not the execution kernle
// implementation for layer_normalization and we just want to use them to
// model the quantization requirement.
auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
none_af, fc_format, keep_dims);
}
Operation* LoadQuantizationRecipe::CreateGate(
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
OpBuilder* builder) {
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
none_af, fc_format, keep_dims);
auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
none_af, fc_format, keep_dims);
AddNOp s4;
if (cell.hasValue()) {
auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
cell.getValue().second, none_af);
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>(
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
} else {
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
}
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
if (cell.hasValue()) {
return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
} else {
return builder->create<TanhOp>(loc, int16, s5->getResult(0));
}
}
void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
Initialize(lstm, builder);
Region region;
region.push_back(new Block);
builder->setInsertionPointToEnd(&region.front());
Location loc = lstm.getLoc();
Type int32_type = builder->getIntegerType(32);
Type int32_tensor = builder->getTensorType(int32_type);
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
builder->getUnitAttr());
auto input_gate = CreateGate(
loc, lstm.input(), lstm.input_to_input_weights(),
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
auto forget_gate = CreateGate(
loc, lstm.input(), lstm.input_to_forget_weights(),
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
lstm.input_activation_state(),
lstm.recurrent_to_cell_weights(), llvm::None,
lstm.cell_layer_norm_coefficients(),
lstm.cell_bias(), builder);
auto forget_cell_state = builder->create<MulOp>(
loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
auto input_cell_state = builder->create<MulOp>(
loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
input_cell_state.output(), none_af);
auto output_gate = CreateGate(
loc, lstm.input(), lstm.input_to_output_weights(),
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
{new_cell, lstm.cell_to_output_weights()}),
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
auto hidden_state = builder->create<MulOp>(
loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af);
auto act = builder->create<FullyConnectedOp>(
loc, int8, hidden_state.output(), lstm.projection_weights(),
lstm.projection_bias(), none_af, fc_format, keep_dims);
// TODO(fengliuai): define and register the op in the QuantOps Dialect.
OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
{int8}, {});
builder->createOperation(return_state);
lstm.internal().takeBody(region);
}
void LoadQuantizationRecipe::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
none_af = builder.getStringAttr("NONE");
fc_format = builder.getStringAttr("DEFAULT");
keep_dims = builder.getBoolAttr(false);
func.walk([&](Operation* op) {
if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
LoadForLSTMOp(lstm, &builder);
}
// Handles other ops.
});
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
// pass.
std::unique_ptr<FunctionPassBase> CreateLoadQuantizationRecipePass() {
return absl::make_unique<LoadQuantizationRecipe>();
}
static PassRegistration<LoadQuantizationRecipe> pass(
"tfl-load-recipe", "Load TFL op quantization recipe");
} // namespace TFL
} // namespace mlir

View File

@ -429,12 +429,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
tf_op.element_dtype().isF64() ||
tf_op.element_dtype().isInteger(1) ||
tf_op.element_dtype().isInteger(8) ||
tf_op.element_dtype().isInteger(16) ||
tf_op.element_dtype().isInteger(32) ||
tf_op.element_dtype().isInteger(64))) {
return tf_op.emitError(
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer "
"or 16-bit/32-bit/64-bit "
"float type during TF Lite transformation pass");
}
@ -461,6 +463,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
auto c = ConvertTFTensorListPushBack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
auto c = TFL::ConvertTFTensorListLength(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(tf_op);

View File

@ -110,3 +110,42 @@ def : Pat<(TFL_MulOp (TFL_DepthwiseConv2DOp $input,
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// Constraint that makes sure both operands are the same operands.
def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint<
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
// This pattern constructs L2NormalizationOp from
// Mul->Rsqrt->Sum->Square
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
def : Pat<(TFL_MulOp $operand1,
(TFL_RsqrtOp
(TFL_SumOp
(TFL_SquareOp $square_operand),
(ConstantOp I32ElementsAttr:$constant),
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand)]>;
// This pattern constructs L2NormalizationOp from
// Div->sqrt->Sum->Square
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
def : Pat<(TFL_DivOp $operand1,
(TFL_SqrtOp
(TFL_SumOp
(TFL_SquareOp $square_operand),
(ConstantOp I32ElementsAttr:$constant),
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand)]>;

View File

@ -21,8 +21,12 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class FunctionPassBase;
class ModulePassBase;
class FuncOp;
class ModuleOp;
template <typename T>
class OpPassBase;
using FunctionPassBase = OpPassBase<FuncOp>;
using ModulePassBase = OpPassBase<ModuleOp>;
namespace TFL {
@ -64,6 +68,10 @@ std::unique_ptr<FunctionPassBase> CreatePrepareCompositeFunctionsPass();
// Creates a instance of the TensorFlow Lite dialect ExtractOphint pass.
std::unique_ptr<ModulePassBase> CreateExtractOphintPass();
// Creates a instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
// pass. The composite op is created from the ophint extraction pass.
std::unique_ptr<ModulePassBase> CreateLegalizeOphintFuncOpPass();
} // namespace TFL
} // namespace mlir

View File

@ -18,6 +18,14 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def NonOpaqueElementsAttr : ElementsAttrBase<
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
"non-opaque constant tensor">;
// Convert to std constant for statically shaped, non-opaque constants.
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
[(AnyStaticShapeTensor $res)]>;
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
@ -81,8 +89,8 @@ class TFi32<int v> : ConstantAttr<I32ElementsAttr, !cast<string>(v)>;
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $b),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
/*limit=*/(TF_ConstOp TFi32<0>),
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))),
$at, ConstBoolAttrTrue)>;
// Matmul with transpose on a to matmul with explicit transpose op and a not
@ -90,10 +98,12 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $a),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
/*limit=*/(TF_ConstOp TFi32<0>),
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b,
ConstBoolAttrFalse, $bt)>;
// Partially supported in TFLite, treated as passthrough IdentityOp
def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>;
def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>;
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -246,7 +247,8 @@ struct ConvertTFConvOp : public RewritePattern {
filter_type.getShape());
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
auto bias_attr = rewriter.getZeroAttr(bias_type);
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
auto bias =
rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
@ -297,7 +299,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
rewriter.getIntegerType(32));
auto perm_attr =
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>();
@ -366,7 +368,7 @@ class ConvertTFDepthwiseConv2dNative
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
auto shape_attr =
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
}
@ -377,6 +379,11 @@ class ConvertTFDepthwiseConv2dNative
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
applyPatternsGreedily(func, patterns);
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the

View File

@ -14,9 +14,13 @@ limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def CreateTFShapeOp : NativeCodeCall<
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">;
//===----------------------------------------------------------------------===//
// TensorList transformation patterns.
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
@ -34,3 +38,11 @@ def ConvertTFTensorListStack : Pat<
def ConvertTFTensorListGetItem : Pat<
(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
// TensorListLength is equivalent to the size of the first dimension of the
// input tensorlist, rewrite it to a combination of Gather and Shape op.
def ConvertTFTensorListLength: Pat<
(TF_TensorListLengthOp:$old_value $input),
(TF_GatherOp
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;

View File

@ -105,13 +105,13 @@ void TrimFunctionsPass::Verify() {
SymbolTable symbol_table = SymbolTable(getModule());
llvm::SetVector<FuncOp> reachable_funcs;
for (auto func : getModule().getOps<FuncOp>()) {
func.walk<CallOp>([&](CallOp op) {
if (!symbol_table.lookup<FuncOp>(op.getCallee())) {
getModule().emitError()
<< func.getName() << " is not in the funcs whitelist";
return signalPassFailure();
}
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
return getModule().emitError()
<< func.getName() << " is not in the funcs whitelist";
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) return signalPassFailure();
}
}

View File

@ -0,0 +1,309 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include <climits>
#include <cstdint>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
namespace {
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
// of the inputs, matmul them individually, then stack them all back together at
// the end.
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
void runOnFunction() override;
};
void UnrollBatchMatMulPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
applyPatternsGreedily(func, patterns);
}
} // namespace
template <typename BatchMatMulOpType>
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
PatternRewriter& rewriter) {
int64_t shape_rank = shape.size();
auto shapeSpecType =
rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
Type resultType = rewriter.getTensorType(shape, elementType);
auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
auto shapeTensor =
rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
/*shape=*/shapeTensor);
}
template <typename BatchMatMulOpType>
std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
Type elementType = tensorType.getElementType();
int rank = tensorType.getShape().size();
int num_rows = tensorType.getShape()[rank - 2];
int num_cols = tensorType.getShape()[rank - 1];
// Reshape to rank-3 Tensor with first dimension as the batch size.
auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
elementType, loc, rewriter);
SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
std::vector<Value*> sliced;
Type int64Type = rewriter.getIntegerType(64);
Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
// Slice along each batch index and remember the slice output for future
// use.
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
auto vector3Type = rewriter.getTensorType({3}, int64Type);
auto begin_attr =
DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
auto begin = rewriter.create<ConstantOp>(loc, vector3Type, begin_attr);
auto size = rewriter.create<ConstantOp>(loc, vector3Type, size_attr);
auto sliceOp =
rewriter.create<TF::SliceOp>(loc, sliceResultType,
/*input=*/reshapeOp.output(), begin, size);
// Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
// num_cols]
auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
elementType, loc, rewriter);
sliced.emplace_back(squeezeOp.output());
}
return sliced;
}
template <typename BatchMatMulOpType>
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
Value* value, Location loc, PatternRewriter& rewriter) {
auto valueType = value->getType().cast<RankedTensorType>();
auto shape = valueType.getShape();
int dims = shape.size();
std::vector<int32_t> perm(dims);
for (int i = 0; i < dims - 2; i++) {
perm[i] = i;
}
perm[dims - 2] = dims - 1;
perm[dims - 1] = dims - 2;
auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
rewriter.getIntegerType(32));
auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
int64_t r = transposed_shape[dims - 1];
int64_t c = transposed_shape[dims - 2];
transposed_shape[dims - 1] = c;
transposed_shape[dims - 2] = r;
auto transposed_type =
rewriter.getTensorType(transposed_shape, valueType.getElementType());
return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
}
template <typename BatchMatMulOpType>
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
const std::vector<Value*>& sliced_lhs,
const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
int rows, int cols, Type elementType, Location loc,
PatternRewriter& rewriter) {
auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
std::vector<Value*> matmuls;
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
int lhs_batch_idx, rhs_batch_idx;
if (bcast.IsBroadcastingRequired()) {
lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
} else {
lhs_batch_idx = batch_idx;
rhs_batch_idx = batch_idx;
}
auto false_attr = rewriter.getBoolAttr(false);
auto matmul = rewriter.create<TF::MatMulOp>(loc, matmulType,
/*a=*/sliced_lhs[lhs_batch_idx],
/*b=*/sliced_rhs[rhs_batch_idx],
/*transpose_a=*/false_attr,
/*transpose_b=*/false_attr);
matmuls.emplace_back(matmul.product());
}
// Combine the result of each individual MatMul into a rank-3 Tensor.
Type packedType = rewriter.getTensorType(
{bcast.output_batch_size(), rows, cols}, elementType);
auto N = rewriter.getI64IntegerAttr(matmuls.size());
auto axis = rewriter.getI64IntegerAttr(0);
return rewriter.create<TF::PackOp>(loc, packedType,
/*values=*/matmuls, N, axis);
}
template <typename BatchMatMulOpType>
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
BatchMatMulOpType op, PatternRewriter& rewriter) const {
Value* input_lhs = op.x();
Value* input_rhs = op.y();
if (!input_lhs->getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return this->matchFailure();
}
if (!input_rhs->getType().isa<RankedTensorType>()) {
// RHS must be a ranked tensor type
return this->matchFailure();
}
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
auto elementType = lhs_type.getElementType();
if (elementType != rhs_type.getElementType()) {
// The element type of LHS must be the same with element type of RHS
return this->matchFailure();
}
auto lhs_shape = lhs_type.getShape();
auto rhs_shape = rhs_type.getShape();
Location loc = op.getLoc();
// Transpose LHS input if necessary.
if (op.adj_x()) {
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
lhs_type = input_lhs->getType().cast<RankedTensorType>();
lhs_shape = lhs_type.getShape();
}
// Transpose RHS input if necessary.
if (op.adj_y()) {
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
rhs_type = input_rhs->getType().cast<RankedTensorType>();
rhs_shape = rhs_type.getShape();
}
// Ensure that input ranks are at least 2 and batch shapes are
// broadcastable.
const int dims_a = lhs_shape.size();
const int dims_b = rhs_shape.size();
if (dims_a < 2 || dims_b < 2) {
// Both inputs must have rank >= 2
return this->matchFailure();
}
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
// Input dimensions must be compatible for multipication.
return this->matchFailure();
}
if (dims_a == 2 && dims_b == 2) {
// When both inputs are matrices, just replace the op to a matmul op.
Type resultType =
rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
auto false_attr = rewriter.getBoolAttr(false);
rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, resultType,
/*a=*/input_lhs,
/*b=*/input_rhs,
/*transpose_a=*/false_attr,
/*transpose_b=*/false_attr);
return this->matchSuccess();
}
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
lhs_shape.begin(), lhs_shape.end()),
absl::InlinedVector<tensorflow::int64, 4>(
rhs_shape.begin(), rhs_shape.end()));
if (!bcast.IsValid()) {
// Input batch dimensions must be broadcastable
return this->matchFailure();
}
// Compute slices for each batch in the LHS and RHS.
std::vector<Value*> sliced_lhs =
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
std::vector<Value*> sliced_rhs =
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
// Compute (single batch) MatMul for each output batch. The MatMul outputs
// are then packed together into one output Tensor.
auto packOp =
createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
rhs_shape[dims_b - 1], elementType, loc, rewriter);
// Reshape the rank-3 Tensor into the correct output shape.
const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
std::vector<int64_t> resultShape(resultBatchShape.begin(),
resultBatchShape.end());
resultShape.push_back(lhs_shape[dims_a - 2]);
resultShape.push_back(rhs_shape[dims_b - 1]);
auto reshapeOp =
createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
rewriter.replaceOp(op, reshapeOp.output());
return this->matchSuccess();
}
static PassRegistration<UnrollBatchMatMulPass> pass(
"tfl-unroll-batch-matmul",
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,60 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
template <typename BatchMatMulOpType>
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
Type elementType, Location loc,
PatternRewriter& rewriter);
static std::vector<Value*> sliceInput(Value* value, int batch_size,
Location loc,
PatternRewriter& rewriter);
static TF::TransposeOp createTransposeOp(Value* value, Location loc,
PatternRewriter& rewriter);
static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
const std::vector<Value*>& sliced_rhs,
const tensorflow::MatMulBCast& bcast,
int rows, int cols, Type elementType,
Location loc, PatternRewriter& rewriter);
PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
PatternRewriter& rewriter) const override;
};
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_

View File

@ -0,0 +1,456 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
int32_t val, mlir::Location location) {
auto type = builder->getTensorType(shape, builder->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
float val, mlir::Location location) {
auto type = builder->getTensorType(shape, builder->getF32Type());
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
ArrayRef<int64_t> values, mlir::Location location) {
auto type = builder->getTensorType(static_cast<int>(shape.size()),
builder->getIntegerType(64));
auto attr = DenseElementsAttr::get(type, values);
return builder->create<ConstantOp>(location, type, attr);
}
Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) {
return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
builder->getUnitAttr());
}
Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose,
RankedTensorType type, mlir::Location location) {
// Create a constant op for transpose permutation.
SmallVector<int64_t, 2> perm = {1, 0};
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
// Create tensor type for the transpose result.
auto transpose_type = type;
auto transpose_shape = functional::map(
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
perm);
auto elem_type = transpose_type.getElementType();
auto result_type = builder->getTensorType(transpose_shape, elem_type);
return builder->create<TF::TransposeOp>(location, result_type,
value_to_transpose, perm_op);
}
Value* SliceRankedTensor(OpBuilder* builder, Value* input,
ArrayRef<int64_t> begin_shape,
ArrayRef<int64_t> begin_values,
ArrayRef<int64_t> size_shape,
ArrayRef<int64_t> size_values,
mlir::Location location) {
// Create a dense constant op for slice's begin
auto slice_i2c_begin =
CreateI64DenseConst(builder, begin_shape, begin_values, location);
// Create a dense constant op for slice's size
auto slice_i2c_size =
CreateI64DenseConst(builder, size_shape, size_values, location);
return builder->create<TF::SliceOp>(
location,
builder->getTensorType(
size_values,
input->getType().cast<RankedTensorType>().getElementType()),
input, slice_i2c_begin, slice_i2c_size);
}
} // namespace
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
input2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
input2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_i2i_values,
weight_slice_shape_,
weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
input2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
int input_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
input2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
rec2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
rec2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_rec2i_values,
weight_slice_shape_,
weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
rec2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
rec2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
SmallVector<int64_t, 1> begin_bias2c_values = {0};
bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2c_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
bias2input_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2i_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2f_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
int bias_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2o_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
SmallVector<int64_t, 2> projection_slice_shape = {
1, num_cols_projection_transposed_};
SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
proj_weight_ =
!projection_
? none_
: SliceRankedTensor(
&builder_, projection_transposed_, projection_slice_shape,
projection_slice_begin_values, projection_slice_shape,
projection_slice_size_values, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
proj_bias_ = !projection_type_
? none_
: CreateI32SplatConst(&builder_, {n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
input_cell_state_ =
CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
cell_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
input_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
forget_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
output_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
// Transpose both weight and projection.
weight_transposed_ =
Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
fused_func_op_.getLoc());
none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
// Extract input to cifg gates via slicing the weight tensor
SetWeightForInputToCellGate();
SetWeightForInputToInputGate();
SetWeightForInputToForgetGate();
SetWeightForInputToOutputGate();
// Extract recurrent to cifg gates via slicing the weight tensor
SetWeightForRecurrentToCellGate();
SetWeightForRecurrentToInputGate();
SetWeightForRecurrentToForgetGate();
SetWeightForRecurrentToOutputGate();
// Extract bias to cifg gates via slicing the bias tensor
SetBiasToCellGate();
SetBiasToInputGate();
SetBiasToForgetGate();
SetBiasToOutputGate();
// Extract projection and set an empty projection bias
SetProjection();
SetProjectionBias();
// Set the variable tensors
SetInputActivationState();
SetInputCellState();
// Extract the layer norm coefficients
SetCellLayerNormCoefficients();
SetInputLayerNormCoefficients();
SetForgetLayerNormCoefficients();
SetOutputLayerNormCoefficients();
}
void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
// https://github.com/tensorflow/community/pull/113
auto attr = fused_func_op_.getAttrOfType<StringAttr>("tf_.implements");
if (!attr) {
fused_func_op_.setAttr("tf._implements",
builder_.getStringAttr(GetCompositeOpName()));
}
SmallVector<int64_t, 2> output_shape{1, n_output_};
auto input_types = fused_func_op_.getType().getInputs();
auto output_type = builder_.getTensorType(
output_shape,
input_->getType().cast<RankedTensorType>().getElementType());
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
fused_func_op_.getContext()));
}
void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
// Update the func signature, based on output shape.
// The func will ultimately return the output of the fused
// LSTM op.
UpdateFuncSignature();
// Transoform the weights, projection, bias and layer norm coefficients
// to generate operands for the TFL fused LSTM op.
GenerateFusedOpOperands();
// Create the fused LSTM op.
SmallVector<int64_t, 2> output_shape = {1, n_output_};
auto result_type = builder_.getTensorType(
output_shape,
input_->getType().cast<RankedTensorType>().getElementType());
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
rec2output_, /*cell_to_input_weights*/ none_,
/*cell_to_forget_weights*/ none_,
/*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
bias2output_, proj_weight_, proj_bias_, input_activation_state_,
input_cell_state_, input_layer_norm_coefficients_,
forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
builder_.getStringAttr("FULL"));
builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(), lstm_.getResult());
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
num_gates_ = couple_input_forget_gates_ ? 3 : 4;
input_ = fused_func_op_.getArgument(0);
bias_ = fused_func_op_.getArgument(2);
weight_ = fused_func_op_.getArgument(1);
weight_type_ = weight_->getType().cast<RankedTensorType>();
if (weight_type_.getRank() != 2) {
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
}
if (weight_type_.getDimSize(1) % num_gates_ != 0) {
return fused_func_op_.emitError()
<< "Invalid dimension 1 of weight tensor, "
"should be divisible by the number of gates";
}
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
projection_ = fused_func_op_.getArgument(3);
projection_type_ = projection_->getType().cast<RankedTensorType>();
if (projection_type_.getRank() != 2) {
n_output_ = n_cell_;
} else {
n_output_ = projection_type_.getDimSize(1);
}
n_input_ = weight_type_.getDimSize(0) - n_output_;
num_cols_weight_transposed_ = weight_type_.getDimSize(0);
num_cols_projection_transposed_ = projection_type_.getDimSize(0);
bias_slice_shape_ = {n_cell_};
bias_size_values_ = {n_cell_};
weight_slice_shape_ = {1, num_cols_weight_transposed_};
weight_slice_size_input_values_ = {n_cell_, n_input_};
weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
return success();
}
LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
return fused_func_op_.emitError()
<< "Specified LayerNormalizedLSTMCellSimple was not of the expected "
"interface and cannot not be converted to the fused LSTM op";
}
layer_norm_scale_ = fused_func_op_.getArgument(4);
layer_norm_scale_type_ =
layer_norm_scale_->getType().cast<RankedTensorType>();
if (layer_norm_scale_type_.getRank() != 1) {
return fused_func_op_.emitError()
<< "The layer_norm_scale tensor was not of rank 1";
}
layer_norm_slice_shape_ = {n_cell_};
layer_norm_size_values_ = {n_cell_};
return success();
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
cell_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_cell_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
input_layer_norm_coefficients_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(
&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_input_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
forget_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_forget_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
output_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_output_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,214 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
constexpr char kLstmCellSimple[] = "LSTMCellSimple";
constexpr char kLayerNormalizedLstmCellSimple[] =
"LayerNormalizedLstmCellSimple";
// A utility class that enables the conversion of the LSTMCellSimple composite
// op into a fused TFL LSTM op. The fused op is contained within a FuncOp
// that also contains other supporting ops needed to construct the operands for
// the fused op. The caller provides the containing FuncOp as input with
// arguments specifying the input, weight, projection and bias.
// The weight, pprojection, bias and layer norm scale all need to be
// RankedTensorType.
// This class sets the layer norm coefficients to NoneType.
class ConvertLSTMCellSimpleToFusedLSTM {
public:
// TODO(b/140053256): The couple_input_forget_gates should be specified on
// FuncOp as an attribute.
explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op,
bool couple_input_forget_gates)
: fused_func_op_(fused_func_op),
couple_input_forget_gates_(couple_input_forget_gates),
builder_(fused_func_op.getBody()) {}
// not copyable.
ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
delete;
ConvertLSTMCellSimpleToFusedLSTM& operator=(
const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
// verify input func op arguments and initialize internal state.
virtual LogicalResult Initialize();
virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
// Rewrite the func body with constructed fused lstm.
void RewriteFunc();
protected:
void UpdateFuncSignature();
void GenerateFusedOpOperands();
void SetWeightForInputToCellGate();
void SetWeightForInputToInputGate();
void SetWeightForInputToForgetGate();
void SetWeightForInputToOutputGate();
void SetWeightForRecurrentToCellGate();
void SetWeightForRecurrentToInputGate();
void SetWeightForRecurrentToForgetGate();
void SetWeightForRecurrentToOutputGate();
void SetBiasToCellGate();
void SetBiasToInputGate();
void SetBiasToForgetGate();
void SetBiasToOutputGate();
void SetProjection();
void SetProjectionBias();
void SetInputActivationState();
void SetInputCellState();
virtual void SetCellLayerNormCoefficients();
virtual void SetInputLayerNormCoefficients();
virtual void SetForgetLayerNormCoefficients();
virtual void SetOutputLayerNormCoefficients();
// specified state
FuncOp fused_func_op_;
Value* input_;
Value* weight_;
Value* bias_;
Value* projection_;
bool couple_input_forget_gates_;
// internal state
Value* weight_transposed_;
Value* projection_transposed_;
RankedTensorType weight_type_;
RankedTensorType projection_type_;
int num_gates_;
int n_cell_;
int n_output_;
int n_input_;
int num_cols_weight_transposed_;
int num_cols_projection_transposed_;
// input -> cifg
Value* input2input_;
Value* input2forget_;
Value* input2cell_;
Value* input2output_;
// reccurrent -> cifg
Value* rec2input_;
Value* rec2forget_;
Value* rec2cell_;
Value* rec2output_;
// bias -> cifg
Value* bias2input_;
Value* bias2forget_;
Value* bias2cell_;
Value* bias2output_;
// projection
Value* proj_weight_;
Value* proj_bias_;
// state
Value* input_activation_state_;
Value* input_cell_state_;
// layer norm coefficients
Value* input_layer_norm_coefficients_;
Value* forget_layer_norm_coefficients_;
Value* cell_layer_norm_coefficients_;
Value* output_layer_norm_coefficients_;
mlir::TFL::LSTMOp lstm_;
Value* none_;
SmallVector<int64_t, 1> bias_slice_shape_;
SmallVector<int64_t, 1> bias_size_values_;
SmallVector<int64_t, 2> weight_slice_shape_;
SmallVector<int64_t, 2> weight_slice_size_input_values_;
SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
OpBuilder builder_;
};
// A utility class that enables the conversion of the
// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
// fused op is contained within a FuncOp that also contains other supporting ops
// needed to construct the operands for the fused op. The caller provides the
// containing FuncOp as input with arguments specifying the input, weight,
// projection, bias and layer norm scale. The weight, pprojection, bias and
// layer norm scale all need to be RankedTensorType.
// This class overrides the layer norm coefficient setters from the base class.
class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
: public ConvertLSTMCellSimpleToFusedLSTM {
public:
// TODO(b/140053256): The couple_input_forget_gates should be specified on
// FuncOp as an attribute.
explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
mlir::FuncOp fused_func_op, bool couple_input_forget_gates)
: ConvertLSTMCellSimpleToFusedLSTM(fused_func_op,
couple_input_forget_gates) {}
// not copyable.
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
llvm::StringRef GetCompositeOpName() override {
return kLayerNormalizedLstmCellSimple;
}
LogicalResult Initialize() override;
protected:
void SetCellLayerNormCoefficients() override;
void SetInputLayerNormCoefficients() override;
void SetForgetLayerNormCoefficients() override;
void SetOutputLayerNormCoefficients() override;
private:
// specified state
Value* layer_norm_scale_;
// internal state
RankedTensorType layer_norm_scale_type_;
SmallVector<int64_t, 1> layer_norm_slice_shape_;
SmallVector<int64_t, 1> layer_norm_size_values_;
};
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_

View File

@ -0,0 +1,216 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include <memory>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/platform/test.h"
namespace mlir {
namespace TFL {
FuncOp createFusedFunc(mlir::Builder* builder) {
SmallVector<int64_t, 2> input_shape{1, 2};
SmallVector<int64_t, 2> weight_shape{3, 12};
SmallVector<int64_t, 1> bias_shape{2};
SmallVector<int64_t, 2> projection_shape{1, 2};
SmallVector<int64_t, 1> layer_norm_scale{4};
SmallVector<int64_t, 2> output_shape{1, 2};
auto input_type = builder->getTensorType(input_shape, builder->getF32Type());
auto weight_type =
builder->getTensorType(weight_shape, builder->getF32Type());
auto bias_type = builder->getTensorType(bias_shape, builder->getF32Type());
auto projection_type =
builder->getTensorType(projection_shape, builder->getF32Type());
auto layer_norm_scale_type =
builder->getTensorType(layer_norm_scale, builder->getF32Type());
auto output_type =
builder->getTensorType(output_shape, builder->getF32Type());
SmallVector<mlir::Type, 4> input_types{input_type, weight_type, bias_type,
projection_type,
layer_norm_scale_type};
auto func_type = builder->getFunctionType(input_types, output_type);
auto func =
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
builder->getContext()),
"fused_func", func_type, {});
func.addEntryBlock();
return func;
}
// TODO(ashwinm): Revisit if this test should be moved to a test pass
// with FileCheck test after the pass that consumes the lstm_utils to stack
// the layers.
class LstmUtilsTest : public ::testing::Test {
protected:
LstmUtilsTest() {}
void SetUp() override {
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
fused_lstm_func_ = createFusedFunc(builder_.get());
}
void TearDown() override {
fused_lstm_func_.erase();
builder_.reset();
}
FuncOp fused_lstm_func_;
mlir::MLIRContext context_;
std::unique_ptr<mlir::Builder> builder_;
};
TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, false);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
// verify transpose
EXPECT_EQ(
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
convert.GetCompositeOpName());
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto transpose_op = fused_lstm_func_.getBody().front().begin();
transpose_op++;
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(0),
3);
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(1),
12);
EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
0),
12);
EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
1),
3);
auto return_op = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(return_op->getName().getStringRef(),
mlir::ReturnOp::getOperationName());
return_op++;
EXPECT_EQ(return_op->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(return_op->getNumOperands(), 24);
EXPECT_EQ(return_op->getNumResults(), 1);
// cifg = false, so input2input is not None.
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
// input layer norm is None
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
output_shape[i]);
}
}
TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, true);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
auto it = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
it++;
EXPECT_EQ(it->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1);
// cifg = true, so input2input is None.
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>());
}
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
mlir::TFL::ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM convert(
fused_lstm_func_, false);
auto result = convert.Initialize();
EXPECT_FALSE(failed(result));
convert.RewriteFunc();
fused_lstm_func_.dump();
EXPECT_EQ(
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
convert.GetCompositeOpName());
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto it = fused_lstm_func_.getBody().back().rbegin();
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
it++;
EXPECT_EQ(it->getName().getStringRef(),
mlir::TFL::LSTMOp::getOperationName());
EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1);
// cifg = false, so input2input is not None.
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
// input layer norm
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>());
EXPECT_EQ(
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(),
1);
EXPECT_EQ(
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3);
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
output_shape[i]);
}
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,86 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/op_name_mapper.h"
#include "llvm/ADT/APInt.h"
namespace tensorflow {
using llvm::StringRef;
using mlir::Operation;
OpNameMapper::~OpNameMapper() {}
std::string OpNameMapper::GetUniqueName(llvm::StringRef prefix) {
std::string name = prefix;
auto& val = name_to_count_[name];
if (!val) {
++val;
return name;
}
llvm::SmallString<64> probe_name(prefix);
while (true) {
probe_name.resize(prefix.size());
// TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
// of 1.
// TODO(jpienaar): Switch to radix 36 and update tests.
llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10,
/*Signed=*/false);
if (!name_to_count_.count(probe_name)) {
name = llvm::StringRef(probe_name);
break;
}
}
return name;
}
const std::string& OpNameMapper::GetUniqueName(Operation* op) {
auto& name = op_to_name_[op];
if (!name.empty()) return name;
// Update the value in the map with unique name.
name = GetUniqueName(GetName(op));
return name;
}
int OpNameMapper::InitOpName(mlir::Operation* op, llvm::StringRef name) {
op_to_name_[op] = name;
return name_to_count_[name]++;
}
std::string OpLocNameMapper::GetName(Operation* op) {
if (auto name_loc = op->getLoc().dyn_cast<mlir::NameLoc>())
return name_loc.getName().str();
if (auto call_loc = op->getLoc().dyn_cast<mlir::CallSiteLoc>()) {
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
// if imported with DebugInfo), else use the fallback naming scheme below.
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
return name_loc.getName().str();
}
// If the location is none of the expected types, then simply use name
// generated using the op type.
return op->getName().getStringRef();
}
std::string OpStripNameMapper::GetName(Operation* op) {
return llvm::APInt(32, count_++)
.toString(/*Radix=*/36,
/*Signed=*/false);
}
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More