Merge branch 'master' into toupstream/fix-tflite-interpreter-test
This commit is contained in:
commit
d3564251e3
4
.bazelrc
4
.bazelrc
@ -92,17 +92,13 @@ build:sycl_nodouble --config=sycl
|
||||
build:sycl_trisycl --define=using_trisycl=true
|
||||
|
||||
# Options extracted from configure script
|
||||
build:gdr --define=with_gdr_support=true
|
||||
build:ngraph --define=with_ngraph_support=true
|
||||
build:verbs --define=with_verbs_support=true
|
||||
build:numa --define=with_numa_support=true
|
||||
|
||||
# Options to disable default on features
|
||||
build:noaws --define=no_aws_support=true
|
||||
build:nogcp --define=no_gcp_support=true
|
||||
build:nohdfs --define=no_hdfs_support=true
|
||||
build:nokafka --define=no_kafka_support=true
|
||||
build:noignite --define=no_ignite_support=true
|
||||
build:nonccl --define=no_nccl_support=true
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
|
@ -5,7 +5,6 @@
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @mrry
|
||||
/tensorflow/core/platform/s3 @yongtang
|
||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/python/debug @caisq
|
||||
/tensorflow/python/eager @jaingaurav @alextp
|
||||
@ -37,9 +36,7 @@
|
||||
/tensorflow/contrib/hadoop @yongtang
|
||||
/tensorflow/contrib/hvx/ @satok16
|
||||
/tensorflow/contrib/integrate/ @shoyer
|
||||
/tensorflow/contrib/kafka @yongtang
|
||||
/tensorflow/contrib/kernel_methods/ @petrosmol
|
||||
/tensorflow/contrib/kinesis @yongtang
|
||||
/tensorflow/contrib/ios_examples/ @petewarden
|
||||
/tensorflow/contrib/labeled_tensor/ @shoyer
|
||||
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
||||
|
22
WORKSPACE
22
WORKSPACE
@ -49,34 +49,34 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
|
||||
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
|
||||
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
|
||||
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
|
||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.5.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
|
||||
strip_prefix = "swift-protobuf-1.6.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
|
@ -3,56 +3,56 @@ package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "gcc",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-gcc",
|
||||
"bin/arm-rpi-linux-gnueabihf-gcc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ar",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-ar",
|
||||
"bin/arm-rpi-linux-gnueabihf-ar",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ld",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-ld",
|
||||
"bin/arm-rpi-linux-gnueabihf-ld",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "nm",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-nm",
|
||||
"bin/arm-rpi-linux-gnueabihf-nm",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objcopy",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-objcopy",
|
||||
"bin/arm-rpi-linux-gnueabihf-objcopy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objdump",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-objdump",
|
||||
"bin/arm-rpi-linux-gnueabihf-objdump",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "strip",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-strip",
|
||||
"bin/arm-rpi-linux-gnueabihf-strip",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "as",
|
||||
srcs = [
|
||||
"bin/arm-linux-gnueabihf-as",
|
||||
"bin/arm-rpi-linux-gnueabihf-as",
|
||||
],
|
||||
)
|
||||
|
||||
|
81
configure.py
81
configure.py
@ -1145,78 +1145,6 @@ def set_trisycl_include_dir(environ_cp):
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||
|
||||
|
||||
def set_mpi_home(environ_cp):
|
||||
"""Set MPI_HOME."""
|
||||
|
||||
default_mpi_home = which('mpirun') or which('mpiexec') or ''
|
||||
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
|
||||
|
||||
def valid_mpi_path(mpi_home):
|
||||
exists = (
|
||||
os.path.exists(os.path.join(mpi_home, 'include')) and
|
||||
(os.path.exists(os.path.join(mpi_home, 'lib')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
if not exists:
|
||||
print(
|
||||
'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
|
||||
% (os.path.join(mpi_home, 'include'),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
return exists
|
||||
|
||||
_ = prompt_loop_or_load_from_env(
|
||||
environ_cp,
|
||||
var_name='MPI_HOME',
|
||||
var_default=default_mpi_home,
|
||||
ask_for_var='Please specify the MPI toolkit folder.',
|
||||
check_success=valid_mpi_path,
|
||||
error_msg='',
|
||||
suppress_default_error=True)
|
||||
|
||||
|
||||
def set_other_mpi_vars(environ_cp):
|
||||
"""Set other MPI related variables."""
|
||||
# Link the MPI header files
|
||||
mpi_home = environ_cp.get('MPI_HOME')
|
||||
symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
|
||||
|
||||
# Determine if we use OpenMPI or MVAPICH, these require different header files
|
||||
# to be included here to make bazel dependency checker happy
|
||||
if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
|
||||
'third_party/mpi/mpi_portable_platform.h')
|
||||
# TODO(gunan): avoid editing files in configure
|
||||
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = False',
|
||||
'MPI_LIB_IS_OPENMPI = True')
|
||||
else:
|
||||
# MVAPICH / MPICH
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
|
||||
# TODO(gunan): avoid editing files in configure
|
||||
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = True',
|
||||
'MPI_LIB_IS_OPENMPI = False')
|
||||
|
||||
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
|
||||
(mpi_home, mpi_home, mpi_home))
|
||||
|
||||
|
||||
def system_specific_test_config(env):
|
||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
@ -1549,11 +1477,6 @@ def main():
|
||||
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
|
||||
'At most 1 GPU platform can be configured.')
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
|
||||
if environ_cp.get('TF_NEED_MPI') == '1':
|
||||
set_mpi_home(environ_cp)
|
||||
set_other_mpi_vars(environ_cp)
|
||||
|
||||
set_cc_opt_flags(environ_cp)
|
||||
set_system_libs_flag(environ_cp)
|
||||
if is_windows():
|
||||
@ -1580,8 +1503,6 @@ def main():
|
||||
'details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('gdr', 'Build with GDR support.')
|
||||
config_info_line('verbs', 'Build with libverbs support.')
|
||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||
config_info_line('numa', 'Build with NUMA support.')
|
||||
config_info_line(
|
||||
@ -1593,8 +1514,6 @@ def main():
|
||||
config_info_line('noaws', 'Disable AWS S3 filesystem support.')
|
||||
config_info_line('nogcp', 'Disable GCP support.')
|
||||
config_info_line('nohdfs', 'Disable HDFS support.')
|
||||
config_info_line('noignite', 'Disable Apache Ignite support.')
|
||||
config_info_line('nokafka', 'Disable Apache Kafka support.')
|
||||
config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
|
||||
|
||||
|
||||
|
@ -267,18 +267,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "no_ignite_support",
|
||||
define_values = {"no_ignite_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "no_kafka_support",
|
||||
define_values = {"no_kafka_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "no_nccl_support",
|
||||
define_values = {"no_nccl_support": "true"},
|
||||
@ -309,18 +297,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_gdr_support",
|
||||
define_values = {"with_gdr_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_verbs_support",
|
||||
define_values = {"with_verbs_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_numa_support",
|
||||
define_values = {"with_numa_support": "true"},
|
||||
@ -421,12 +397,6 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_mpi_support",
|
||||
values = {"define": "with_mpi_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "override_eigen_strong_inline",
|
||||
values = {"define": "override_eigen_strong_inline=true"},
|
||||
@ -470,6 +440,7 @@ config_setting(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
|
@ -56,10 +56,10 @@ elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
|
||||
# Import compat before trying to import summary from tensorboard, so that
|
||||
# reexport_tf_summary can get compat from sys.modules
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
@ -125,25 +125,6 @@ if _running_from_pip_package():
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
|
||||
# These symbols appear because we import the python package which
|
||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||
# must come from this module. So python adds these symbols for the
|
||||
# resolution to succeed.
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
losses = keras.losses
|
||||
|
@ -60,6 +60,10 @@ elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
# Import compat before trying to import summary from tensorboard, so that
|
||||
# reexport_tf_summary can get compat from sys.modules. Only needed if using
|
||||
# lazy loading.
|
||||
_current_module.compat.v2 # pylint: disable=pointless-statement
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
@ -134,23 +138,3 @@ if _running_from_pip_package():
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
|
||||
# These symbols appear because we import the python package which
|
||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||
# must come from this module. So python adds these symbols for the
|
||||
# resolution to succeed.
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -270,6 +270,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
exports_files(
|
||||
|
@ -159,7 +159,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
|
||||
TF_Status* status) {
|
||||
Session* session;
|
||||
status->status = NewSession(opt->options, &session);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
return new TF_DeprecatedSession({session});
|
||||
} else {
|
||||
DCHECK_EQ(nullptr, session);
|
||||
@ -332,7 +332,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
||||
// TODO(nolivia): check this on a subset of the graph instead of all of
|
||||
// it.
|
||||
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
session->graph->mu.unlock();
|
||||
return false;
|
||||
}
|
||||
@ -352,7 +352,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
||||
*graph_def.mutable_library() = graph.flib_def().ToProto();
|
||||
session->graph->mu.unlock();
|
||||
status->status = session->session->Extend(std::move(graph_def));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
// Contract is we always delete input_values[i].
|
||||
return false;
|
||||
}
|
||||
@ -382,7 +382,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
|
||||
const int ninputs = input_pairs->size();
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
if (!status->status.ok()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -439,7 +439,7 @@ static void TF_Run_Helper(
|
||||
// Serialize back to upstream client, who now owns the new buffer
|
||||
if (run_metadata != nullptr) {
|
||||
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
} else {
|
||||
// NOTE(zongheng): PRun does not support RunOptions yet.
|
||||
@ -459,7 +459,7 @@ static void TF_Run_Helper(
|
||||
continue;
|
||||
}
|
||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
||||
string new_handle;
|
||||
status->status = s->session->PRunSetup(input_names, output_names,
|
||||
target_oper_names, &new_handle);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
*handle = buf;
|
||||
@ -555,7 +555,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
||||
status->status = tensorflow::LoadLibrary(
|
||||
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
|
||||
&lib_handle->op_list.length);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
@ -983,7 +983,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
|
||||
TF_Tensor* value, TF_Status* status) {
|
||||
Tensor t;
|
||||
status->status = TF_TensorToTensor(value, &t);
|
||||
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
|
||||
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
@ -993,13 +993,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
std::vector<Tensor> t;
|
||||
t.reserve(num_values);
|
||||
|
||||
for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
|
||||
for (int i = 0; i < num_values && status->status.ok(); ++i) {
|
||||
Tensor v;
|
||||
status->status = TF_TensorToTensor(values[i], &v);
|
||||
t.emplace_back(v);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
|
||||
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
@ -1048,11 +1048,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
|
||||
/*consume=*/true);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
// Run shape inference function for newly added node.
|
||||
status->status = desc->graph->refiner.AddNode(ret);
|
||||
}
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
// Add the node to the name-to-node mapping.
|
||||
desc->graph->name_map[ret->name()] = ret;
|
||||
} else if (ret != nullptr) {
|
||||
@ -1101,7 +1101,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
|
||||
@ -1123,7 +1123,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
|
||||
@ -1142,6 +1142,16 @@ TF_Output TF_OperationInput(TF_Input oper_in) {
|
||||
return {ToOperation(edge->src()), edge->src_output()};
|
||||
}
|
||||
|
||||
void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
|
||||
int max_inputs) {
|
||||
for (auto* edge : oper->node.in_edges()) {
|
||||
if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
|
||||
inputs[edge->dst_input()] = {ToOperation(edge->src()),
|
||||
edge->src_output()};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int TF_OperationOutputNumConsumers(TF_Output oper_out) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
||||
@ -1221,7 +1231,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
||||
TF_Status* status) {
|
||||
TF_AttrMetadata metadata;
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return metadata;
|
||||
if (!status->status.ok()) return metadata;
|
||||
switch (attr->value_case()) {
|
||||
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
||||
case tensorflow::AttrValue::kK: \
|
||||
@ -1328,7 +1338,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
|
||||
void* value, size_t max_length,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kS) {
|
||||
status->status =
|
||||
InvalidArgument("Attribute '", attr_name, "' is not a string");
|
||||
@ -1346,7 +1356,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
int max_values, void* storage,
|
||||
size_t storage_size, TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list");
|
||||
@ -1379,7 +1389,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
||||
int max_values, TF_Status* status) { \
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status); \
|
||||
if (TF_GetCode(status) != TF_OK) return; \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
@ -1401,7 +1411,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
|
||||
PartialTensorShape shape;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(shape.dims(), num_dims);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
value[i] = shape.dim_size(i);
|
||||
@ -1415,7 +1425,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
|
||||
int64_t* p = storage;
|
||||
int storage_left = storage_size;
|
||||
@ -1443,7 +1453,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
|
||||
const char* attr_name,
|
||||
TF_Buffer* value, TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kShape) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a shape.");
|
||||
@ -1457,7 +1467,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
|
||||
TF_Buffer** values, int max_values,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
||||
status->status =
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list");
|
||||
@ -1467,7 +1477,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
// Delete everything allocated to far, the operation has failed.
|
||||
for (int j = 0; j <= i; ++j) {
|
||||
TF_DeleteBuffer(values[j]);
|
||||
@ -1482,7 +1492,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
*value = nullptr;
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
*value = TF_TensorFromTensor(t, status);
|
||||
}
|
||||
|
||||
@ -1491,7 +1501,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
TF_Status* status) {
|
||||
std::vector<Tensor> ts;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||
@ -1502,7 +1512,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
|
||||
TF_Buffer* output_attr_value,
|
||||
TF_Status* status) {
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
status->status = MessageToBuffer(*attr, output_attr_value);
|
||||
}
|
||||
|
||||
@ -1583,7 +1593,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
|
||||
{
|
||||
mutex_lock l(graph->mu);
|
||||
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
status->status = MessageToBuffer(*op_def, output_op_def);
|
||||
}
|
||||
@ -1701,7 +1711,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
&graph->refiner, &results);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Add new nodes to name_map
|
||||
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
||||
@ -1755,7 +1765,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
||||
auto results = new TF_ImportGraphDefResults();
|
||||
mutex_lock l(graph->mu);
|
||||
GraphImportGraphDefLocked(graph, def, options, results, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
delete results;
|
||||
return nullptr;
|
||||
}
|
||||
@ -1813,7 +1823,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
|
||||
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
|
||||
// TODO(skyewm): set placeholder shape
|
||||
TF_Operation* oper = TF_FinishOperation(desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
if (!status->status.ok()) return false;
|
||||
*input = {oper, 0};
|
||||
return true;
|
||||
}
|
||||
@ -1958,7 +1968,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
||||
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
|
||||
body_graph, body_inputs, body_outputs, name};
|
||||
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
FreeWhileResources(¶ms);
|
||||
return EmptyWhileParams();
|
||||
}
|
||||
@ -2160,7 +2170,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
||||
TF_Status* status) {
|
||||
Session* session;
|
||||
status->status = NewSession(opt->options, &session);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
TF_Session* new_session = new TF_Session(session, graph);
|
||||
if (graph != nullptr) {
|
||||
mutex_lock l(graph->mu);
|
||||
@ -2208,7 +2218,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
status->status =
|
||||
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
|
||||
export_dir, tag_set, &bundle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
|
||||
// extends using GraphDefs. The Graph instance is different, but equivalent
|
||||
@ -2221,11 +2231,11 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
|
||||
import_opts, &results, status);
|
||||
TF_DeleteImportGraphDefOptions(import_opts);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
if (meta_graph_def != nullptr) {
|
||||
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
|
||||
TF_Session* session = new TF_Session(bundle.session.release(), graph);
|
||||
@ -2325,7 +2335,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
|
||||
string new_handle;
|
||||
status->status = session->session->PRunSetup(input_names, output_names,
|
||||
target_names, &new_handle);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
if (status->status.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
*handle = buf;
|
||||
@ -2387,9 +2397,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
||||
tensor, graph->refiner, *graph->graph.op_registry(),
|
||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||
if (evaluated) {
|
||||
DCHECK(TF_GetCode(status) == TF_OK);
|
||||
DCHECK(status->status.ok());
|
||||
*result = TF_TensorFromTensor(result_tensor, status);
|
||||
if (TF_GetCode(status) != TF_OK) evaluated = false;
|
||||
if (!status->status.ok()) evaluated = false;
|
||||
}
|
||||
return evaluated;
|
||||
}
|
||||
@ -2444,7 +2454,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(*api_def, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2456,7 +2466,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
|
||||
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2468,7 +2478,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
tensorflow::GetRegisteredKernelsForOp(name);
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
@ -2498,7 +2508,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len,
|
||||
|
||||
std::unique_ptr<tensorflow::ServerInterface> out_server;
|
||||
status->status = tensorflow::NewServer(server_def, &out_server);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TF_Server(std::move(out_server));
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
|
@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper,
|
||||
// producer.index) to consumer.oper's input (given by consumer.index).
|
||||
TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in);
|
||||
|
||||
// Get list of all inputs of a specific operation. `inputs` must point to
|
||||
// an array of length at least `max_inputs` (ideally set to
|
||||
// TF_OperationNumInputs(oper)). Beware that a concurrent
|
||||
// modification of the graph can increase the number of inputs of
|
||||
// an operation.
|
||||
TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper,
|
||||
TF_Output* inputs,
|
||||
int max_inputs);
|
||||
|
||||
// Get the number of current consumers of a specific output of an
|
||||
// operation. Note that this number can change when new operations
|
||||
// are added to the graph.
|
||||
|
@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
static void CheckOk(TF_Status* status) {
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
||||
auto* status = TF_NewStatus();
|
||||
if (!TFE_TensorHandleIsConcrete(handle)) {
|
||||
|
@ -41,6 +41,7 @@ namespace {
|
||||
// node names, so if necessary we add a suffix to make
|
||||
// names unique. If we have an input named "A" and a node in the function
|
||||
// body named "a", they will be renamed to "a" and "a_0".
|
||||
// TODO(b/139886381) Unify this and the one in graph_to_functiondef.cc
|
||||
class NodeNameMapping {
|
||||
public:
|
||||
NodeNameMapping() = default;
|
||||
@ -64,14 +65,14 @@ class NodeNameMapping {
|
||||
string Lookup(const string& name) const;
|
||||
|
||||
private:
|
||||
string UniquifyHelper(const string& name) const;
|
||||
string UniquifyHelper(const string& name);
|
||||
static string Normalize(string name);
|
||||
|
||||
// The normalized/uniquified names already used as
|
||||
// input names (in signature), output names (in signature), and node names
|
||||
// (in node_def).
|
||||
// This is a superset of values in name_mapping_.
|
||||
std::unordered_set<string> used_names_;
|
||||
std::unordered_map<string, uint64> used_names_;
|
||||
// Mapping from original node name from the graph to the normalized
|
||||
// and uniquified version of it.
|
||||
std::unordered_map<string, string> name_mapping_;
|
||||
@ -102,13 +103,16 @@ string NodeNameMapping::Normalize(string name) {
|
||||
return i == n ? "unknown" : name.substr(i);
|
||||
}
|
||||
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) const {
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) {
|
||||
auto it = used_names_.emplace(name, 0);
|
||||
// If the name hasn't been used yet, use it as-is.
|
||||
if (used_names_.find(name) == used_names_.end()) return name;
|
||||
if (it.second) return name;
|
||||
|
||||
// Add a suffix to name to make it unique.
|
||||
for (int i = 0;; ++i) {
|
||||
const string candidate = strings::StrCat(name, "_", i);
|
||||
if (used_names_.find(candidate) == used_names_.end()) return candidate;
|
||||
while (true) {
|
||||
const string candidate = strings::StrCat(name, "_", it.first->second);
|
||||
it.first->second++;
|
||||
if (used_names_.emplace(candidate, 0).second) return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,16 +124,13 @@ string NodeNameMapping::GetInputName(const string& name) {
|
||||
|
||||
string NodeNameMapping::GetOutputName(const string& name) {
|
||||
const string& input_name = UniquifyHelper(Normalize(name));
|
||||
// Record that we used this name, but don't add it to name_mapping_
|
||||
// since this name is not for a node.
|
||||
used_names_.insert(input_name);
|
||||
// Don't add it to name_mapping_ since this name is not for a node.
|
||||
return input_name;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Uniquify(const string& name) {
|
||||
const string uniqued = UniquifyHelper(name);
|
||||
name_mapping_[name] = uniqued;
|
||||
used_names_.insert(uniqued);
|
||||
return uniqued;
|
||||
}
|
||||
|
||||
@ -139,7 +140,7 @@ Status NodeNameMapping::UseOutputName(const string& name) {
|
||||
return InvalidArgument("Cannot have duplicate output names. Name '", name,
|
||||
"' appears more than once in 'output_names' array.");
|
||||
}
|
||||
used_names_.insert(iter, name);
|
||||
used_names_.emplace(name, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,6 @@ load(
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:default/build_config.bzl",
|
||||
"tf_additional_device_tracer_test_flags",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
load(
|
||||
@ -37,6 +36,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -79,6 +79,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
@ -226,6 +227,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/profiler/rpc/client:capture_profile",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
@ -234,8 +236,7 @@ tf_cuda_cc_test(
|
||||
srcs = [
|
||||
"c_api_experimental_test.cc",
|
||||
],
|
||||
args =
|
||||
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
|
@ -26,12 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
@ -60,6 +62,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -99,32 +102,34 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::Status status;
|
||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||
for (const string& remote_worker : remote_workers) {
|
||||
tensorflow::Notification n;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, remote_worker,
|
||||
[&status, &n, &remote_devices](
|
||||
tensorflow::Env::Default(), worker_cache, remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
status = s;
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
n.Notify();
|
||||
counter.DecrementCount();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
}
|
||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||
new tensorflow::DeviceMgr(std::move(remote_devices)));
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -134,11 +139,15 @@ tensorflow::Status CreateRemoteContexts(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
for (int i = 0; i < remote_workers.size(); i++) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextResponse response;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
@ -158,16 +167,17 @@ tensorflow::Status CreateRemoteContexts(
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
tensorflow::Notification n;
|
||||
tensorflow::Status status;
|
||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||
eager_client->CreateContextAsync(
|
||||
&request, &response, [&status, &n](const tensorflow::Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -214,7 +224,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
|
||||
remote_workers.end());
|
||||
|
||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&remote_device_mgr));
|
||||
@ -246,7 +256,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor()->Async(), base_request));
|
||||
ctx->context->Executor().Async(), base_request));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
@ -384,7 +394,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
&devices);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||
new tensorflow::DeviceMgr(std::move(devices)));
|
||||
new tensorflow::StaticDeviceMgr(std::move(devices)));
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
@ -563,7 +573,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), handle->Context()->Executor(),
|
||||
handle, handle->Context(), &handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
@ -893,10 +903,9 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
|
||||
*num_retvals);
|
||||
status->status =
|
||||
tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
@ -916,7 +925,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
ctx->context->Executor(),
|
||||
&ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
@ -967,7 +976,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
|
||||
status->status = ctx->context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
@ -979,9 +988,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
TF_Status* status) {
|
||||
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
|
||||
for (const auto& attr : func.attr()) {
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
return func_op;
|
||||
}
|
||||
@ -1029,7 +1038,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
} break;
|
||||
case tensorflow::AttrValue::kFunc: {
|
||||
const auto func_op = GetFunc(ctx, default_value.func(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!status->status.ok()) return;
|
||||
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
|
||||
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
||||
// consider adding an overload to the C API to make this case easier.
|
||||
|
@ -597,5 +597,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(ctx->context->Executor());
|
||||
return new TFE_Executor(&ctx->context->Executor());
|
||||
}
|
||||
|
@ -84,11 +84,6 @@ void ExecuteWithProfiling(bool async) {
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
// device name with "stream:all" is collected by Device Tracer.
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
// ROCm platform does not yet support stream level tracing
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
|
||||
#endif
|
||||
}
|
||||
// "/host:CPU" is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
|
@ -1069,10 +1069,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
// still fail.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
retvals[0] = nullptr;
|
||||
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorClearError(executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -209,7 +210,9 @@ class ForwardAccumulator {
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||
: vspace_(vspace) {
|
||||
call_state_.emplace(nullptr, false);
|
||||
}
|
||||
|
||||
virtual ~ForwardAccumulator() {
|
||||
for (auto accumulated : accumulated_gradients_) {
|
||||
@ -262,11 +265,11 @@ class ForwardAccumulator {
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||
|
||||
// Returns true if `Accumulate` is active somewhere above on the stack. This
|
||||
// is useful for ordering ForwardAccumulators, where more deeply nested
|
||||
// accumulators should not see computations from less deeply nested
|
||||
// accumulators.
|
||||
bool BusyAccumulating() const { return this->accumulating_; }
|
||||
// Returns true if `Accumulate` is active somewhere above on the stack and
|
||||
// there isn't an intervening PushState. This is useful for ordering
|
||||
// ForwardAccumulators, where more deeply nested accumulators should not see
|
||||
// computations from less deeply nested accumulators.
|
||||
bool BusyAccumulating() const { return call_state_.top().accumulating; }
|
||||
|
||||
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
||||
// a nullptr if none is available.
|
||||
@ -282,6 +285,15 @@ class ForwardAccumulator {
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
|
||||
// Temporarily push or pop transient state for this accumulator.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. Without pushing and poping, accumulators
|
||||
// ignore operations executed as a direct result of their own jvp
|
||||
// computations.
|
||||
void PushState() { call_state_.emplace(nullptr, false); }
|
||||
void PopState() { call_state_.pop(); }
|
||||
|
||||
private:
|
||||
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
||||
// from a backward gradient function. Fills `out_grads` corresponding to
|
||||
@ -289,7 +301,7 @@ class ForwardAccumulator {
|
||||
//
|
||||
// Executes the backward function in order to trace its gradient, which will
|
||||
// waste computation if executing eagerly (when graph building the unneeded
|
||||
// computation is pruned). Temporarily sets `backward_tape_` so that
|
||||
// computation is pruned). Temporarily sets `backward_tape` so that
|
||||
// Accumulate will forward op executions to the tape while the backward
|
||||
// function is running; this effectively adds the backward tape to the active
|
||||
// set (but does not require complicated callbacks to the language bindings).
|
||||
@ -305,16 +317,26 @@ class ForwardAccumulator {
|
||||
// Not owned; provides operations on Tensors which are currently only
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
|
||||
// keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
|
||||
// While the Accumulate method is running (accumulating_ is True), any op
|
||||
// executions not forwarded to backward_tape_ should be ignored.
|
||||
bool accumulating_;
|
||||
|
||||
struct AccumulatorCallState {
|
||||
AccumulatorCallState(
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
|
||||
bool accumulating)
|
||||
: backward_tape(backward_tape), accumulating(accumulating) {}
|
||||
// Set temporarily while in the Accumulate method; if backward_tape is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets
|
||||
// `backward_tape` keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
|
||||
// While the Accumulate method is running (accumulating is True), any op
|
||||
// executions not forwarded to backward_tape should be ignored.
|
||||
bool accumulating;
|
||||
};
|
||||
// A deque-backed stack, whose element references are not invalidated by
|
||||
// pushes and pops at the back.
|
||||
std::stack<AccumulatorCallState> call_state_;
|
||||
};
|
||||
|
||||
// Template instantiations here
|
||||
@ -645,16 +667,15 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
Status s = InitialGradients(vspace, target_tensor_ids,
|
||||
sources_that_are_targets, output_gradients,
|
||||
tensor_tape_, state.op_tape, &gradients);
|
||||
auto cleanup = [this, &state]() {
|
||||
auto cleanup = gtl::MakeCleanup([this, &state]() {
|
||||
if (!persistent_) {
|
||||
// Release all backprop functions
|
||||
for (const auto& pair : state.op_tape) {
|
||||
pair.second.backward_function_deleter(pair.second.backward_function);
|
||||
}
|
||||
}
|
||||
};
|
||||
});
|
||||
if (!s.ok()) {
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -732,11 +753,17 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
s = vspace.CallBackwardFunction(trace.backward_function,
|
||||
unneeded_gradients, out_gradients,
|
||||
&in_gradients);
|
||||
if (in_gradients.size() != trace.input_tensor_id.size()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Recorded operation '", trace.op_type,
|
||||
"' returned too few gradients. Expected ",
|
||||
trace.input_tensor_id.size(), " but received ",
|
||||
in_gradients.size());
|
||||
}
|
||||
if (!persistent_) {
|
||||
trace.backward_function_deleter(trace.backward_function);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
} else {
|
||||
@ -847,12 +874,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
|
||||
// we should also delegate ShouldRecord.
|
||||
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
|
||||
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
|
||||
}
|
||||
if (accumulating_) {
|
||||
if (call_state_.top().accumulating) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
@ -884,9 +911,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
*/
|
||||
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
||||
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
||||
backward_tape_ = tape.get();
|
||||
AccumulatorCallState& call_state = call_state_.top();
|
||||
call_state.backward_tape = tape.get();
|
||||
auto pop_backward_tape =
|
||||
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
|
||||
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
|
||||
std::vector<Gradient*> forwardprop_aids;
|
||||
std::vector<int64> sources;
|
||||
std::unordered_set<int64> sources_set;
|
||||
@ -894,6 +922,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
for (const TapeTensor& output_tensor : output_tensors) {
|
||||
// Ownership of `aid` transferred to CallBackwardFunction below.
|
||||
Gradient* aid = vspace_.Ones(output_tensor);
|
||||
if (TF_PREDICT_FALSE(aid == nullptr)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to create ones tensor for tensor ", output_tensor.GetID(),
|
||||
" with dtype ", output_tensor.GetDType());
|
||||
}
|
||||
forwardprop_aids.push_back(aid);
|
||||
int64 aid_id = vspace_.TensorId(aid);
|
||||
sources.push_back(aid_id);
|
||||
@ -961,10 +994,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
const ForwardFunction<Gradient>* forward_function,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If backward_tape_ is not null, then this call to Accumulate is the result
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If backward_tape is not null, then this call to Accumulate is the result
|
||||
// of a still-active call to Accumulate which is running operations. We
|
||||
// forward these operations to backward_tape_ so the outer Accumulate call
|
||||
// forward these operations to backward_tape so the outer Accumulate call
|
||||
// can do its work.
|
||||
//
|
||||
// Rather than re-entering and delegating Accumulate like this, we could
|
||||
@ -972,9 +1005,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
// (so it can deactivate itself and activate its GradientTape). Currently
|
||||
// that is managed by the language binding and would require relatively
|
||||
// messy callbacks.
|
||||
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
|
||||
input_dtypes, backward_function_getter,
|
||||
backward_function_deleter);
|
||||
call_state_.top().backward_tape->RecordOperation(
|
||||
op_type, output_tensors, input_tensor_id, input_dtypes,
|
||||
backward_function_getter, backward_function_deleter);
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||
@ -1012,9 +1045,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
|
||||
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
||||
// executing ops, and we don't want to watch those with this accumulator.
|
||||
accumulating_ = true;
|
||||
auto reset_accumulating =
|
||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
||||
call_state_.emplace(nullptr, true);
|
||||
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
|
||||
|
||||
std::vector<Gradient*> forward_grads;
|
||||
if (forward_function == nullptr) {
|
||||
|
@ -123,6 +123,7 @@ cc_library(
|
||||
"//tensorflow/core/util/tensor_bundle:naming",
|
||||
# mobile not supported yet
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -48,12 +48,12 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
export_dir);
|
||||
}
|
||||
|
||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
|
||||
SavedModel* saved_model_proto,
|
||||
MetaGraphDef* meta_graph_def) {
|
||||
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
|
||||
<< " }";
|
||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||
for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
|
||||
// Get tags from the graph_def.
|
||||
std::unordered_set<string> graph_tags;
|
||||
for (const string& tag : graph_def.meta_info_def().tags()) {
|
||||
@ -61,7 +61,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
}
|
||||
// Match with the set of tags provided.
|
||||
if (graph_tags == tags) {
|
||||
*meta_graph_def = graph_def;
|
||||
*meta_graph_def = std::move(graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -81,7 +81,8 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
MetaGraphDef* const meta_graph_def) {
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
# -*- Python -*-
|
||||
|
||||
"""Build macro that compiles a TensorFlow graph into a cc_library.
|
||||
|
||||
To use from your BUILD file, add the following line to load the macro:
|
||||
|
@ -1,12 +1,11 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
"//tensorflow/core/common_runtime/eager:__pkg__",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
@ -39,7 +38,7 @@ cc_library(
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/compiler/plugin",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
]),
|
||||
@ -62,7 +61,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "xla_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda([
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -145,8 +144,57 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
XLA_DEVICE_DEPS = [
|
||||
":common",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:control_flow_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:no_op_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "xla_device",
|
||||
name = "xla_device_no_jit_rewrite_registration",
|
||||
srcs = [
|
||||
"xla_compile_on_demand_op.cc",
|
||||
"xla_device.cc",
|
||||
@ -159,56 +207,22 @@ cc_library(
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
deps = XLA_DEVICE_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_device",
|
||||
hdrs = [
|
||||
"xla_compile_on_demand_op.h",
|
||||
"xla_device.h",
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
# Public visibility is needed for external TF/XLA backends.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
deps = XLA_DEVICE_DEPS + [
|
||||
":jit_compilation_passes",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:control_flow_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:no_op_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
":xla_device_no_jit_rewrite_registration",
|
||||
],
|
||||
)
|
||||
|
||||
@ -262,6 +276,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -269,6 +284,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/lib/gtl:array_slice",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -323,22 +340,27 @@ cc_library(
|
||||
":compilation_passes",
|
||||
":xla_activity_logging_listener",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||
"//tensorflow/compiler/tf2xla:mlir_bridge_pass_registration",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
] + tf_jit_compilation_passes_extra_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Linked by tensorflow core, without registration of jit compilation passes
|
||||
# which is not necessary to create and run a XlaLocalLaunchBase kernel.
|
||||
# Linking jit compilation passes could cause programs stuck right now (b/140069592).
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
name = "xla_kernel_creator_util",
|
||||
srcs = [
|
||||
"xla_kernel_creator.cc",
|
||||
"xla_kernel_creator_util.cc",
|
||||
],
|
||||
hdrs = ["xla_kernel_creator.h"],
|
||||
hdrs = ["xla_kernel_creator_util.h"],
|
||||
visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
":compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -351,6 +373,23 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
srcs = [
|
||||
"xla_kernel_creator.cc",
|
||||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_kernel_creator_test",
|
||||
srcs = [
|
||||
@ -846,6 +885,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -46,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -85,7 +86,7 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
@ -100,12 +101,14 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{node.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace,
|
||||
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
@ -120,22 +123,31 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace,
|
||||
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(
|
||||
const Node& node, string* uncompilable_reason) const {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
if (node.type_string() == "SymbolicGradient") return false;
|
||||
if (node.type_string() == "SymbolicGradient") {
|
||||
*uncompilable_reason =
|
||||
"SymbolicGradient should be handled by IsCompilableCall().";
|
||||
return false;
|
||||
}
|
||||
if (node.type_string() == "Const") {
|
||||
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
|
||||
// registered Const KernelDef says that it does, to support no-op Assert for
|
||||
// tfcompile.
|
||||
const AttrValue* attr = node.attrs().Find("dtype");
|
||||
if (attr != nullptr && attr->type() == DT_STRING) {
|
||||
*uncompilable_reason =
|
||||
"Const op with type DT_STRING is not supported by XLA.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -145,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
// such nodes out of XLA clusters.
|
||||
if (HasForwardedRefInput(node)) {
|
||||
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
|
||||
*uncompilable_reason = "Identity with unsafe cast.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
|
||||
Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
|
||||
if (!s.ok()) {
|
||||
*uncompilable_reason = s.error_message();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Tests whether 'if_node' is compilable. Every operator in the then_branch and
|
||||
@ -156,16 +174,18 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
bool is_compilable = true;
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
if_node, "then_branch", "if_then", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
if_node, "else_branch", "if_else", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
return is_compilable;
|
||||
}
|
||||
@ -176,37 +196,43 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
bool RecursiveCompilabilityChecker::IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
bool is_compilable = true;
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
while_node, "cond", "while_cond", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
while_node, "body", "while_body", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
while_node, "body", "while_body", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
|
||||
const Node& node, const std::string& attr_name,
|
||||
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::string& call_name, NameAttrList* encapsulating_function,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
NodeDef call;
|
||||
call.set_name(call_name);
|
||||
if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
|
||||
const auto uncompilable_reason = absl::StrCat(
|
||||
"missing '", attr_name, "' attribute from node", node.name());
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
if (!IsCompilableCall(call, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
VLOG(2) << "Rejecting node " << node.name()
|
||||
<< ": can't compile : " << call.op();
|
||||
return false;
|
||||
@ -220,24 +246,33 @@ bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
if (stack_trace->size() > kMaxRecursionDepth) {
|
||||
std::string uncompilable_reason = "function depth limit exceeded";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
|
||||
if (!status.ok()) {
|
||||
Status s;
|
||||
NameAttrList function;
|
||||
s = NameAndAttrsFromFunctionCall(call_def, &function);
|
||||
if (s.ok()) {
|
||||
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
|
||||
&handle);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
std::string uncompilable_reason = "could not instantiate call";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
|
||||
<< uncompilable_reason << " : " << status;
|
||||
<< uncompilable_reason << " : " << s;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -246,9 +281,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
bool is_compilable = true;
|
||||
for (const Node* node : fbody->graph->op_nodes()) {
|
||||
stack_trace->emplace_back(StackFrameView{node->name(), call_def.op()});
|
||||
is_compilable &=
|
||||
IsCompilableNode(*node, lib_runtime, stack_trace, uncompilable_nodes);
|
||||
stack_trace->emplace_back(StackFrameView{node->name(), function.name()});
|
||||
is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
|
||||
&function, uncompilable_nodes);
|
||||
stack_trace->pop_back();
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
}
|
||||
@ -279,12 +314,14 @@ bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
auto stack_depth = stack_trace->size();
|
||||
if (node.IsSource() || node.IsSink()) {
|
||||
absl::string_view uncompilable_reason = "source or sink node";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -295,7 +332,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
|
||||
absl::string_view uncompilable_reason = "top level _Arg or _Retval";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -307,33 +344,36 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
absl::string_view uncompilable_reason =
|
||||
"_scoped_allocator or _forward_from attribute";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
string uncompilable_reason;
|
||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
|
||||
if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
|
||||
uncompilable_nodes)) {
|
||||
encapsulating_function, uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported function");
|
||||
return false;
|
||||
}
|
||||
} else if (!HasXLAKernel(node)) {
|
||||
absl::string_view uncompilable_reason = "unsupported op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
} else if (!HasXLAKernel(node, &uncompilable_reason)) {
|
||||
MaybeMarkUncompilableNode(
|
||||
absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.IsWhileNode() &&
|
||||
!IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
!IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported while");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.IsIfNode() &&
|
||||
!IsCompilableIf(node, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
!IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported if");
|
||||
return false;
|
||||
}
|
||||
@ -342,7 +382,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
IsStatefulRandomOp(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "stateful random op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -350,7 +390,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
|
||||
absl::string_view uncompilable_reason = "not allowed control trigger";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -359,7 +399,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
IsAssertOrCheckNumerics(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "Assert or CheckNumerics";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -368,7 +408,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
OpProducesOrConsumesVariant(node)) {
|
||||
absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -376,7 +416,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
|
||||
absl::string_view uncompilable_reason = "Stack op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -384,7 +424,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
|
||||
absl::string_view uncompilable_reason = "TensorArray op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -394,7 +434,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
absl::string_view uncompilable_reason =
|
||||
"resource variable op in called function";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -406,7 +446,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
node.DebugString())
|
||||
.IgnoreError();
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -417,7 +457,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
node.DebugString())
|
||||
.IgnoreError();
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -446,8 +486,9 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
/*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list) {
|
||||
if (!uncompilable_node_list) return;
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
|
||||
if (!uncompilable_nodes) return;
|
||||
|
||||
UncompilableNodeInfo node_info;
|
||||
node_info.uncompilable_reason = std::string(reason);
|
||||
@ -459,7 +500,20 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
});
|
||||
|
||||
node_info.name = std::string(stack_trace.back().name);
|
||||
(*uncompilable_node_list).push_back(std::move(node_info));
|
||||
auto function =
|
||||
encapsulating_function ? *encapsulating_function : NameAttrList();
|
||||
auto function_identifier = function.ShortDebugString();
|
||||
|
||||
auto it = uncompilable_nodes->find(function_identifier);
|
||||
if (it == uncompilable_nodes->end()) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompileable_node_info{std::move(node_info)};
|
||||
uncompilable_nodes->emplace(
|
||||
std::move(function_identifier),
|
||||
std::make_pair(function, std::move(uncompileable_node_info)));
|
||||
} else {
|
||||
it->second.second.emplace_back(std::move(node_info));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
@ -129,19 +130,35 @@ class RecursiveCompilabilityChecker {
|
||||
const DeviceType* jit_device_type)
|
||||
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
|
||||
|
||||
// Returns a list of uncompilable nodes. When `node` is inside a function
|
||||
// body, users can set `node_stack_trace` to provide an additional
|
||||
// context for `node`'s placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
using UncompilableNodesMap =
|
||||
std::map<std::string,
|
||||
std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
|
||||
|
||||
// Returns a map where the key is the function identifier(short debug
|
||||
// string) of the function encapsulating the uncompilable nodes, and the
|
||||
// value is a pair of NameAttrList of the function and a vector of
|
||||
// uncompilable node info. When uncompilable node is not inside any
|
||||
// function call nodes, then key is a ShortDebugString() of an empty
|
||||
// NameAttrList.
|
||||
//
|
||||
// Also, when `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
UncompilableNodesMap FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns a list of uncompilable nodes in `call_def` that cannot be
|
||||
// compiled by XLA. It is assumed that `call_def` is a call operation.
|
||||
// When `node` is inside a function body, users can set
|
||||
// Returns a map where the key is the function identifier(short debug
|
||||
// string) of the function encapsulating the uncompilable nodes, and the
|
||||
// value is a pair of NameAttrList of the function and a vector of
|
||||
// uncompilable node info. When uncompilable node is not inside any
|
||||
// function call nodes, then key is a ShortDebugString() of an empty
|
||||
// NameAttrList.
|
||||
//
|
||||
// Also, when `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
UncompilableNodesMap FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
@ -176,27 +193,31 @@ class RecursiveCompilabilityChecker {
|
||||
bool IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
NameAttrList* encapsulating_function = nullptr,
|
||||
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableIf(
|
||||
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
bool IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
NameAttrList* encapsulating_function = nullptr,
|
||||
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
bool IsCompilableWhile(const Node& while_node,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||
// name `attr_name`.
|
||||
bool ExtractNodeDefAndCheckCompilability(
|
||||
const Node& node, const std::string& attr_name,
|
||||
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::string& call_name, NameAttrList* encapsulating_function,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
bool IsStackOp(const Node& node) const {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
@ -226,12 +247,14 @@ class RecursiveCompilabilityChecker {
|
||||
absl::c_any_of(node.output_types(), is_variant);
|
||||
}
|
||||
|
||||
bool HasXLAKernel(const Node& node) const;
|
||||
bool HasXLAKernel(const Node& node,
|
||||
string* uncompilable_reason = nullptr) const;
|
||||
|
||||
static void MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list);
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes_map);
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 10;
|
||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -117,10 +119,16 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
|
||||
const auto uncompilable_nodes =
|
||||
checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
|
||||
ASSERT_EQ(1, uncompilable_nodes.size());
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
ASSERT_EQ(1, node_info.stack_trace.size());
|
||||
ASSERT_EQ("", node_info.stack_trace.at(0).function_name);
|
||||
auto node_info_it =
|
||||
uncompilable_nodes.find(NameAttrList().ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
|
||||
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
|
||||
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
|
||||
@ -147,14 +155,21 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
|
||||
checker_->FindUncompilableNodes(*functional_node, flib_runtime);
|
||||
|
||||
EXPECT_EQ(1, uncompilable_nodes.size());
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
NameAttrList function;
|
||||
function.set_name(kUncompilableFunctionName);
|
||||
const auto node_info_it =
|
||||
uncompilable_nodes.find(function.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_node_list = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& node_info = uncompilable_node_list.at(0);
|
||||
const auto& node_stack = node_info.stack_trace;
|
||||
ASSERT_EQ(2, node_stack.size());
|
||||
EXPECT_EQ("D", node_stack.at(0).name);
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
@ -212,7 +227,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
|
||||
ASSERT_EQ(1, uncompilable_nodes.size());
|
||||
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
NameAttrList function;
|
||||
function.set_name(kUncompilableFunctionName);
|
||||
const auto node_info_it =
|
||||
uncompilable_nodes.find(function.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_node_list = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& node_info = uncompilable_node_list.at(0);
|
||||
|
||||
const auto& node_stack = node_info.stack_trace;
|
||||
ASSERT_EQ(2, node_stack.size());
|
||||
const auto& stacktrace_first_node_info = node_stack.at(0);
|
||||
@ -225,7 +248,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
stacktrace_second_node_info.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
@ -280,7 +304,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
|
||||
ASSERT_EQ(2, uncompilable_nodes.size());
|
||||
|
||||
const auto& uncompilable_node_one = uncompilable_nodes.at(0);
|
||||
NameAttrList function_one;
|
||||
function_one.set_name(kUncompilableFunctionName);
|
||||
auto it = uncompilable_nodes.find(function_one.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), it);
|
||||
|
||||
const auto& uncompilable_node_list = it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& uncompilable_node_one = uncompilable_node_list.at(0);
|
||||
const auto& node_one_stack = uncompilable_node_one.stack_trace;
|
||||
|
||||
ASSERT_EQ(2, node_one_stack.size());
|
||||
@ -294,9 +325,17 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
stacktrace_second_node_info.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
|
||||
const auto& uncompilable_node_two = uncompilable_nodes.at(1);
|
||||
NameAttrList function_two;
|
||||
function_two.set_name(kUncompilableFunctionTwoName);
|
||||
it = uncompilable_nodes.find(function_two.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), it);
|
||||
|
||||
const auto& uncompilable_node_two_list = it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_two_list.size());
|
||||
const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
|
||||
const auto& node_two_stack = uncompilable_node_two.stack_trace;
|
||||
ASSERT_EQ(2, node_two_stack.size());
|
||||
const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
|
||||
@ -310,7 +349,8 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
node_two_stacktrace_second_node.function_name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_two.uncompilable_reason);
|
||||
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -29,6 +29,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace jit {
|
||||
class DeviceInfoCache;
|
||||
class DeviceSet;
|
||||
|
||||
// Instances of DeviceId represent TensorFlow devices as integers.
|
||||
//
|
||||
// This helps avoid having to manipulate device names as strings when
|
||||
|
@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
}
|
||||
|
||||
std::unique_ptr<DeviceMgr> device_mgr =
|
||||
absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
||||
|
@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
OptimizerOptions opts;
|
||||
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
|
@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
}
|
||||
|
||||
Status ExtractOutsideCompilationTest(
|
||||
|
@ -5,35 +5,48 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
XLA_OPS_DEPS = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
]
|
||||
|
||||
# Linked by tensorflow core, without registration of jit compilation passes.
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
name = "xla_ops_no_jit_rewrite_registration",
|
||||
srcs = ["xla_ops.cc"],
|
||||
hdrs = ["xla_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
deps = XLA_OPS_DEPS,
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
hdrs = ["xla_ops.h"],
|
||||
deps = XLA_OPS_DEPS + [
|
||||
":xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:jit_compilation_passes",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -313,6 +313,8 @@ static Status CompileToLocalExecutable(
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
}
|
||||
// TODO(b/138728225): Set options.alias_passthrough_params for clusters
|
||||
// without ref variables.
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
@ -397,9 +399,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
@ -595,6 +599,9 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
closure.executable()->executable()->module().input_output_alias_config();
|
||||
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
|
||||
@ -605,7 +612,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args()));
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||
input_output_alias));
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
||||
|
@ -1639,10 +1639,9 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(
|
||||
FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
|
||||
uncompilable_node_info) {
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
@ -1668,8 +1667,8 @@ bool IsCompilable(
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_result = checker.FindUncompilableNodes(ndef, flr);
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
@ -52,10 +52,9 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
bool IsCompilable(
|
||||
FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
|
||||
uncompilable_node_info = nullptr);
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
|
@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
|
||||
/*input_buffer_bytes=*/k_buffer_size,
|
||||
/*output_buffer_bytes=*/k_buffer_size,
|
||||
io::ZlibCompressionOptions::GZIP());
|
||||
string decompressed_pbtxt_string;
|
||||
tstring decompressed_pbtxt_string;
|
||||
Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
|
||||
if (!s.ok() && !errors::IsOutOfRange(s)) {
|
||||
// OutOfRange is fine since we set the number of read bytes to INT_MAX.
|
||||
|
@ -158,6 +158,7 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
: client_->default_device_ordinal());
|
||||
build_options.set_result_layout(result.xla_output_shape);
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
|
||||
auto compile_result =
|
||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||
|
@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->Run(launch_context.arguments(), run_options);
|
||||
TF_RETURN_IF_ERROR(run_result.status());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||
|
@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
||||
|
@ -14,243 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(const std::vector<int>* values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_->size() &&
|
||||
(*values_)[current_index_] <= value) {
|
||||
if ((*values_)[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const std::vector<int>* values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
const FunctionDef* function_def =
|
||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||
if (function_def == nullptr) {
|
||||
// The node def is not calling a function. Individual ops can be
|
||||
// run directly using on-demand mode, no need to create XlaLaunch
|
||||
// kernel for them.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||
if (it != node_def.attr().end()) {
|
||||
return it->second.b();
|
||||
}
|
||||
|
||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||
// FunctionDef.
|
||||
bool xla_compile = false;
|
||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||
node_def, kXlaCompileAttr, &xla_compile);
|
||||
if (!status.ok() || !xla_compile) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
if (!status.ok()) {
|
||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||
<< node_def.op() << ". status=" << status.ToString();
|
||||
} else {
|
||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return CanCreateXlaKernel(flr, node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
if (!CanCreateKernel(*flr, node_def)) {
|
||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||
}
|
||||
|
||||
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_info;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_node_info)) {
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
stack_frame.name, stack_frame.function_name);
|
||||
}
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
function.set_name(node_def.op());
|
||||
*(function.mutable_attr()) = node_def.attr();
|
||||
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
return s;
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
|
||||
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||
OpRegistry::Global(), proto);
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
|
259
tensorflow/compiler/jit/xla_kernel_creator_util.cc
Normal file
259
tensorflow/compiler/jit/xla_kernel_creator_util.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(const std::vector<int>* values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_->size() &&
|
||||
(*values_)[current_index_] <= value) {
|
||||
if ((*values_)[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const std::vector<int>* values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) {
|
||||
const FunctionDef* function_def =
|
||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||
if (function_def == nullptr) {
|
||||
// The node def is not calling a function. Individual ops can be
|
||||
// run directly using on-demand mode, no need to create XlaLaunch
|
||||
// kernel for them.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||
if (it != node_def.attr().end()) {
|
||||
return it->second.b();
|
||||
}
|
||||
|
||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||
// FunctionDef.
|
||||
bool xla_compile = false;
|
||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||
node_def, kXlaCompileAttr, &xla_compile);
|
||||
if (!status.ok() || !xla_compile) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
if (!status.ok()) {
|
||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||
<< node_def.op() << ". status=" << status.ToString();
|
||||
} else {
|
||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
if (!CanCreateXlaKernel(*flr, node_def)) {
|
||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||
}
|
||||
|
||||
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_info;
|
||||
for (const auto& it : uncompilable_nodes_map) {
|
||||
for (const auto& info : it.second.second) {
|
||||
uncompilable_node_info.emplace_back(info);
|
||||
}
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
stack_frame.name, stack_frame.function_name);
|
||||
}
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
function.set_name(node_def.op());
|
||||
*(function.mutable_attr()) = node_def.attr();
|
||||
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
return s;
|
||||
}
|
||||
} // namespace tensorflow
|
39
tensorflow/compiler/jit/xla_kernel_creator_util.h
Normal file
39
tensorflow/compiler/jit/xla_kernel_creator_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FunctionLibraryRuntime;
|
||||
class OpKernel;
|
||||
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def);
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
@ -247,9 +247,50 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
int output_num) {
|
||||
xla::ShapeIndex output_index;
|
||||
if (input_output_alias.shape().IsTuple()) {
|
||||
output_index = {output_num};
|
||||
} else {
|
||||
DCHECK_EQ(output_num, 0)
|
||||
<< "output_num must be 0 for non-tuple shapes but is " << output_num;
|
||||
output_index = {};
|
||||
}
|
||||
if (input_output_alias.shape().tuple_shapes_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return input_output_alias.OutputHasAlias(output_index) &&
|
||||
input_output_alias.GetAliasedParameter(output_index).value().kind ==
|
||||
xla::HloInputOutputAliasConfig::kUserAlias;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor XlaComputationLaunchContext::MakeOutputTensor(
|
||||
DataType type, const TensorShape& shape, se::DeviceMemoryBase buffer,
|
||||
int output_num, const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
Allocator* allocator) {
|
||||
bool is_aliased = false;
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
.parameter_number;
|
||||
DCHECK(arg_ptrs_[xla_param] != nullptr);
|
||||
buffer = arg_ptrs_[xla_param]->root_buffer();
|
||||
is_aliased = true;
|
||||
}
|
||||
return XlaTensorBuffer::MakeTensor(type, shape,
|
||||
/*unref_buffer=*/!is_aliased, buffer,
|
||||
allocator);
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix) {
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
@ -343,8 +384,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
return errors::Unimplemented(
|
||||
"Aliasing is not yet supported for allocate_xla_tensors_.");
|
||||
}
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
@ -359,8 +407,10 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
CHECK_EQ(output_tensor->TotalBytes(), 0);
|
||||
}
|
||||
} else {
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor =
|
||||
MakeOutputTensor(ctx->expected_output_dtype(i), shape, buffer,
|
||||
output_num, input_output_alias, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
@ -408,6 +458,10 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
}
|
||||
|
||||
if (allocate_xla_tensors_) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
return errors::Unimplemented(
|
||||
"Aliasing is not yet supported for allocate_xla_tensors_.");
|
||||
}
|
||||
Tensor output_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(write.type, write.shape, &output_tensor));
|
||||
@ -423,8 +477,9 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
Tensor output_tensor =
|
||||
MakeOutputTensor(write.type, write.shape, buffer, output_num,
|
||||
input_output_alias, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
}
|
||||
++output_num;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -149,16 +150,21 @@ class XlaComputationLaunchContext {
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output,
|
||||
int missing_ctx_input_prefix);
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
|
||||
|
||||
private:
|
||||
Tensor MakeOutputTensor(
|
||||
DataType type, const TensorShape& shape, se::DeviceMemoryBase buffer,
|
||||
int output_num, const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
Allocator* allocator);
|
||||
|
||||
xla::LocalClient* client_;
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
@ -193,12 +199,15 @@ class XlaTensorBuffer : public TensorBuffer {
|
||||
}
|
||||
|
||||
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
||||
bool unref_buffer, se::DeviceMemoryBase buffer,
|
||||
Allocator* allocator) {
|
||||
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
||||
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
|
||||
buffer.size(), allocator);
|
||||
Tensor t(dtype, shape, tensor_buffer);
|
||||
tensor_buffer->Unref();
|
||||
if (unref_buffer) {
|
||||
tensor_buffer->Unref();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = ["@local_config_mlir//:friends"],
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__subpackages__",
|
||||
"@local_config_mlir//:friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
@ -19,10 +22,20 @@ filegroup(
|
||||
srcs = glob(["**/*.td"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_name_mapper",
|
||||
srcs = ["op_name_mapper.cc"],
|
||||
hdrs = ["op_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_mlir_opt_main",
|
||||
srcs = ["tf_mlir_opt_main.cc"],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
@ -32,9 +45,12 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/xla",
|
||||
"//tensorflow/compiler/mlir/xla:lxla",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# TensorFlow MLIR
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/mlir
|
24
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
24
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Resources
|
||||
path: /resources
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /resources/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide & Tutorials
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /mlir/overview
|
||||
- heading: Dialects
|
||||
- title: TensorFlow
|
||||
path: /mlir/tf_ops
|
||||
- title: TensorFlow Lite
|
||||
path: /mlir/tfl_ops
|
||||
|
||||
- include: /_upper_tabs_right.yaml
|
48
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
48
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
items:
|
||||
- description: >
|
||||
The MLIR project defines a common intermediate representation (IR) that
|
||||
unifies the infrastructure required to execute high performance machine
|
||||
learning models in TensorFlow and similar ML frameworks. This project
|
||||
will include the application of HPC techniques, along with integration of
|
||||
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||
cost to bring up new hardware, and improve usability for existing
|
||||
TensorFlow users.
|
||||
|
||||
- code_block: |
|
||||
<pre class = "prettyprint">
|
||||
// Syntactically similar to LLVM:
|
||||
func @testFunction(%arg0: i32) {
|
||||
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||
br ^bb1
|
||||
^bb1:
|
||||
%y = addi %x, %x : i32
|
||||
return %y : i32
|
||||
}
|
||||
</pre>
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||
youtube_id: qzljG6DKgic
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||
- heading: "A new intermediate representation and compiler framework"
|
||||
image_path: /resources/images/tf-logo-card-16x9.png
|
||||
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
|
||||
buttons:
|
||||
- label: Read on TensorFlow blog
|
||||
path: https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d
|
||||
- heading: TensorFlow MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/mlir
|
11
tensorflow/compiler/mlir/g3doc/_project.yaml
Normal file
11
tensorflow/compiler/mlir/g3doc/_project.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
name: TensorFlow MLIR
|
||||
breadcrumb_name: MLIR
|
||||
home_url: /mlir/
|
||||
parent_project_metadata_path: /_project.yaml
|
||||
description: >
|
||||
MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
use_site_branding: true
|
||||
hide_from_products_list: true
|
||||
content_license: cc-apache
|
||||
buganizer_id: 443907
|
||||
include: /_project_included.yaml
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
5
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
5
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,5 @@
|
||||
# MLIR overview
|
||||
|
||||
## Overview
|
||||
|
||||
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
2761
tensorflow/compiler/mlir/g3doc/tf_ops.md
Normal file
2761
tensorflow/compiler/mlir/g3doc/tf_ops.md
Normal file
File diff suppressed because it is too large
Load Diff
1606
tensorflow/compiler/mlir/g3doc/tfl_ops.md
Normal file
1606
tensorflow/compiler/mlir/g3doc/tfl_ops.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||
load(
|
||||
"@local_config_mlir//:tblgen.bzl",
|
||||
"gentbl",
|
||||
@ -146,7 +146,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"utils/validators.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
@ -169,7 +168,6 @@ cc_library(
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
@ -187,6 +185,41 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_utils",
|
||||
srcs = [
|
||||
"utils/lstm_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/lstm_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "lstm_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/lstm_utils_test.cc"],
|
||||
deps = [
|
||||
":lstm_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
@ -194,16 +227,18 @@ cc_library(
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
"transforms/legalize_ophint_func_op.cc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/lower_static_tensor_list.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":common",
|
||||
":tensorflow_lite",
|
||||
@ -233,7 +268,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
@ -252,6 +286,7 @@ cc_library(
|
||||
name = "tensorflow_lite_quantize",
|
||||
srcs = [
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/load_quantization_recipe.cc",
|
||||
"transforms/post_quantize.cc",
|
||||
"transforms/prepare_quantize.cc",
|
||||
"transforms/quantize.cc",
|
||||
@ -260,7 +295,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
@ -305,7 +339,6 @@ cc_library(
|
||||
srcs = [
|
||||
"ir/dialect_registration.cc",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@local_config_mlir//:IR",
|
||||
@ -349,7 +382,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"flatbuffer_operator.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
@ -379,7 +411,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"emit_error_reporter.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"@local_config_mlir//:IR",
|
||||
@ -398,11 +429,11 @@ cc_library(
|
||||
"flatbuffer_translate.h",
|
||||
"utils/convert_type.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir:op_name_mapper",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
@ -450,7 +481,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
],
|
||||
@ -462,7 +492,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"common/tfl_pass_config.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
],
|
||||
@ -523,7 +552,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"tf_tfl_passes.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":common",
|
||||
":tensorflow_lite_legalize_tf",
|
||||
@ -531,6 +559,7 @@ cc_library(
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
@ -552,7 +581,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"tf_to_tfl_flatbuffer.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
":tensorflow_lite",
|
||||
|
@ -98,8 +98,7 @@ Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
|
||||
if (tensor.name.empty()) {
|
||||
return base;
|
||||
}
|
||||
return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base,
|
||||
builder.getContext());
|
||||
return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
|
||||
}
|
||||
|
||||
// Returns the correct type for a quantized tensor
|
||||
@ -478,8 +477,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc,
|
||||
builder.getContext());
|
||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
||||
|
||||
// Construct function type
|
||||
for (auto input : subgraph.inputs) {
|
||||
|
@ -53,9 +53,10 @@ limitations under the License.
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/op_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils//convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -89,6 +90,8 @@ using mlir::TranslateFromMLIRRegistration;
|
||||
using mlir::Type;
|
||||
using mlir::UnknownLoc;
|
||||
using mlir::Value;
|
||||
using tensorflow::OpLocNameMapper;
|
||||
using tensorflow::OpNameMapper;
|
||||
using tensorflow::Status;
|
||||
using tflite::flex::IsWhitelistedFlexOp;
|
||||
using xla::StatusOr;
|
||||
@ -96,7 +99,10 @@ using xla::StatusOr;
|
||||
template <typename T>
|
||||
using BufferOffset = flatbuffers::Offset<T>;
|
||||
|
||||
using CustomOptionsOffset = BufferOffset<flatbuffers::Vector<uint8_t>>;
|
||||
template <typename T>
|
||||
using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
|
||||
|
||||
using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
|
||||
|
||||
namespace error = tensorflow::error;
|
||||
namespace tfl = mlir::TFL;
|
||||
@ -341,16 +347,16 @@ class Translator {
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops,
|
||||
bool strip_debug_info);
|
||||
OpNameMapper* op_name_mapper);
|
||||
|
||||
private:
|
||||
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
||||
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
bool strip_debug_info)
|
||||
OpNameMapper* op_name_mapper)
|
||||
: module_(module),
|
||||
builder_(kInitialBufferSize),
|
||||
strip_debug_info_(strip_debug_info) {
|
||||
name_mapper_(*op_name_mapper),
|
||||
builder_(kInitialBufferSize) {
|
||||
// The first buffer must be empty according to the schema definition.
|
||||
empty_buffer_ = tflite::CreateBuffer(builder_);
|
||||
buffers_.push_back(empty_buffer_);
|
||||
@ -369,10 +375,6 @@ class Translator {
|
||||
|
||||
Optional<std::string> TranslateInternal();
|
||||
|
||||
// Returns name that should be used by tensors for values generated by this
|
||||
// operation.
|
||||
std::string GetName(Operation* inst);
|
||||
|
||||
// Returns TFLite buffer populated with constant value if the operation is
|
||||
// TFLite constant operation. Otherwise, returns an empty buffer. Emits error
|
||||
// and returns llvm::None on failure.
|
||||
@ -416,6 +418,15 @@ class Translator {
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
StringRef content);
|
||||
|
||||
// Encodes the `tfl.metadata` dictionary attribute of the module to the
|
||||
// metadata section in the final model.
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
CreateMetadataVector();
|
||||
|
||||
// Uses the tf.entry_function attribute (if set) to initialize the op to name
|
||||
// mapping.
|
||||
void InitializeNamesFromAttribute(FuncOp fn);
|
||||
@ -427,11 +438,10 @@ class Translator {
|
||||
// Returns a unique name for `op`.
|
||||
std::string UniqueName(mlir::Operation* op);
|
||||
|
||||
// Returns a unique name starting with a given prefix.
|
||||
std::string UniqueName(llvm::StringRef prefix);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
tensorflow::OpNameMapper& name_mapper_;
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder_;
|
||||
BufferOffset<tflite::Buffer> empty_buffer_;
|
||||
|
||||
@ -446,61 +456,14 @@ class Translator {
|
||||
absl::flat_hash_map<std::string, int> subgraph_index_map_;
|
||||
absl::flat_hash_set<OpType> enabled_op_types_;
|
||||
|
||||
// Maps from op to name.
|
||||
absl::flat_hash_map<mlir::Operation*, std::string> op_to_name_;
|
||||
absl::flat_hash_map<std::string, int64_t> name_to_count_;
|
||||
|
||||
// Points to TensorFlow and TFLite dialects, respectively. nullptr if the
|
||||
// dialect is not registered.
|
||||
const Dialect* tf_dialect_;
|
||||
const Dialect* tfl_dialect_;
|
||||
|
||||
// Suffix used to generate unique tensor names from operation names.
|
||||
int name_counter_ = 0;
|
||||
|
||||
// Whether to strip or not emit debug info.
|
||||
const bool strip_debug_info_;
|
||||
};
|
||||
|
||||
std::string Translator::GetName(Operation* inst) {
|
||||
// If strip_debug_info_ is set, then simply return counter value.
|
||||
if (strip_debug_info_) return Twine(name_counter_++).str();
|
||||
|
||||
if (auto name_loc = inst->getLoc().dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().str();
|
||||
|
||||
if (auto call_loc = inst->getLoc().dyn_cast<mlir::CallSiteLoc>()) {
|
||||
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
|
||||
// if imported with DebugInfo), else use the fallback naming scheme below.
|
||||
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().str();
|
||||
}
|
||||
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type.
|
||||
return inst->getName().getStringRef().str();
|
||||
}
|
||||
|
||||
std::string Translator::UniqueName(llvm::StringRef prefix) {
|
||||
// Keep incrementing the counter until we find a unique name.
|
||||
std::string name = prefix;
|
||||
int64_t& prefix_count = name_to_count_[name];
|
||||
int64_t val = prefix_count;
|
||||
while (val != 0) {
|
||||
name = (prefix + Twine(prefix_count)).str();
|
||||
++prefix_count;
|
||||
val = name_to_count_[name];
|
||||
}
|
||||
name_to_count_[name] = 1;
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string Translator::UniqueName(mlir::Operation* op) {
|
||||
auto& name = op_to_name_[op];
|
||||
if (!name.empty()) return name;
|
||||
// Update the value in the map with unique name.
|
||||
name = UniqueName(GetName(op));
|
||||
return name;
|
||||
return name_mapper_.GetUniqueName(op);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
@ -867,8 +830,8 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) {
|
||||
return;
|
||||
}
|
||||
for (auto it : llvm::enumerate(fn.getArguments())) {
|
||||
op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
|
||||
++name_to_count_[input_names[it.index()].str()];
|
||||
name_mapper_.InitOpName(*it.value()->user_begin(),
|
||||
input_names[it.index()]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -888,8 +851,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) {
|
||||
// insert an op so that we can have a buffer named such. This cannot
|
||||
// currently happen due to pseudo_input nodes.
|
||||
if (auto op = it.value()->getDefiningOp()) {
|
||||
op_to_name_[op] = output_names[it.index()];
|
||||
name_to_count_[output_names[it.index()].str()] = 1;
|
||||
name_mapper_.InitOpName(op, output_names[it.index()]);
|
||||
} else {
|
||||
fn.emitWarning() << "output is not due to an op and '"
|
||||
<< output_names[it.index()]
|
||||
@ -1027,14 +989,44 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
StringRef content) {
|
||||
auto buffer_index = buffers_.size();
|
||||
auto buffer_data = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(content.data()), content.size());
|
||||
buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
|
||||
return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
|
||||
}
|
||||
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
Translator::CreateMetadataVector() {
|
||||
auto dict_attr = module_.getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
|
||||
if (!dict_attr) return VectorBufferOffset<BufferOffset<tflite::Metadata>>();
|
||||
|
||||
std::vector<BufferOffset<tflite::Metadata>> metadata;
|
||||
for (const auto& named_attr : dict_attr) {
|
||||
StringRef name = named_attr.first;
|
||||
mlir::Attribute attr = named_attr.second;
|
||||
if (auto content = attr.dyn_cast<StringAttr>()) {
|
||||
metadata.push_back(BuildMetadata(name, content.getValue()));
|
||||
} else {
|
||||
module_.emitError(
|
||||
"all values in tfl.metadata's dictionary key-value pairs should be "
|
||||
"string attributes");
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
return builder_.CreateVector(metadata);
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::Translate(ModuleOp module,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops,
|
||||
bool strip_debug_info) {
|
||||
OpNameMapper* op_name_mapper) {
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, strip_debug_info);
|
||||
emit_custom_ops, op_name_mapper);
|
||||
return translator.TranslateInternal();
|
||||
}
|
||||
|
||||
@ -1074,12 +1066,17 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
} else {
|
||||
model_description = "MLIR Converted.";
|
||||
}
|
||||
|
||||
// Build the model and finish the model building process.
|
||||
auto description = builder_.CreateString(model_description.data());
|
||||
VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
|
||||
auto metadata = CreateMetadataVector();
|
||||
if (!metadata) return llvm::None;
|
||||
|
||||
auto model = tflite::CreateModel(
|
||||
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
|
||||
builder_.CreateVector(subgraphs), description,
|
||||
builder_.CreateVector(buffers_));
|
||||
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
|
||||
tflite::FinishModelBuffer(builder_, model);
|
||||
|
||||
// Return serialized string for the built FlatBuffer.
|
||||
@ -1100,22 +1097,38 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
//
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops) {
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
OpNameMapper* op_name_mapper) {
|
||||
auto maybe_translated =
|
||||
Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, strip_debug_info_flag);
|
||||
emit_custom_ops, op_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops) {
|
||||
OpLocNameMapper op_name_mapper;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, &op_name_mapper);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction(
|
||||
ModuleOp module, llvm::StringRef filename) {
|
||||
std::string serialized_flatbuffer;
|
||||
std::unique_ptr<OpNameMapper> op_name_mapper;
|
||||
if (strip_debug_info) {
|
||||
op_name_mapper = std::make_unique<tensorflow::OpStripNameMapper>();
|
||||
} else {
|
||||
op_name_mapper = std::make_unique<OpLocNameMapper>();
|
||||
}
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops))
|
||||
emit_select_tf_ops, emit_custom_ops, op_name_mapper.get()))
|
||||
return mlir::failure();
|
||||
|
||||
auto file = openOutputFile(filename);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/op_name_mapper.h"
|
||||
|
||||
// These flags are used to control the emission or not of different kinds of ops
|
||||
// during the flatbuffer translation.
|
||||
@ -33,12 +34,19 @@ extern bool strip_debug_info;
|
||||
namespace tflite {
|
||||
|
||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||
// serialized flatbuffer into the string.
|
||||
// serialized flatbuffer into the string. This uses OpLocNameMapper to convert
|
||||
// location of the op to name in flatbuffer.
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
std::string *serialized_flatbuffer,
|
||||
std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops);
|
||||
|
||||
// Same as the above but with a custom op name mapper.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
tensorflow::OpNameMapper* op_name_mapper);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_
|
||||
|
@ -655,14 +655,40 @@ static LogicalResult Verify(UnpackOp op) {
|
||||
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
|
||||
ElementsAttr attr;
|
||||
if (!matchPattern(value, m_Constant(&attr))) return {};
|
||||
|
||||
IntegerAttr int_attr = attr.getValue(llvm::None).cast<IntegerAttr>();
|
||||
return int_attr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
// Returns a RankedTensorType which is similar to `input_type` but replaces the
|
||||
// dimension size of `dim` with `dim_size`. For example,
|
||||
// `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
|
||||
// `tensor<3x2xi32>`.
|
||||
static RankedTensorType SubstituteRankedTensorTypeDimSize(
|
||||
RankedTensorType input_type, int64_t dim, int64_t dim_size) {
|
||||
auto shape = input_type.getShape().vec();
|
||||
shape[dim] = dim_size;
|
||||
return RankedTensorType::get(shape, input_type.getElementType());
|
||||
}
|
||||
|
||||
// Verifies the output tensor types of SplitOp or SplitVOp.
|
||||
template <typename ExpectedOutputTypeGetter>
|
||||
static LogicalResult VerifySplitOpOutputTypes(
|
||||
Operation *op, int64_t num_splits,
|
||||
ExpectedOutputTypeGetter get_expected_output_type) {
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
auto expected_output_type = get_expected_output_type(i);
|
||||
Value *output = op->getResult(i);
|
||||
auto output_type = output->getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_type || output_type != expected_output_type)
|
||||
return op->emitOpError()
|
||||
<< "output #" << i << " should be " << expected_output_type;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SplitOp op) {
|
||||
int64_t num_splits = op.num_splits().getSExtValue();
|
||||
if (op.getOperation()->getNumResults() != num_splits)
|
||||
if (op.getNumResults() != num_splits)
|
||||
return op.emitOpError("output count should match 'num_splits' attribute");
|
||||
|
||||
// If 'split_dim' is not a constant, there are no other checks.
|
||||
@ -688,21 +714,100 @@ static LogicalResult Verify(SplitOp op) {
|
||||
if (dim_size % num_splits != 0)
|
||||
return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
|
||||
|
||||
// Creates sliced tensor type.
|
||||
auto slice_shape = input_type.getShape().vec();
|
||||
slice_shape[split_dim] = dim_size / num_splits;
|
||||
RankedTensorType slice_type =
|
||||
RankedTensorType::get(slice_shape, input_type.getElementType());
|
||||
// Verifies output tensor types.
|
||||
RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
|
||||
input_type, split_dim, dim_size / num_splits);
|
||||
return VerifySplitOpOutputTypes(
|
||||
op.getOperation(), num_splits,
|
||||
[expected_output_type](int64_t) { return expected_output_type; });
|
||||
}
|
||||
|
||||
// Verifies result tensor types.
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
Value *result = op.getResult(i);
|
||||
auto result_type = result->getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type || result_type != slice_type)
|
||||
return op.emitOpError() << "output #" << i << " should be " << slice_type;
|
||||
static LogicalResult Verify(SplitVOp op) {
|
||||
int64_t num_splits = op.num_splits().getSExtValue();
|
||||
if (op.getNumResults() != num_splits)
|
||||
return op.emitOpError("output count should match 'num_splits' attribute");
|
||||
|
||||
// If 'split_dim' is not a constant, there are no other checks.
|
||||
llvm::Optional<int64_t> split_dim_opt =
|
||||
ExtractConstantIntFromTensor(op.split_dim());
|
||||
if (!split_dim_opt) return success();
|
||||
|
||||
// If 'input' is not a ranked tensor, there are no other checks.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t split_dim = split_dim_opt.getValue();
|
||||
const int64_t rank = input_type.getRank();
|
||||
if (split_dim < 0) split_dim += rank;
|
||||
if (split_dim < 0 || split_dim >= rank)
|
||||
return op.emitOpError("'split_dim' should be in [-rank, rank)");
|
||||
|
||||
// If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
|
||||
// there are no other checks.
|
||||
const int64_t dim_size = input_type.getDimSize(split_dim);
|
||||
if (ShapedType::isDynamic(dim_size)) return success();
|
||||
|
||||
// If 'size_splits' is not a constant, there are no other checks.
|
||||
ElementsAttr size_splits_attr;
|
||||
if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
|
||||
return success();
|
||||
|
||||
if (size_splits_attr.getNumElements() != num_splits) {
|
||||
auto size_splits_type =
|
||||
op.size_splits()->getType().cast<RankedTensorType>();
|
||||
RankedTensorType expected_size_splits_type =
|
||||
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
|
||||
return op.emitOpError("'size_splits' should be ")
|
||||
<< expected_size_splits_type;
|
||||
}
|
||||
|
||||
return success();
|
||||
// Normalizes and verifies 'size_splits'.
|
||||
// Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
|
||||
// means the rest of the dimension size.
|
||||
llvm::SmallVector<int64_t, 4> size_splits;
|
||||
size_splits.reserve(num_splits);
|
||||
|
||||
int64_t negative_size_split_loc = -1;
|
||||
int64_t total_size_splits = 0;
|
||||
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
|
||||
int64_t size_split = size_split_attr.getValue().getSExtValue();
|
||||
size_splits.push_back(size_split);
|
||||
if (size_split >= 0) {
|
||||
total_size_splits += size_split;
|
||||
continue;
|
||||
}
|
||||
if (size_split < -1)
|
||||
return op.emitOpError(
|
||||
"elements of 'size_splits' should be greater than or equal to -1");
|
||||
if (negative_size_split_loc != -1)
|
||||
return op.emitOpError("'size_splits' can only have one -1");
|
||||
negative_size_split_loc = i;
|
||||
}
|
||||
|
||||
if (negative_size_split_loc != -1) {
|
||||
if (total_size_splits > dim_size)
|
||||
return op.emitOpError(
|
||||
"sum of non-negative elements of 'size_splits' is greater than the "
|
||||
"dimension size of 'split_dim' axis");
|
||||
size_splits[negative_size_split_loc] = dim_size - total_size_splits;
|
||||
total_size_splits = dim_size;
|
||||
}
|
||||
|
||||
if (total_size_splits != dim_size)
|
||||
return op.emitOpError(
|
||||
"sum of 'size_splits' should match the dimension size of 'split_dim' "
|
||||
"axis");
|
||||
|
||||
// Verifies result tensor types.
|
||||
auto get_expected_output_type = [input_type, split_dim,
|
||||
&size_splits](int64_t i) {
|
||||
return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
|
||||
size_splits[i]);
|
||||
};
|
||||
return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
|
||||
get_expected_output_type);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -132,15 +132,35 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
// Rank/Shape helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TFL_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
// TODO: Some of these could be generalized and/or moved to more general
|
||||
// location.
|
||||
// Returns true if the n-th operand has unknown rank or has rank m.
|
||||
class TFL_OperandHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
// CPred version of TFL_OperandHasRank.
|
||||
class TFL_OperandHasRankPred<int n, int m> :
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>;
|
||||
|
||||
// True if operand n is ranked and has a rank > dim.
|
||||
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
|
||||
# dim>]>;
|
||||
|
||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] == " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
@ -155,6 +175,32 @@ class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[0]">>;
|
||||
|
||||
// True if x_shape[dim] == y_shape[dim].
|
||||
class TFL_DimOfOperandEqualsDimOfOperandPred<int x, int y, int dim> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<x, dim>,
|
||||
TFL_OperandIsRankedAndHasDimPred<y, dim>,
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
")->getType().cast<ShapedType>().getShape()[" # dim # "] == "
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[" # dim # "]">]>;
|
||||
|
||||
// Select operands must satisfy one of the following constraints:
|
||||
// All inputs are unranked/scalars
|
||||
// OR
|
||||
// All inputs are ranked AND have equal dim[0] AND X & Y have same rank.
|
||||
def SelectShapeConstraints :
|
||||
PredOpTrait<"Select operands meet shape criteria",
|
||||
Or<[
|
||||
And<[
|
||||
TFL_OperandHasRankPred<0, 0>,
|
||||
TFL_OperandHasRankPred<1, 0>,
|
||||
TFL_OperandHasRankPred<2, 0>]>,
|
||||
And<[
|
||||
TFL_DimOfOperandEqualsDimOfOperandPred<0, 1, 0>,
|
||||
TFL_DimOfOperandEqualsDimOfOperandPred<0, 2, 0>,
|
||||
CPred<"$_op.getOperand(1)->getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(2)->getType().cast<ShapedType>().getRank()">]>]>>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
TCOpResIsShapedTypePred<i, j>,
|
||||
@ -239,7 +285,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
}
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -315,7 +361,7 @@ def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
|
||||
|
||||
// TODO(haoliang): Implement legalization pass after pattern rewrite generator
|
||||
// supports variadic inputs.
|
||||
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
|
||||
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "add_n operator";
|
||||
|
||||
let description = [{
|
||||
@ -323,11 +369,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I32]>>:$inputs
|
||||
Variadic<TensorOf<[F32, I32, QI16, QUI16]>>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32]>:$sum
|
||||
TensorOf<[F32, I32, QI16, QUI16]>:$sum
|
||||
);
|
||||
}
|
||||
|
||||
@ -359,7 +405,7 @@ retained with length 1.
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Average_pool_2d operator";
|
||||
|
||||
let description = [{
|
||||
@ -454,7 +500,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_SameOperandsAndResultsScale
|
||||
SameOperandsAndResultsScale
|
||||
]> {
|
||||
let summary = "Concatenation operator";
|
||||
|
||||
@ -464,14 +510,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
|
||||
let arguments = (
|
||||
ins Variadic<TensorOf<
|
||||
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>>:$values,
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>>:$values,
|
||||
I32Attr:$axis,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<
|
||||
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -529,13 +575,13 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||
// TODO(jpienaar): Include more shape verification.
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$input,
|
||||
TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, TFL_QI32, TFL_QUI32]>:$bias,
|
||||
TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
|
||||
TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
|
||||
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format,
|
||||
@ -544,7 +590,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
|
||||
// Depending on the weights format, this op can have one or two outputs.
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, TFL_QI8, TFL_QUI8, TFL_QI16, TFL_QUI16]>>:$output
|
||||
Variadic<TensorOf<[F32, QI8, QUI8, QI16, QUI16]>>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -552,7 +598,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
|
||||
def TFL_GatherOp : TFL_Op<"gather", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasAtleastRank<0, 1>,
|
||||
PredOpTrait<"params and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
@ -564,7 +610,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$params,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
|
||||
TensorOf<[I32, I64]>:$indices,
|
||||
I32Attr:$axis
|
||||
);
|
||||
@ -577,7 +623,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
];
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Str, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -602,7 +648,7 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
|
||||
|
||||
// Same type check of lhs and rhs is handled by the Broadcastable trait.
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -610,8 +656,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$lhs,
|
||||
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$rhs);
|
||||
ins TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs,
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -643,7 +689,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
I32Attr:$radius,
|
||||
F32Attr:$bias,
|
||||
F32Attr:$alpha,
|
||||
@ -651,14 +697,14 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -680,8 +726,119 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
// These ops are named NonMaxSuppressionV4 & NonMaxSuppressionV5 to be
|
||||
// consistent with TensorFlow's naming. They are NOT 'versions' of NMS in the
|
||||
// sense that one is an incremental change over the other.
|
||||
// In reality NonMaxSuppressionV5 implements Soft Non Max Suppression and
|
||||
// NonMaxSuppressionV4 performs hard NMS.
|
||||
|
||||
def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [
|
||||
NoSideEffect,
|
||||
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
|
||||
TFL_OperandHasRank<0, 2>,
|
||||
PredOpTrait<"boxes should have dim[1] == 4",
|
||||
TFL_OperandDimEquals<0, 1, 4>>,
|
||||
// Operand 1 (scores) should be a 1-dim tensor
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
// Other operands are scalar params.
|
||||
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
|
||||
TFL_OperandHasRank<4, 0>]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_FpTensor:$boxes,
|
||||
TFL_FpTensor:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TFL_FpTensor:$iou_threshold,
|
||||
TFL_FpTensor:$score_threshold
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [
|
||||
NoSideEffect,
|
||||
// Operand 0 (boxes) should have rank 2 with the dim[1] == 4 (box corners)
|
||||
TFL_OperandHasRank<0, 2>,
|
||||
PredOpTrait<"boxes should have dim[1] == 4",
|
||||
TFL_OperandDimEquals<0, 1, 4>>,
|
||||
// Operand 1 (scores) should be a 1-dim tensor
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
// Other operands are scalar params.
|
||||
TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>,
|
||||
TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> {
|
||||
let summary = [{
|
||||
Greedily selects a subset of bounding boxes in descending order of score,
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
pruning away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes with score less than
|
||||
`score_threshold` are removed. Bounding boxes are supplied as
|
||||
[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
|
||||
is agnostic to where the origin is in the coordinate system and more
|
||||
generally is invariant to orthogonal transformations and translations
|
||||
of the coordinate system; thus translating or reflections of the coordinate
|
||||
system result in the same boxes being selected by the algorithm.
|
||||
The output of this operation is a set of integers indexing into the input
|
||||
collection of bounding boxes representing the selected boxes. The bounding
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather operation`. For example:
|
||||
selected_indices = tf.image.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
This op also supports a Soft-NMS (with Gaussian weighting) mode (c.f.
|
||||
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
|
||||
of other overlapping boxes instead of directly causing them to be pruned.
|
||||
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
|
||||
larger than 0.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_FpTensor:$boxes,
|
||||
TFL_FpTensor:$scores,
|
||||
I32Tensor:$max_output_size,
|
||||
TFL_FpTensor:$iou_threshold,
|
||||
TFL_FpTensor:$score_threshold,
|
||||
TFL_FpTensor:$soft_nms_sigma
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$selected_indices,
|
||||
TFL_FpTensor:$selected_scores,
|
||||
I32Tensor:$valid_outputs
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
Broadcastable, Commutative, NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Not_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -766,7 +923,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
}
|
||||
|
||||
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
|
||||
TFL_NoQuantizableResult,
|
||||
NoQuantizableResult,
|
||||
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
|
||||
let summary = "Equal operator";
|
||||
|
||||
@ -776,8 +933,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$x,
|
||||
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$y
|
||||
TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x,
|
||||
TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
@ -844,7 +1001,7 @@ size 1.
|
||||
}
|
||||
|
||||
def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Removes dimensions of size 1 from the shape of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -944,7 +1101,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -979,6 +1136,25 @@ def TFL_InputOp : Op<TFL_Dialect, "pseudo_input", [SameOperandsAndResultType]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> {
|
||||
let summary = "L2 Normalize Operator";
|
||||
|
||||
let description = [{
|
||||
L2Normalization Op
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$input,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, QUI8, QI8, QUI16, QI16, I8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let customOption = "L2NormOptions";
|
||||
}
|
||||
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Leaky Relu operator";
|
||||
|
||||
@ -1000,7 +1176,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1073,17 +1249,17 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
let summary = "Logistic operator";
|
||||
|
||||
let description = [{
|
||||
Computes element-wise Sigmoid of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$x);
|
||||
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$y);
|
||||
let results = (outs TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -1106,8 +1282,8 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<127, 625, -4>>,
|
||||
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<255, 625, -4>>]> {
|
||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<255, 625, -4>>]> {
|
||||
let summary = "Log softmax operator";
|
||||
|
||||
let description = [{
|
||||
@ -1133,13 +1309,13 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
|
||||
And<[
|
||||
// The input and output tensors should have the same elemental type
|
||||
// and they should be one of the specified types below.
|
||||
TCopVTEtIs<0, AnyTypeOf<[F32, TFL_QI8, TFL_QUI8]>>,
|
||||
TCopVTEtIs<0, AnyTypeOf<[F32, QI8, QUI8]>>,
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
|
||||
|
||||
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
NoSideEffect,
|
||||
MaxPoolOperandAndResultConstraints,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Max Pool 2D op";
|
||||
|
||||
let description = [{
|
||||
@ -1167,19 +1343,19 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
}
|
||||
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
|
||||
let summary = "Max operator";
|
||||
let description = [{
|
||||
Element-wise max operation.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
|
||||
ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$max
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -1187,7 +1363,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -1199,13 +1375,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_SameOperandsAndResultsScale]>
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||
TensorOf<[I32, I64]>:$axis,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output);
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -1256,7 +1432,7 @@ Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}
|
||||
|
||||
def TFL_SliceOp : TFL_Op<"slice", [
|
||||
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Return a slice from 'input'.";
|
||||
|
||||
let description = [{
|
||||
@ -1363,19 +1539,19 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
|
||||
let summary = "Min operator";
|
||||
let description = [{
|
||||
Element-wise min operation.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$lhs,
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$rhs
|
||||
ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, TFL_QI8, TFL_QUI8]>:$min
|
||||
TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -1422,7 +1598,7 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||
|
||||
let description = [{
|
||||
@ -1453,14 +1629,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64]>>:$values,
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -1470,7 +1646,7 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
|
||||
|
||||
def TFL_PadOp : TFL_Op<"pad", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>]> {
|
||||
let summary = "Padding operator";
|
||||
@ -1500,17 +1676,17 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
|
||||
ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding);
|
||||
|
||||
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
|
||||
let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandHasRank<2, 0>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
@ -1545,11 +1721,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
|
||||
ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding,
|
||||
TensorOf<[F32, I8, I32, I64]>:$constant_values);
|
||||
|
||||
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
|
||||
let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -1570,6 +1746,8 @@ def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect]> {
|
||||
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
|
||||
@ -1587,7 +1765,7 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
|
||||
|
||||
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -1602,7 +1780,7 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu6 operator";
|
||||
|
||||
let description = [{
|
||||
@ -1616,7 +1794,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
}
|
||||
|
||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Reshape operator";
|
||||
|
||||
let description = [{
|
||||
@ -1682,7 +1860,7 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Shape operator";
|
||||
|
||||
let description = [{
|
||||
@ -1756,8 +1934,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
}
|
||||
|
||||
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
// TODO(jpienaar): This is too retrictive, rank 1 input is also allowed.
|
||||
SameOperandsAndResultShape,
|
||||
SelectShapeConstraints,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -1810,12 +1987,12 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
let summary = "Softmax operator";
|
||||
|
||||
let description = [{
|
||||
Computes element-wise softmax activiations with the following formula
|
||||
Computes element-wise softmax activations with the following formula
|
||||
|
||||
exp(input) / tf.reduce_sum(exp(input * beta), dim)
|
||||
}];
|
||||
@ -1851,9 +2028,9 @@ def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
Computes element-wise Square of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$x);
|
||||
let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[AnyFloat, TFL_QI8, TFL_QUI8]>:$y);
|
||||
let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
|
||||
@ -1913,17 +2090,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<0, 78125, -7>>,
|
||||
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<128, 78125, -7>>]> {
|
||||
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
|
||||
let summary = "Hyperbolic tangent operator";
|
||||
|
||||
let description = [{
|
||||
Computes element-wise Hyperbolic tangent of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[F32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$x);
|
||||
let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[F32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$y);
|
||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
@ -1988,7 +2165,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
// TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Transpose operator";
|
||||
|
||||
let description = [{
|
||||
@ -2028,14 +2205,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
|
||||
|
||||
I32Attr:$num,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I8, I32, TFL_QI8, TFL_QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2059,7 +2236,7 @@ def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> {
|
||||
|
||||
def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
@ -2070,19 +2247,19 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
TensorOf<[I32]>:$block_shape,
|
||||
TensorOf<[I32]>:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
@ -2093,19 +2270,19 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
TensorOf<[I32]>:$block_shape,
|
||||
TensorOf<[I32]>:$paddings
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
@ -2119,12 +2296,12 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||
I32Attr:$block_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2132,7 +2309,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
|
||||
def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
|
||||
NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
@ -2148,21 +2325,18 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||
I32Attr:$block_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def Rank0I32Tensor : Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0]>]>,
|
||||
"tensor<i32>">;
|
||||
|
||||
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
|
||||
|
||||
let description = [{
|
||||
@ -2172,13 +2346,13 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Rank0I32Tensor:$split_dim,
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
|
||||
0DTensorOf<[I32]>:$split_dim,
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
|
||||
PositiveI32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2186,7 +2360,7 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
|
||||
|
||||
let description = [{
|
||||
@ -2196,21 +2370,23 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsSc
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
|
||||
I32Tensor:$size_splits,
|
||||
I32Tensor:$split_dim,
|
||||
I32Attr:$num_splits
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
|
||||
1DTensorOf<[I32]>:$size_splits,
|
||||
0DTensorOf<[I32]>:$split_dim,
|
||||
PositiveI32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "ResizeBilinear Op";
|
||||
|
||||
let description = [{
|
||||
@ -2219,12 +2395,12 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
|
||||
let arguments = (ins
|
||||
// TODO(ycling): Support quantized types.
|
||||
TensorOf<[F32, I32, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I32, QI8, QUI8]>:$input,
|
||||
TensorOf<[I32]>:$size,
|
||||
BoolAttr:$align_corners);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2232,7 +2408,7 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
|
||||
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
|
||||
[NoSideEffect,
|
||||
TFL_SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "ResizeNearestNeighbor Op";
|
||||
|
||||
let description = [{
|
||||
@ -2240,13 +2416,13 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, TFL_Uint8, TFL_QUI8, TFL_QI8]>:$input,
|
||||
TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input,
|
||||
TensorOf<[I32]>:$size,
|
||||
BoolAttr:$align_corners
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, TFL_Uint8, TFL_QUI8, TFL_QI8]>:$output
|
||||
TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2294,7 +2470,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_SameOperandsAndResultsScale
|
||||
SameOperandsAndResultsScale
|
||||
]> {
|
||||
let summary = "StridedSlice Op";
|
||||
|
||||
@ -2303,7 +2479,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8]>:$input,
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$input,
|
||||
TensorOf<[I32]>:$begin,
|
||||
TensorOf<[I32]>:$end,
|
||||
TensorOf<[I32]>:$strides,
|
||||
@ -2316,13 +2492,14 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8]>:$output
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
def TFL_CastOp : TFL_Op<"cast", [
|
||||
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
|
||||
let summary = "Cast operator";
|
||||
|
||||
let description = [{
|
||||
@ -2411,7 +2588,7 @@ in the unique output `y`. In other words:
|
||||
// Quantization ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_DequantizeOp: TFL_Op<"dequantize", [
|
||||
NoSideEffect, TFL_NoQuantizableResult]> {
|
||||
NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Dequantize operator";
|
||||
|
||||
let description = [{
|
||||
@ -2447,7 +2624,7 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
|
||||
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
|
||||
let summary = "Quantized constant pseudo op";
|
||||
|
||||
let description = [{
|
||||
@ -2465,7 +2642,7 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
}
|
||||
|
||||
def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
|
||||
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
|
||||
let summary = "Quantize operator";
|
||||
|
||||
let description = [{
|
||||
@ -2609,6 +2786,10 @@ Ba et al. “Layer Normalization”
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
// TODO(fengliuai): customize printer and parser to not display
|
||||
// empty region.
|
||||
let regions = (region AnyRegion:$internal);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
@ -15,7 +15,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"graphdef_to_tfl_flatbuffer.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||
|
@ -138,7 +138,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
mlir::PassManager pm;
|
||||
mlir::PassManager pm(module->getContext());
|
||||
bool run_quantize = tensorflow::ShouldRunQuantizePasses(module.get());
|
||||
mlir::TFL::PassConfig pass_config;
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
|
@ -33,7 +33,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"quantization_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for TensorFlow Lite.
|
||||
// This is the quantization definition file for TensorFlow.
|
||||
|
||||
#ifdef TFL_Quantization
|
||||
#ifdef TF_Quantization
|
||||
#else
|
||||
#define TFL_Quantization
|
||||
#define TF_Quantization
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
@ -46,7 +46,7 @@ def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// The base class of a quantized type.
|
||||
class TFL_QuantizedType<string n, list<int> params, bit signed>
|
||||
class QuantizedType<string n, list<int> params, bit signed>
|
||||
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
|
||||
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
|
||||
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
|
||||
@ -59,21 +59,21 @@ class TFL_QuantizedType<string n, list<int> params, bit signed>
|
||||
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
|
||||
// express the Mantissa and Exponent components of the floating-point scale so
|
||||
// the scale of the quantized type is "smantissa * 10 ^ sexp".
|
||||
class TFL_UInt8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
|
||||
: TFL_QuantizedType<"Uniform",
|
||||
class UInt8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
|
||||
: QuantizedType<"Uniform",
|
||||
[8, zero_pt, smantissa, sexp, 0, 255], 0>;
|
||||
class TFL_Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
|
||||
: TFL_QuantizedType<"Uniform",
|
||||
class Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
|
||||
: QuantizedType<"Uniform",
|
||||
[8, zero_pt, smantissa, sexp, -128, 127], 1>;
|
||||
|
||||
// General uniform quantized types. The definitions can be used to specify
|
||||
// operand's tensor types.
|
||||
def TFL_QUI8 : TFL_QuantizedType<"Uniform", [8], 0>;
|
||||
def TFL_QI8 : TFL_QuantizedType<"Uniform", [8], 1>;
|
||||
def TFL_QUI16 : TFL_QuantizedType<"Uniform", [16], 0>;
|
||||
def TFL_QI16 : TFL_QuantizedType<"Uniform", [16], 1>;
|
||||
def TFL_QUI32 : TFL_QuantizedType<"Uniform", [32], 0>;
|
||||
def TFL_QI32 : TFL_QuantizedType<"Uniform", [32], 1>;
|
||||
def QUI8 : QuantizedType<"Uniform", [8], 0>;
|
||||
def QI8 : QuantizedType<"Uniform", [8], 1>;
|
||||
def QUI16 : QuantizedType<"Uniform", [16], 0>;
|
||||
def QI16 : QuantizedType<"Uniform", [16], 1>;
|
||||
def QUI32 : QuantizedType<"Uniform", [32], 0>;
|
||||
def QI32 : QuantizedType<"Uniform", [32], 1>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL native op traits (for quantization).
|
||||
@ -83,23 +83,23 @@ def TFL_QI32 : TFL_QuantizedType<"Uniform", [32], 1>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Specify this trait if the op has a fixed output value range.
|
||||
class TFL_FixedResultScale<TFL_QuantizedType qt> : NativeOpTrait<!strconcat(
|
||||
"TFL::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
|
||||
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
||||
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
|
||||
|
||||
// Specify this trait if the op requires same inputs and outputs quantization
|
||||
// scales.
|
||||
def TFL_SameOperandsAndResultsScale : NativeOpTrait<
|
||||
"TFL::SameOperandsAndResultsScale">;
|
||||
def SameOperandsAndResultsScale : NativeOpTrait<
|
||||
"quant::SameOperandsAndResultsScale">;
|
||||
|
||||
// Specify this trait if the b-th input of the op is a bias input, which needs
|
||||
// a scale based on the scales of op1 and op2.
|
||||
class TFL_AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
||||
!strconcat("TFL::AccumulatorUniformScale<",
|
||||
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
||||
!strconcat("quant::AccumulatorUniformScale<",
|
||||
StrJoinInt<[bias, op1, op2]>.result,
|
||||
">::Impl")>;
|
||||
|
||||
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
|
||||
// apply quantization on this op.
|
||||
def TFL_NoQuantizableResult : NativeOpTrait<"TFL::NoQuantizableResult">;
|
||||
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;
|
||||
|
||||
#endif // TFL_Quantization
|
||||
#endif // TF_Quantization
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@ -118,17 +119,19 @@ class QuantizationDriver {
|
||||
// result.
|
||||
void Finalize();
|
||||
|
||||
// Whether the constant is used as a bias input of another op. Here we assume
|
||||
// bias is used immediately by the user. This assumption is always correct
|
||||
// after constant folding.
|
||||
bool UsedAsBias(ConstantOp cst) {
|
||||
Value *value = cst.getResult();
|
||||
for (auto &use : value->getUses()) {
|
||||
auto biases = GetQuantSpec(use.getOwner())->biases_params;
|
||||
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// The quantization parameters of bias operand are usually determined by
|
||||
// other operands, so if a constant is used by different ops as bias, it needs
|
||||
// to be duplicated, thus each op can assign its own quantization parameter
|
||||
// for this bias. Also this methods add all the non-bias constants to a set
|
||||
// for looking up later.
|
||||
void PreprocessConstantOps();
|
||||
|
||||
// Setup all the data structures for quantization propagation.
|
||||
void SetupAllStates();
|
||||
|
||||
// Whether the constant is a weight, which shouldn't be shared by different
|
||||
// ops.
|
||||
bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
|
||||
|
||||
// Returns all the related quantization constraints of the op.
|
||||
std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
|
||||
@ -266,6 +269,11 @@ class QuantizationDriver {
|
||||
OpBuilder builder_;
|
||||
bool is_signed_;
|
||||
|
||||
// We should distinguish weights and bias constants. Biases are specified by
|
||||
// the quantization spec or are the operands of ops with same scale spec. The
|
||||
// rest are weights.
|
||||
llvm::DenseSet<Operation *> weights_;
|
||||
|
||||
// All the ops needs to propagate the quantization parameters to.
|
||||
std::vector<Operation *> work_list_;
|
||||
std::unordered_set<Operation *> quantized_;
|
||||
@ -539,12 +547,49 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
|
||||
return {};
|
||||
}
|
||||
|
||||
// This method scans the operations in the function to setup the initial
|
||||
// states for quantization parameter propagation.
|
||||
// TODO(fengliuai): This algorithm assumes there are only one pair of
|
||||
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
|
||||
// check should be applied.
|
||||
void QuantizationDriver::Initialize() {
|
||||
void QuantizationDriver::PreprocessConstantOps() {
|
||||
fn_.walk([&](ConstantOp cst) {
|
||||
// Non-float tensors are neither weights nor require quantization.
|
||||
auto type = cst.getType().dyn_cast<ShapedType>();
|
||||
if (!type || !type.getElementType().isa<FloatType>()) return;
|
||||
|
||||
Value *value = cst.getResult();
|
||||
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
||||
bool used_as_weight = false;
|
||||
for (auto &use : value->getUses()) {
|
||||
auto spec = GetQuantSpec(use.getOwner());
|
||||
auto biases = spec->biases_params;
|
||||
Operation *user = use.getOwner();
|
||||
int operand_num = use.getOperandNumber();
|
||||
|
||||
// The user doesn't use this value as a bias operand or require same
|
||||
// scale, then this constant is considered to be a weight.
|
||||
if (biases.find(operand_num) == biases.end() &&
|
||||
!spec->requires_same_scale) {
|
||||
used_as_weight = true;
|
||||
} else {
|
||||
bias_users.push_back({user, operand_num});
|
||||
}
|
||||
}
|
||||
|
||||
// If the constant is used as a weight, this constant will be duplicated for
|
||||
// each bias user, so it isn't shared with the weight usage. Otherwise, the
|
||||
// first bias user can use the original constant and the rest use the
|
||||
// duplications, so we pop bias user from the set.
|
||||
if (used_as_weight) {
|
||||
weights_.insert(cst);
|
||||
} else {
|
||||
bias_users.pop_back();
|
||||
builder_.setInsertionPoint(cst);
|
||||
}
|
||||
for (auto bias_user : bias_users) {
|
||||
auto copied = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
|
||||
bias_user.first->setOperand(bias_user.second, copied.getResult());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void QuantizationDriver::SetupAllStates() {
|
||||
llvm::DenseMap<Value *, int> value_to_state;
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
@ -582,6 +627,21 @@ void QuantizationDriver::Initialize() {
|
||||
});
|
||||
}
|
||||
|
||||
// This method scans the operations in the function to setup the initial
|
||||
// states for quantization parameter propagation.
|
||||
// TODO(fengliuai): This algorithm assumes there are only one pair of
|
||||
// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
|
||||
// check should be applied.
|
||||
void QuantizationDriver::Initialize() {
|
||||
// Duplicate the bias constant, so the states can be setup correctly.
|
||||
// TODO(fengliuai): Function definition should also be duplicated if there are
|
||||
// multiple call sites.
|
||||
PreprocessConstantOps();
|
||||
|
||||
// Setup all the internal states.
|
||||
SetupAllStates();
|
||||
}
|
||||
|
||||
bool QuantizationDriver::PropagateParams() {
|
||||
// TODO(fengliuai): uses a typed indicator instead of a bool value.
|
||||
bool changed = false;
|
||||
@ -590,7 +650,7 @@ bool QuantizationDriver::PropagateParams() {
|
||||
work_list_.pop_back();
|
||||
|
||||
// This op has been quantized, so we should not consider it again.
|
||||
if (quantized_.find(op) != quantized_.end()) continue;
|
||||
if (llvm::is_contained(quantized_, op)) continue;
|
||||
quantized_.insert(op);
|
||||
|
||||
auto spec = GetQuantSpec(op);
|
||||
@ -600,9 +660,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
if (!spec->is_quantizable) continue;
|
||||
|
||||
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
||||
// This constant is used as a bias in another op, then the quantization
|
||||
// parameters are determined by that op.
|
||||
if (UsedAsBias(cst) || IsQuantized(op)) continue;
|
||||
// If it isn't a weight or has been quantized, skip.
|
||||
if (!IsWeight(cst) || IsQuantized(op)) continue;
|
||||
|
||||
// The quantization parameters are determined by the content of the
|
||||
// constant.
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
|
||||
using QuantizedType = mlir::quant::QuantizedType;
|
||||
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
|
||||
@ -119,7 +119,7 @@ class NoQuantizableResult
|
||||
static bool IsQuantizable() { return false; }
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -32,11 +32,12 @@ static Type GetQuantizedType(Builder builder, Type input_type, double min,
|
||||
double max, int storage_type_width,
|
||||
bool narrow_range, bool is_signed) {
|
||||
auto converter =
|
||||
quant::ExpressedToUniformQuantizedConverter::forInputType(input_type);
|
||||
quant::ExpressedToQuantizedConverter::forInputType(input_type);
|
||||
|
||||
quant::UniformQuantizedType quantizedEleType = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(), storage_type_width, min, max, narrow_range,
|
||||
converter.expressedType, is_signed);
|
||||
if (!quantizedEleType) return {};
|
||||
return converter.convert(quantizedEleType);
|
||||
}
|
||||
|
||||
@ -79,20 +80,24 @@ Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr,
|
||||
double min = std::numeric_limits<double>::max();
|
||||
double max = std::numeric_limits<double>::min();
|
||||
if (auto fp = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
|
||||
double ele_value = FloatAttr::getValueAsDouble(*it);
|
||||
min = std::min(min, ele_value);
|
||||
max = std::max(max, ele_value);
|
||||
// If all the element values are same we don't need to scan the content.
|
||||
if (fp.isSplat()) {
|
||||
min = max =
|
||||
FloatAttr::getValueAsDouble(fp.getSplatValue<llvm::APFloat>());
|
||||
} else {
|
||||
for (auto it = fp.begin(), e = fp.end(); it != e; ++it) {
|
||||
double ele_value = FloatAttr::getValueAsDouble(*it);
|
||||
min = std::min(min, ele_value);
|
||||
max = std::max(max, ele_value);
|
||||
}
|
||||
}
|
||||
// The range must straddle zero.
|
||||
if (min > 0.0 || max < 0.0) return {};
|
||||
auto type = GetQuantizedType(builder, attr.getType(), min, max,
|
||||
storage_type_width, narrow_range, is_signed);
|
||||
if (auto ele_type = type.dyn_cast_or_null<TensorType>())
|
||||
return ele_type.getElementType();
|
||||
}
|
||||
|
||||
// The range from SplatElementAttr and other element attribute types couldn't
|
||||
// The range from SplatElementAttr and other element attribute types couldn't
|
||||
// straddle zero, so the quantization parameters couldn't be derived from its
|
||||
// range.
|
||||
return {};
|
||||
|
@ -40,7 +40,8 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
"FixedResultUniformScale<([0-9]+).*(true|false)>"};
|
||||
emitSourceFileHeader("Generated Ops Quant Spec Getters", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
// Retrieve all the definitions derived from Op defintion and sort by record
|
||||
// name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
@ -53,8 +54,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
for (const auto t : op.getTraits()) {
|
||||
if (auto opTrait = llvm::dyn_cast<mlir::tblgen::NativeOpTrait>(&t)) {
|
||||
auto trait = opTrait->getTrait();
|
||||
// We only handle TFL specific native op traits.
|
||||
if (!trait.consume_front("OpTrait::TFL::")) continue;
|
||||
if (!trait.consume_front("OpTrait::quant::")) continue;
|
||||
|
||||
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
|
||||
<< ">(op)) {\n";
|
||||
@ -73,7 +73,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n";
|
||||
OUT(6) << "spec->restricted_output_params[std::make_pair("
|
||||
<< matches[1] << ", " << matches[2]
|
||||
<< ")].push_back(tfl.OpTrait::TFL::" << trait << "<"
|
||||
<< ")].push_back(tfl.OpTrait::quant::" << trait << "<"
|
||||
<< op.getQualCppClassName()
|
||||
<< ">::GetResultQuantizedType(i));\n";
|
||||
matches.clear();
|
||||
|
@ -13,6 +13,11 @@ func @extractSimpleOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractPackedInputOphint
|
||||
func @extractPackedInputOphint() {
|
||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
@ -30,6 +35,11 @@ func @extractPackedInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractFirstInputOphint
|
||||
func @extractFirstInputOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -46,6 +56,11 @@ func @extractFirstInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractLastInputOphint
|
||||
func @extractLastInputOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -62,6 +77,11 @@ func @extractLastInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractPackOneInputOphint
|
||||
func @extractPackOneInputOphint() {
|
||||
// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32>
|
||||
@ -75,13 +95,16 @@ func @extractPackOneInputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractStackInputOutputOphint
|
||||
func @extractStackInputOutputOphint() {
|
||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -98,11 +121,14 @@ func @extractStackInputOutputOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extractMultipleInputsOutputsOphint
|
||||
func @extractMultipleInputsOutputsOphint() {
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
@ -119,21 +145,33 @@ func @extractMultipleInputsOutputsOphint() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation"}
|
||||
// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_stack"}
|
||||
// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_first"}
|
||||
// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_last"}
|
||||
// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"}
|
||||
// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"}
|
||||
// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"}
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: inputsAfterOutputs
|
||||
func @inputsAfterOutputs() {
|
||||
// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
|
||||
%0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
%3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
|
||||
%5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32>
|
||||
%7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor<f32>, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-opt -tfl-legalize-ophint-func-op %s | FileCheck %s
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @testConvertUnidirectionalSequenceRNN
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<1x3xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<1x3xf32>)
|
||||
func @testConvertUnidirectionalSequenceRNN(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32>
|
||||
// CHECK: %[[PACKED_INPUT:[a-z0-9]*]] = "tfl.pack"(%[[ARG_0]], %[[ARG_1]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32>
|
||||
// CHECK: %[[FUSED_OUTPUT:[a-z0-9]*]] = "tfl.unidirectional_sequence_rnn"(%[[PACKED_INPUT]], %[[CST_1]], %[[CST_2]], %[[CST_0]], %[[CST]]) {fused_activation_function = "TANH", time_major = true} : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
|
||||
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[FUSED_OUTPUT]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>)
|
||||
|
||||
%cst = constant dense<0.000000e+00> : tensor<1x4xf32>
|
||||
%cst0 = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
%cst1 = constant dense<0.000000e+00> : tensor<4x3xf32>
|
||||
%cst2 = constant dense<0.000000e+00> : tensor<4x4xf32>
|
||||
%2 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32>
|
||||
%3 = call @a9211722c23011e9875cdc4a3e957995(%2, %cst1, %cst2, %cst0, %cst) : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
|
||||
%4:2 = "tfl.unpack"(%3) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>)
|
||||
return %4#0 : tensor<1x4xf32>
|
||||
}
|
||||
func @a9211722c23011e9875cdc4a3e957995(tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32>
|
||||
attributes {_tflite_function_name = "UnidirectionalSequenceRnn"}
|
||||
}
|
@ -50,7 +50,7 @@ func @biasAddInt(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor
|
||||
func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i32 {
|
||||
%0 = "tf.Squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%2 = constant dense<[2, 5]> : tensor<2xi32>
|
||||
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
return %4 : i32
|
||||
@ -119,8 +119,8 @@ func @fakeQuantArgsTrue(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
}
|
||||
|
||||
func @fakeQuantVarsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%arg1 = constant dense<-0.1> : tensor<f32>
|
||||
%arg2 = constant dense<0.2> : tensor<f32>
|
||||
%arg1 = "tf.Const"() { value = dense<-0.1> : tensor<f32> } : () -> tensor<f32>
|
||||
%arg2 = "tf.Const"() { value = dense<0.2> : tensor<f32> } : () -> tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor<f32>, tensor<f32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
|
||||
@ -153,6 +153,14 @@ func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0: tensor<i32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_int
|
||||
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
}
|
||||
|
||||
func @placeholder_min(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input", min = -0.1 : f32} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
@ -409,7 +417,7 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
|
||||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
@ -418,7 +426,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
@ -427,7 +435,7 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
|
||||
}
|
||||
|
||||
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
@ -509,6 +517,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) ->
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
return %0: tensor<8x3xf32>
|
||||
|
||||
// CHECK-LABEL: select_multidim
|
||||
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8x3xf32>
|
||||
}
|
||||
|
||||
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
@ -518,6 +535,15 @@ func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
return %0: tensor<8x3xf32>
|
||||
|
||||
// CHECK-LABEL: select_v2_multidim
|
||||
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8x3xf32>
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
@ -536,7 +562,7 @@ func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?
|
||||
}
|
||||
|
||||
func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<8xf32>, tensor<i32>) -> (tensor<2xf32>, tensor<2xi32>)
|
||||
return %1#0, %1#1: tensor<2xf32>, tensor<2xi32>
|
||||
|
||||
@ -546,7 +572,7 @@ func @topk_2(%arg0: tensor<8xf32>) -> (tensor<2xf32>, tensor<2xi32>) {
|
||||
}
|
||||
|
||||
func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
|
||||
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
|
||||
|
||||
@ -556,7 +582,7 @@ func @topk_3(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
|
||||
}
|
||||
|
||||
func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<1x2x3x4xf32>, tensor<i32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>)
|
||||
return %1#0, %1#1: tensor<1x2x3x2xf32>, tensor<1x2x3x2xi32>
|
||||
|
||||
@ -566,7 +592,7 @@ func @topk_4(%arg0: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x2xf32>, tensor<1x2x3x2
|
||||
}
|
||||
|
||||
func @topk_5(%arg0: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
|
||||
%0 = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
|
||||
%1:2 = "tf.TopKV2"(%arg0, %0) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
|
||||
|
||||
@ -660,9 +686,18 @@ func @pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
// CHECK: return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @pow(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x1x1xf32>) -> tensor<2x1x3xf32> {
|
||||
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x1x1xf32>) -> tensor<2x1x3xf32>
|
||||
return %0 : tensor<2x1x3xf32>
|
||||
|
||||
// CHECK-LABEL: pow
|
||||
// CHECK: %[[pow:.*]] = "tfl.pow"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x1x1xf32>) -> tensor<2x1x3xf32>
|
||||
// CHECK: return %[[pow]] : tensor<2x1x3xf32>
|
||||
}
|
||||
|
||||
func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
|
||||
^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>):
|
||||
%cst = constant dense<[1, 2]> : tensor<2xi32>
|
||||
%cst = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32>
|
||||
return %0 : tensor<2x6xf32>
|
||||
|
||||
@ -673,7 +708,7 @@ func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> {
|
||||
|
||||
func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
%cst = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%0 = "tf.PadV2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
|
||||
return %0#0 : tensor<? x f32>
|
||||
|
||||
@ -832,12 +867,12 @@ func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32>
|
||||
// CHECK: %0:3 = "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
|
||||
}
|
||||
|
||||
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<1xi32>) -> tensor<1x4x2x3xf32> {
|
||||
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<1x4x2x3xf32> {
|
||||
%0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) {num_split = 2 : i64} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
return %0#0 : tensor<1x4x2x3xf32>
|
||||
|
||||
// CHECK-LABEL: splitv
|
||||
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
// CHECK: %0:2 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
}
|
||||
|
||||
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
@ -849,8 +884,8 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
|
||||
}
|
||||
|
||||
func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = constant dense<[1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) {N = 2 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: concat2Tensors
|
||||
@ -858,8 +893,8 @@ func @concat2Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi
|
||||
}
|
||||
|
||||
func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) {N = 3 : i64} : (tensor<i32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concat3Tensors
|
||||
@ -867,8 +902,8 @@ func @concat3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = constant dense<[-1]> : tensor<1xi32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) {N = 3 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2With3Tensors
|
||||
@ -1084,3 +1119,35 @@ func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> {
|
||||
// CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32>
|
||||
// CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
|
||||
}
|
||||
|
||||
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4
|
||||
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v4_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %arg2, %arg3, %arg4) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4_no_pad
|
||||
// CHECK: %0:2 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v5(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v5
|
||||
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
}
|
||||
|
||||
func @non_max_suppression_v5_no_pad(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> tensor<2xi32> {
|
||||
%0:3 = "tf.NonMaxSuppressionV5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {pad_to_max_output_size = false}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v5_no_pad
|
||||
// CHECK: %0:3 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
}
|
||||
|
@ -0,0 +1,107 @@
|
||||
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
%0 = "tfl.lstm"(%arg0, // input
|
||||
%arg1, %arg2, %arg3, %arg4, // weights
|
||||
%arg5, %arg6, %arg7, %arg8, // recurrent weights
|
||||
%arg9, %arg10, %arg11, // cell weights
|
||||
%arg12, %arg13, %arg14, %arg15, // bias
|
||||
%arg16, %arg17, // projection weight and bias
|
||||
%arg18, %arg19, // stateful
|
||||
%arg20, %arg21, %arg22, %arg23 // layer norm coefficients
|
||||
) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<? xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.lstm"
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant unit
|
||||
|
||||
// input gate
|
||||
// CHECK-NEXT: %[[in1:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in2:.*]] = "tfl.fully_connected"(%arg18, %arg5, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in3:.*]] = "tfl.mul"(%arg19, %arg9)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in4:.*]] = "tfl.add_n"(%[[in1]], %[[in2]], %[[in3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in5:.*]] = "tfl.l2_normalization"(%[[in4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in6:.*]] = tfl.add %[[in4]], %[[in5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in7:.*]] = "tfl.fully_connected"(%[[in6]], %arg20, %arg12)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[in8:.*]] = "tfl.logistic"(%[[in7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// forget gate
|
||||
// CHECK-NEXT: %[[fo1:.*]] = "tfl.fully_connected"(%arg0, %arg2, %[[cst]])
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo2:.*]] = "tfl.fully_connected"(%arg18, %arg6, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo3:.*]] = "tfl.mul"(%arg19, %arg10)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo4:.*]] = "tfl.add_n"(%[[fo1]], %[[fo2]], %[[fo3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo5:.*]] = "tfl.l2_normalization"(%[[fo4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo6:.*]] = tfl.add %[[fo4]], %[[fo5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo7:.*]] = "tfl.fully_connected"(%[[fo6]], %arg21, %arg13)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[fo8:.*]] = "tfl.logistic"(%[[fo7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// cell gate
|
||||
// CHECK-NEXT: %[[ce1:.*]] = "tfl.fully_connected"(%arg0, %arg3, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce2:.*]] = "tfl.fully_connected"(%arg18, %arg7, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce3:.*]] = "tfl.add_n"(%[[ce1]], %[[ce2]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce4:.*]] = "tfl.l2_normalization"(%[[ce3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce5:.*]] = tfl.add %[[ce3]], %[[ce4]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce6:.*]] = "tfl.fully_connected"(%[[ce5]], %arg22, %arg14)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ce7:.*]] = "tfl.tanh"(%[[ce6]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// CHECK-NEXT: %[[ac1:.*]] = "tfl.mul"(%[[fo8]], %arg19)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac2:.*]] = tfl.mul %[[in8]], %[[ce7]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac3:.*]] = tfl.add %[[ac1]], %[[ac2]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// output gate
|
||||
// CHECK-NEXT: %[[ou1:.*]] = "tfl.fully_connected"(%arg0, %arg4, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou2:.*]] = "tfl.fully_connected"(%arg18, %arg8, %[[cst]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou3:.*]] = "tfl.mul"(%[[ac3]], %arg11)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou4:.*]] = "tfl.add_n"(%[[ou1]], %[[ou2]], %[[ou3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou5:.*]] = "tfl.l2_normalization"(%[[ou4]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou6:.*]] = tfl.add %[[ou4]], %[[ou5]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou7:.*]] = "tfl.fully_connected"(%[[ou6]], %arg23, %arg15)
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ou8:.*]] = "tfl.logistic"(%[[ou7]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
|
||||
// output activation
|
||||
// CHECK-NEXT: %[[ac4:.*]] = "tfl.tanh"(%[[ac3]])
|
||||
// CHECK-SAME: -> tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac5:.*]] = tfl.mul %[[ac4]], %[[ou8]]
|
||||
// CHECK-SAME: tensor<?x!quant.any<i16:f32>>
|
||||
// CHECK-NEXT: %[[ac6:.*]] = "tfl.fully_connected"(%[[ac5]], %arg16, %arg17)
|
||||
// CHECK-SAME: (tensor<?x!quant.any<i16:f32>>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x!quant.any<i8:f32>>
|
||||
// CHECK-NEXT: %[[ac7:.*]] = "tf_quant.pseudo_return"(%[[ac6]]) : (tensor<?x!quant.any<i8:f32>>) -> tensor<?x!quant.any<i8:f32>>
|
||||
// CHECK-NEXT: })
|
||||
// CHECK-NEXT: return
|
||||
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
@ -143,6 +143,19 @@ func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: t
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistLength(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>) -> (tensor<i32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListLength"(%0) : (tensor<!tf.variant<tensor<10xf32>>>) -> tensor<i32>
|
||||
return %1: tensor<i32>
|
||||
|
||||
// CHECK-LABEL: tensorlistLength
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>)
|
||||
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32>
|
||||
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: return [[RESULT]] : tensor<i32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
|
@ -278,6 +278,6 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
%21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
module attributes {
|
||||
tfl.metadata = {key1 = "value1", key2 = "value2"}
|
||||
} {
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "NONE"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %2 : tensor<3x2xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: buffers: [ {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK: }, {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 49 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 118, 97, 108, 117, 101, 50 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "key1",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "key2",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
@ -1,5 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
|
||||
// | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
@ -9,7 +8,7 @@ func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf
|
||||
return %2 : tensor<40x40xf32>
|
||||
}
|
||||
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, -1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
|
||||
|
@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
|
||||
// test invalid AddN
|
||||
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
|
||||
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer values}}
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
|
||||
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
|
||||
return %0 : tensor<? x f16>
|
||||
}
|
||||
@ -537,7 +537,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -591,8 +591,9 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -600,8 +601,9 @@ func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x
|
||||
|
||||
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
|
||||
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -610,7 +612,7 @@ func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %
|
||||
// test invalid none type applied to a tensor type arg
|
||||
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -619,7 +621,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
// test violation of projection weight and projection bias pred op trait
|
||||
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -628,7 +630,7 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
// test invalid kernel type
|
||||
func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -652,6 +654,15 @@ func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi
|
||||
|
||||
// -----
|
||||
|
||||
// test select with multi-dim inputs
|
||||
// CHECK-LABEL: testSelectMultiDim
|
||||
func @testSelectMultiDim(%cond : tensor<?xi1>, %arg0 : tensor<?x4xi32>, %arg1 : tensor<?x4xi32>) -> tensor<?x4xi32> {
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?x4xi32>,tensor<?x4xi32>) -> tensor<?x4xi32>
|
||||
return %0 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
|
||||
@ -660,6 +671,14 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> {
|
||||
// expected-error @+1 {{failed to verify that Select operands meet shape criteria}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
|
||||
@ -762,6 +781,21 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPadQuantizedU8
|
||||
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pad"(%arg0, %arg1)
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
return %0#0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testPadQuantizedI8
|
||||
func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pad"(%arg0, %arg1)
|
||||
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
return %0#0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPadV2
|
||||
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||
@ -817,6 +851,20 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) ->
|
||||
|
||||
// -----
|
||||
|
||||
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
|
||||
return %0 : tensor<2x2x!quant.uniform<u8:f32, 0.1>>
|
||||
}
|
||||
|
||||
func @packQuantizedI8(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<i8:f32, 0.1>>, tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>>
|
||||
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
@ -1143,8 +1191,8 @@ func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<i32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<1xi32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<1xi32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
}
|
||||
|
||||
@ -1331,16 +1379,16 @@ func @testSplitOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf
|
||||
|
||||
func @testSplitOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
|
||||
%split_dim = constant dense<0> : tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tfl.split' op operand #0 must be tensor<i32>}}
|
||||
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<2x2xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<? x i32>) -> tensor<16x4x4xf32> {
|
||||
// expected-error @+1 {{'tfl.split' op operand #0 must be tensor<i32>}}
|
||||
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<?xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
|
||||
func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
|
||||
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<*xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
@ -1424,3 +1472,252 @@ func @testSplitOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor
|
||||
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
|
||||
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () {
|
||||
%size_splits = constant dense<[]> : tensor<0xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}}
|
||||
"tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 0 : i32} : (tensor<16xf32>, tensor<0xi32>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%size_splits = constant dense<[4, 4, 4, 4]> : tensor<4xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op output count should match 'num_splits' attribute}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 4 : i32} : (tensor<16xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
|
||||
%size_splits = constant dense<[[8, 8], [2, 2]]> : tensor<2x2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2x2xi32>, tensor<i32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %size_splits: tensor<*xi32>) -> tensor<16x4x4xf32> {
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<*xi32>, tensor<i32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsConstant(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
|
||||
%size_splits = constant dense<[-2]> : tensor<1xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op elements of 'size_splits' should be greater than or equal to -1}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<i32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsConstantMultipleNegativeOne(%arg0: tensor<16x4x4xf32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>) {
|
||||
%size_splits = constant dense<[-1, -1, 14]> : tensor<3xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op 'size_splits' can only have one -1}}
|
||||
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 3 : i32} : (tensor<16x4x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>)
|
||||
return %0, %1, %2 : tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsConstantSum(%arg0: tensor<16x4x4xf32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>) {
|
||||
%size_splits = constant dense<[-1, 17]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op sum of non-negative elements of 'size_splits' is greater than the dimension size of 'split_dim' axis}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>)
|
||||
return %0, %1 : tensor<0x4x4xf32>, tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSizeSplitsSize(%arg0: tensor<16x4x4xf32>) -> tensor<15x4x4xf32> {
|
||||
%size_splits = constant dense<[15, 1]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op 'size_splits' should be 'tensor<1xi32>'}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> tensor<15x4x4xf32>
|
||||
return %0 : tensor<15x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
|
||||
%size_splits = constant dense<[16]> : tensor<1xi32>
|
||||
%split_dim = constant dense<0> : tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<2x2xi32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
|
||||
%size_splits = constant dense<[16]> : tensor<1xi32>
|
||||
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
|
||||
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<*xi32>) -> tensor<16x4x4xf32>
|
||||
return %0 : tensor<16x4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithOutOfRangeSplitDim(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<1> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithOutOfRangeSplitDimTFLConst(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithOutOfRangeSplitDimNegative(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<-2> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithMismatchSizeSplitsSum(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
|
||||
%size_splits = constant dense<[8, 4]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op sum of 'size_splits' should match the dimension size of 'split_dim' axis}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithMismatchTensorTypeSplitDimOut0(%arg0 : tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8xf32>'}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<4xf32>, tensor<4xf32>)
|
||||
return %0, %1 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithMismatchTensorTypeSplitDimOut1(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op output #1 should be 'tensor<8xf32>'}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
|
||||
return %0, %1 : tensor<8xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8x4xf32>'}}
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x2xf32>, tensor<8x2xf32>)
|
||||
return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
|
||||
%size_splits_0 = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim_0 = constant dense<0> : tensor<i32>
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
|
||||
|
||||
%size_splits_1 = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%split_dim_1 = constant dense<1> : tensor<i32>
|
||||
%2, %3 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
|
||||
|
||||
return %0, %1, %2, %3 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) {
|
||||
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
|
||||
%split_dim = constant dense<0> : tensor<i32>
|
||||
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x?xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
|
||||
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithValidSizeSplitsUneven(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>) {
|
||||
%size_splits_0 = constant dense<[7, 3, 6]> : tensor<3xi32>
|
||||
%split_dim_0 = constant dense<0> : tensor<i32>
|
||||
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
|
||||
|
||||
%size_splits_1 = constant dense<[1, 3]> : tensor<2xi32>
|
||||
%split_dim_1 = constant dense<1> : tensor<i32>
|
||||
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x3xf32>)
|
||||
|
||||
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>) {
|
||||
%size_splits_0 = constant dense<[7, -1, 6]> : tensor<3xi32>
|
||||
%split_dim_0 = constant dense<0> : tensor<i32>
|
||||
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
|
||||
|
||||
%size_splits_1 = constant dense<[-1, 4]> : tensor<2xi32>
|
||||
%split_dim_1 = constant dense<1> : tensor<i32>
|
||||
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x0xf32>, tensor<16x4xf32>)
|
||||
|
||||
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
|
||||
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0, %1 : tensor<2xi32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
|
||||
// expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}}
|
||||
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0, %1 : tensor<2xi32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
|
||||
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
|
||||
// expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}}
|
||||
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
|
||||
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
|
||||
}
|
||||
|
@ -107,10 +107,10 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %0 : tensor<4x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast
|
||||
@ -123,10 +123,10 @@ func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>
|
||||
return %1 : tensor<1x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
|
||||
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %0 : tensor<1x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %[[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
|
||||
@ -139,9 +139,9 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
// CHECK: return %0 : tensor<4x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
|
||||
@ -255,3 +255,40 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %2 = "tfl.strided_slice"(%1, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @L2NormalizePattern
|
||||
func @L2NormalizePattern(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%cst = constant dense<[0]> : tensor<1xi32>
|
||||
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%2 = "tfl.rsqrt"(%1) : (tensor<f32>) -> tensor<f32>
|
||||
%3 = "tfl.mul"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
return %3: tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @L2NormalizePattern1
|
||||
func @L2NormalizePattern1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%cst = constant dense<[0]> : tensor<1xi32>
|
||||
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
|
||||
%3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
return %3: tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @InvalidL2NormalizePattern
|
||||
// Div and square ops must take the same argument to be eligible.
|
||||
func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%cst = constant dense<[0]> : tensor<1xi32>
|
||||
%0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%2 = "tfl.sqrt"(%1) : (tensor<f32>) -> tensor<f32>
|
||||
%3 = "tfl.div"(%arg1, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
return %3: tensor<2xf32>
|
||||
// CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
// CHECK: return %3
|
||||
}
|
||||
|
@ -211,6 +211,17 @@ func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
// CHECK: return %3 : tensor<1x6x6x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeConcatConstantOperand
|
||||
func @NotQuantizeConcatConstantOperand(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = constant dense<1.0> : tensor<2xf32>
|
||||
%1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[cst]])
|
||||
// CHECK-NEXT: return %[[cc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConcatOperand0ToAll
|
||||
func @QuantizeConcatOperand0ToAll(tensor<2x!quant.uniform<u8:f32, 0.1:128>>, tensor<2xf32>) -> tensor<2x2xf32> {
|
||||
^bb0(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1:128>>, %arg1: tensor<2xf32>):
|
||||
@ -360,4 +371,85 @@ func @QuantizeConstant() -> tensor<2x3xf32> {
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8<1:255>:f32, 0.023622047244094488:128>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<2x3xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeNoneType
|
||||
func @NotQuantizeNoneType() -> none {
|
||||
%cst = constant unit
|
||||
return %cst : none
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant unit
|
||||
// CHECK-NEXT: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeZeroSplat
|
||||
func @QuantizeZeroSplat() -> tensor<2x3xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x3xf32>
|
||||
return %cst : tensor<2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<2x3xf32>
|
||||
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeZeroScalar
|
||||
func @QuantizeZeroScalar() -> tensor<f32> {
|
||||
%cst = constant dense<0.0> : tensor<f32>
|
||||
return %cst : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSharedBiases
|
||||
func @QuantizeSharedBiases(
|
||||
%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>,
|
||||
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
|
||||
%cst = constant dense<1.0> : tensor<32xf32>
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
|
||||
%2 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||
%conv1 = "tfl.conv_2d"(%1, %2, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%3 = "tfl.quantize"(%conv1) {qtype = tensor<1x112x112x32xf32>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||
%5 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
|
||||
%conv2 = "tfl.conv_2d"(%4, %5, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
|
||||
%6 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
return %6 : tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]])
|
||||
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]])
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSharedBiases2
|
||||
func @QuantizeSharedBiases2(
|
||||
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg1: tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg2: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> (tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>) {
|
||||
%cst = constant dense<1.0> : tensor<32xf32>
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
|
||||
%add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
%3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
%5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 2.0>>) -> tensor<32x3x3x3xf32>
|
||||
%conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32>
|
||||
%7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
return %3, %7 : tensor<32x!quant.uniform<u8:f32, 1.0>>, tensor<1x56x56x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>}
|
||||
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
|
||||
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
@ -16,13 +16,13 @@ func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> (tensor<256x30x30x1
|
||||
return %0, %1, %2, %3, %4 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK-LABEL: conv
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: %0 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: %0 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %1 = "tfl.conv_2d"(%arg0, %0, %[[CONSTANT]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %2 = "tf.Conv2D"
|
||||
// CHECK: %3 = "tf.Transpose"(%arg1, %cst_0) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %3 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
|
||||
// CHECK: %4 = "tfl.conv_2d"(%arg0, %3, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK: %5 = "tf.Conv2D"
|
||||
// CHECK: %6 = "tf.Conv2D"
|
||||
}
|
||||
@ -41,13 +41,13 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>) -> (tensor<2
|
||||
return %0, %1, %2, %3 : tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>
|
||||
|
||||
// CHECK-LABEL: depthwiseConv2D
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<12xf32>
|
||||
// CHECK: %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
|
||||
// CHECK: %0 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<12xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<[1, 3, 3, 12]> : tensor<4xi64>
|
||||
// CHECK: %0 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %1 = "tfl.depthwise_conv_2d"(%arg0, %0, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %2 = "tf.DepthwiseConv2dNative"
|
||||
// CHECK: %3 = "tf.Reshape"(%arg1, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %3 = "tf.Reshape"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32>
|
||||
// CHECK: %4 = "tfl.depthwise_conv_2d"(%arg0, %3, %[[CONSTANT]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<1x3x3x12xf32>, tensor<12xf32>) -> tensor<256x30x30x12xf32>
|
||||
// CHECK: %5 = "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
@ -155,10 +155,10 @@ func @fakeQuantFolded() -> (tensor<8xf32>) {
|
||||
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %rst : tensor<8xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<8xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: return %1 : tensor<8xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<8xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT]]) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantNotFolded
|
||||
@ -261,12 +261,12 @@ func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>)
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %cst)
|
||||
// CHECK: return %2
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
|
||||
@ -281,12 +281,12 @@ func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<48xf32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
|
||||
// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %1 = "tfl.dequantize"(%0)
|
||||
// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %cst)
|
||||
// CHECK: return %2
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32, 1.000000e+00>>}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) {
|
||||
@ -348,3 +348,11 @@ func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
||||
// CHECK-LABEL: stop_gradient
|
||||
// CHECK: return %arg0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
// Should be converted to Identity and then from Identity to value
|
||||
// CHECK-LABEL: CheckNumerics
|
||||
// CHECK: return %arg0 : tensor<3xf32>
|
||||
}
|
||||
|
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
@ -0,0 +1,223 @@
|
||||
// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
|
||||
|
||||
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2TwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2FlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2Matrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulTwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulFlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulMatrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
@ -41,10 +41,13 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::PassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
|
||||
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
// Ophint extraction will happen after island extraction pass.
|
||||
pass_manager->addPass(mlir::TFL::CreateExtractOphintPass());
|
||||
// Convert composite op pass will happen after ophint extraction pass.
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// Execute this pass before `CanonicalizerPass` in case some TensorList
|
||||
|
@ -131,7 +131,7 @@ int main(int argc, char **argv) {
|
||||
// message. So we can just return here.
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
mlir::PassManager pm;
|
||||
mlir::PassManager pm(&context);
|
||||
bool run_quantize =
|
||||
tensorflow::ShouldRunQuantizePasses(module.ValueOrDie().get());
|
||||
mlir::TFL::PassConfig pass_config;
|
||||
@ -149,7 +149,12 @@ int main(int argc, char **argv) {
|
||||
lower_tensor_list_ops, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
auto output = mlir::openOutputFile(output_file_name);
|
||||
std::string error_msg;
|
||||
auto output = mlir::openOutputFile(output_file_name, &error_msg);
|
||||
if (output == nullptr) {
|
||||
llvm::errs() << error_msg << '\n';
|
||||
return kTrFailure;
|
||||
}
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
@ -353,6 +355,127 @@ struct OphintCompositeOp {
|
||||
std::map<int, AggregatedOperand> outputs;
|
||||
};
|
||||
|
||||
// Preprocess the graph for topo sort. (each operation is a node, while
|
||||
// inputs/outputs indictate edges) Assume the graph is acyclic. The preprocess
|
||||
// does the following:
|
||||
// Compute each operations's in-degress (how many input nodes they're taken)
|
||||
// Get all consumer operations for every operations. (operation_to_ouputs)
|
||||
// Get the init_queue (those operations will be processed first).
|
||||
void PreprocessTopoSortGraph(
|
||||
Block* block, std::queue<Operation*>* init_queue,
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
|
||||
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
|
||||
for (auto& op : *block) {
|
||||
if (&op == block->getTerminator()) continue;
|
||||
if (op.getNumOperands() == 0) {
|
||||
init_queue->push(&op);
|
||||
} else {
|
||||
// The operand of the ops is not a direct indication of the "edge" as we
|
||||
// can have a pack op after a unpack op (they have multiple edges), we
|
||||
// should only count as one.
|
||||
llvm::DenseSet<Operation*> input_ops;
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
Operation* input_op = op.getOperand(i)->getDefiningOp();
|
||||
if (input_op) input_ops.insert(input_op);
|
||||
}
|
||||
if (input_ops.empty()) {
|
||||
init_queue->push(&op);
|
||||
continue;
|
||||
}
|
||||
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
||||
for (auto* input_op : input_ops) {
|
||||
auto preceeding_op_it = operation_to_ouputs->find(input_op);
|
||||
if (preceeding_op_it == operation_to_ouputs->end()) {
|
||||
auto result = operation_to_ouputs->try_emplace(
|
||||
input_op, llvm::DenseSet<Operation*>());
|
||||
preceeding_op_it = result.first;
|
||||
}
|
||||
preceeding_op_it->second.insert(&op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSideEffectOp(Operation* op) {
|
||||
if (op->hasNoSideEffect()) return false;
|
||||
|
||||
// Identity op has no side effect.
|
||||
// Check the OperationName maybe more elegant here.
|
||||
auto tf_identity_op = dyn_cast_or_null<TF::IdentityOp>(op);
|
||||
if (tf_identity_op) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// It's possible other transformations can benefit from this util function, but
|
||||
// since currently there's none, so we only limit this function to the ophint
|
||||
// extraction pass. We may refactor this function to extend the usage in future.
|
||||
//
|
||||
// Assume the graph is disconnected from outside.
|
||||
// Also assume the block has no arguments.
|
||||
LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||
std::queue<Operation*> init_queue;
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
|
||||
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
|
||||
std::vector<Operation*> sorted_ops;
|
||||
|
||||
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
|
||||
&operation_to_ouputs, &operation_to_in_degrees);
|
||||
while (!init_queue.empty()) {
|
||||
Operation* current_op = init_queue.front();
|
||||
init_queue.pop();
|
||||
sorted_ops.push_back(current_op);
|
||||
|
||||
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
|
||||
if (current_op_to_output_it == operation_to_ouputs.end()) {
|
||||
continue;
|
||||
}
|
||||
for (Operation* output_op : current_op_to_output_it->second) {
|
||||
auto output_op_it = operation_to_in_degrees.find(output_op);
|
||||
if (output_op_it == operation_to_in_degrees.end()) return failure();
|
||||
|
||||
output_op_it->second -= 1;
|
||||
if (output_op_it->second == 0) {
|
||||
init_queue.push(output_op);
|
||||
operation_to_in_degrees.erase(output_op_it);
|
||||
}
|
||||
}
|
||||
operation_to_ouputs.erase(current_op_to_output_it);
|
||||
}
|
||||
|
||||
// Before we performs the sort. We need to make sure we didn't mess the
|
||||
// ordering of original side-effect operations.
|
||||
// It's possible those side-effect operations have no topogocial relations
|
||||
// at all!
|
||||
std::vector<Operation*> original_side_effect_ops;
|
||||
std::vector<Operation*> after_sort_side_effect_ops;
|
||||
for (auto& op : *builder->getBlock()) {
|
||||
if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator()))
|
||||
original_side_effect_ops.push_back(&op);
|
||||
}
|
||||
for (auto* op : sorted_ops) {
|
||||
if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op);
|
||||
}
|
||||
if (original_side_effect_ops.size() != after_sort_side_effect_ops.size())
|
||||
return failure();
|
||||
for (int i = 0; i < original_side_effect_ops.size(); ++i) {
|
||||
if (original_side_effect_ops[i] != after_sort_side_effect_ops[i])
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Performs the sort.
|
||||
// Ideally it would be nice to just clear the block then write the sorted ops.
|
||||
// But unfortunately that's hard to do.
|
||||
for (int i = sorted_ops.size() - 1; i > 0; --i) {
|
||||
Operation* current_op = sorted_ops[i];
|
||||
for (int j = i - 1; j >= 0; --j) {
|
||||
Operation* prev_op = sorted_ops[j];
|
||||
prev_op->moveBefore(current_op);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
Operation* insert_before_op,
|
||||
const std::map<int, Value*>& inputs,
|
||||
@ -360,10 +483,12 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
OpBuilder* builder, ModuleOp* module_op) {
|
||||
SmallVector<Type, 4> input_types;
|
||||
SmallVector<Value*, 4> input_values;
|
||||
SmallVector<int, 4> input_indexes;
|
||||
for (const auto& kv : inputs) {
|
||||
Value* input = kv.second;
|
||||
input_types.push_back(input->getType());
|
||||
input_values.push_back(input);
|
||||
input_indexes.push_back(kv.first);
|
||||
}
|
||||
|
||||
SmallVector<Type, 4> func_output_types;
|
||||
@ -378,6 +503,8 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
attrs.push_back(builder->getNamedAttr(
|
||||
kTfLiteFunctionName, builder->getStringAttr(fused_func_type)));
|
||||
attrs.push_back(builder->getNamedAttr(
|
||||
kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes)));
|
||||
FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name,
|
||||
function_type, llvm::makeArrayRef(attrs));
|
||||
module_op->push_back(func_op);
|
||||
@ -507,6 +634,10 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
};
|
||||
|
||||
builder->getBlock()->walk(removeRemovableOps);
|
||||
|
||||
// Step 8: Topo sort to fix any invalid temporary IRs.
|
||||
if (failed(TopoSortOperations(builder))) return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,209 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
constexpr char kTfLiteFunctionName[] = "_tflite_function_name";
|
||||
constexpr char kUnidirectionalSequenceRnn[] = "UnidirectionalSequenceRnn";
|
||||
|
||||
// This pass is used for converting to TFLite composite op like
|
||||
// UnidirectionalSequenceRNN, UnidirectionalSequenceLSTM or SVDF Op. Currently,
|
||||
// this pass is only for ophint converted function op only. See below diagram:
|
||||
//
|
||||
// InputOp1 InputOp2 ...
|
||||
// \ /
|
||||
// \ /
|
||||
// call funcOp (say UnidirectionalSequenceRNN)
|
||||
// |
|
||||
// |
|
||||
// OutputOp1
|
||||
//
|
||||
// funcOp() { '_tflite_function_name' = 'UnidirectionalSequenceRNN'}
|
||||
//
|
||||
// ||
|
||||
// ||
|
||||
// \ /
|
||||
//
|
||||
// InputOp1 InputOp2 ...
|
||||
// \ /
|
||||
// \ /
|
||||
// tfl.UnidirectionalSequenceRNN
|
||||
// |
|
||||
// |
|
||||
// OutputOp1
|
||||
struct LegalizeOphintFuncOpPass : public ModulePass<LegalizeOphintFuncOpPass> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
llvm::StringMap<FuncOp> FindCompositeFuncOps(ModuleOp module) {
|
||||
llvm::StringMap<FuncOp> composite_func_ops;
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
if (func.getAttr(kTfLiteFunctionName))
|
||||
composite_func_ops[func.getName()] = func;
|
||||
}
|
||||
return composite_func_ops;
|
||||
}
|
||||
|
||||
LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
|
||||
CallOp* call_op,
|
||||
OpBuilder* builder,
|
||||
Operation** fused_op) {
|
||||
// UnidirectionalSequenceRnn takes exactly 5 inputs.
|
||||
if (composite_func_op.getNumArguments() != 5) return failure();
|
||||
if (call_op->getNumOperands() != 5) return failure();
|
||||
// UnidirectionalSequenceRnn has exactly 1 input.
|
||||
if (call_op->getNumResults() != 1) return failure();
|
||||
|
||||
// Inputs is indexed at 0.
|
||||
Value* input = call_op->getOperand(0);
|
||||
// Input_weight is indexed at 1.
|
||||
Value* weight = call_op->getOperand(1);
|
||||
// Recurrent_weight is indexed at 2.
|
||||
Value* recurrent_weight = call_op->getOperand(2);
|
||||
// Bias is indexed at 3.
|
||||
Value* bias = call_op->getOperand(3);
|
||||
// Hidden_state is indexed at 4.
|
||||
Value* hidden_state = call_op->getOperand(4);
|
||||
|
||||
// Build Output.
|
||||
auto output_type = call_op->getResult(0)->getType();
|
||||
|
||||
// Currently, ophinted RNN only supports time_major = True.
|
||||
const bool time_major = true;
|
||||
// Activation will always be TanH.
|
||||
StringAttr fused_activation_function = builder->getStringAttr("TANH");
|
||||
|
||||
builder->setInsertionPoint(call_op->getOperation());
|
||||
*fused_op = builder->create<TFL::UnidirectionalSequenceRNNOp>(
|
||||
call_op->getLoc(), output_type, input, weight, recurrent_weight, bias,
|
||||
hidden_state, builder->getBoolAttr(time_major),
|
||||
fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name,
|
||||
FuncOp composite_func_op,
|
||||
CallOp* call_op,
|
||||
OpBuilder* builder) {
|
||||
Operation* fused_op = nullptr;
|
||||
if (func_name == kUnidirectionalSequenceRnn) {
|
||||
// TODO(renjieliu): Validate the func op inputs.
|
||||
LogicalResult build_fused_op_result = BuildUnidirectionalSequenceRnnOp(
|
||||
composite_func_op, call_op, builder, &fused_op);
|
||||
if (failed(build_fused_op_result)) return build_fused_op_result;
|
||||
} else { // If we support more fused op, we should add the conversion here.
|
||||
return failure();
|
||||
}
|
||||
|
||||
call_op->replaceAllUsesWith(fused_op);
|
||||
|
||||
// Delete call op.
|
||||
Operation* call = call_op->getOperation();
|
||||
call->dropAllDefinedValueUses();
|
||||
call->dropAllReferences();
|
||||
call->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
|
||||
ModuleOp* module) {
|
||||
for (auto func : module->getOps<FuncOp>()) {
|
||||
// Ideally it will be much simpler if we can just use walk, but we also
|
||||
// want to early return if encounter errors. :(
|
||||
OpBuilder builder(func.getBody());
|
||||
// The call_op replacement within this loop works like an in-place
|
||||
// replacement, so it should be safe to do so.
|
||||
for (auto call_op :
|
||||
llvm::make_early_inc_range(builder.getBlock()->getOps<CallOp>())) {
|
||||
auto it = composite_func_ops->find(call_op.getCallee());
|
||||
if (it == composite_func_ops->end()) return failure();
|
||||
|
||||
// Replace the call op with TfLite fused op.
|
||||
// Currently it's only handled case by case, but ideally it would be
|
||||
// much better if we can do this automatically.
|
||||
FuncOp composite_func_op = it->second;
|
||||
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
|
||||
.cast<StringAttr>()
|
||||
.getValue();
|
||||
if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op,
|
||||
&call_op, &builder)))
|
||||
return failure();
|
||||
|
||||
composite_func_ops->erase(it);
|
||||
// Delete func op.
|
||||
Operation* func = composite_func_op.getOperation();
|
||||
func->erase();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void LegalizeOphintFuncOpPass::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
// Find all composite funcs, then for every call op inside every func op
|
||||
// within the module, we go ahead and replace the callop with the tflite
|
||||
// corresponding op and destroy the func op. This two-phase processing is
|
||||
// intended:
|
||||
//
|
||||
// Every func op is meant to be used exactly once.
|
||||
// Instead of finding the composite func then loop through the graph and
|
||||
// convert the call op immediately, we break finding & converting into two
|
||||
// phases. This changes the complexity from O(op_in_module *
|
||||
// function_in_module * attr_in_func) to O(op_in_module) * O(map_look_up) +
|
||||
// O(function_in_module * attr_in_func). O(op_in_module) is the dominant
|
||||
// factor here and map look up should be very cheap.
|
||||
llvm::StringMap<FuncOp> composite_func_ops = FindCompositeFuncOps(module);
|
||||
if (composite_func_ops.empty()) return;
|
||||
if (failed(ConvertCallOps(&composite_func_ops, &module))) {
|
||||
module.emitError() << "Legalize ophint: ConvertCallOps failed.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||
/// pass.
|
||||
std::unique_ptr<ModulePassBase> CreateLegalizeOphintFuncOpPass() {
|
||||
return std::make_unique<LegalizeOphintFuncOpPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeOphintFuncOpPass> pass(
|
||||
"tfl-legalize-ophint-func-op", "Convert composite op for TfLite dialect.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -20,6 +20,10 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def NonOpaqueElementsAttr : ElementsAttrBase<
|
||||
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
|
||||
"non-opaque constant tensor">;
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
@ -56,8 +60,13 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
// Convert to std constant for statically shaped, non-opaque constants.
|
||||
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
|
||||
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -124,6 +133,7 @@ def : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>;
|
||||
def : Pat<(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis),
|
||||
(TFL_OneHotOp $indices, $depth, $on_value, $off_value,
|
||||
(convertIntAttrTo32Bit $axis))>;
|
||||
def : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>;
|
||||
def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>;
|
||||
def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
|
||||
def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
|
||||
@ -156,7 +166,8 @@ def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
// The following two rules can both match an tf.Placeholder.input node with
|
||||
// min/max/type attributes, so we increase the benefit of the first rule by one
|
||||
// so the tfl.quantize and tfl.dequantize ops will be inserted if it matches.
|
||||
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
|
||||
def : Pat<(TF_PlaceholderInputOp TensorOf<[F16, F32, F64]>:$inputs,
|
||||
$min, $max, $type),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp
|
||||
(TFL_InputOp $inputs),
|
||||
@ -190,7 +201,8 @@ def : Pat<(TF_GatherV2Op $params, $indices,
|
||||
|
||||
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_NotEqualOp $l, $r), (TFL_NotEqualOp $l, $r)>;
|
||||
def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
|
||||
(TFL_NotEqualOp $l, $r)>;
|
||||
|
||||
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
|
||||
|
||||
@ -251,7 +263,7 @@ def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
|
||||
|
||||
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>;
|
||||
def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>;
|
||||
|
||||
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
|
||||
|
||||
@ -307,3 +319,11 @@ def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
|
||||
def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
|
||||
|
||||
def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size),
|
||||
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
|
||||
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>;
|
||||
|
@ -0,0 +1,228 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass prepare the tflite fused ops for quantization.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The LoadQuantizationRecipe Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass loads the quantization recipe for the TFLite ops to be quantized.
|
||||
// Specifically, it extends the fused ops with their internal implementation as
|
||||
// op regions. Each ops in the region produces results with element type
|
||||
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
|
||||
// defines the op quantization traits, which are used to propgate the
|
||||
// quantization parameters by the following passes.
|
||||
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
void Initialize(LSTMOp lstm, OpBuilder* builder);
|
||||
|
||||
// Create LSTM gates with different weights for input, recurrent and
|
||||
// cell state, and also the layer normalization parameters.
|
||||
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
|
||||
Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell,
|
||||
Value* ln_w, Value* ln_bias, OpBuilder* builder);
|
||||
|
||||
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
|
||||
Value* ln_bias, OpBuilder* builder);
|
||||
|
||||
// Add the internal implementation of the LSTM to its regions.
|
||||
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
|
||||
|
||||
StringAttr none_af;
|
||||
StringAttr fc_format;
|
||||
BoolAttr keep_dims;
|
||||
Type int8;
|
||||
Type int16;
|
||||
ConstantOp none_cst;
|
||||
};
|
||||
|
||||
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
Type expressed_type =
|
||||
lstm.input()->getType().cast<ShapedType>().getElementType();
|
||||
Type int8_storage_type = builder->getIntegerType(8);
|
||||
Type int16_storage_type = builder->getIntegerType(16);
|
||||
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
||||
int64_t int8_min = quant::QuantizedType::getDefaultMininumForInteger(
|
||||
flag, /*integralWidth=*/8);
|
||||
int64_t int8_max = quant::QuantizedType::getDefaultMaxinumForInteger(
|
||||
flag, /*integralWidth=*/8);
|
||||
int64_t int16_min = quant::QuantizedType::getDefaultMininumForInteger(
|
||||
flag, /*integralWidth=*/16);
|
||||
int64_t int16_max = quant::QuantizedType::getDefaultMaxinumForInteger(
|
||||
flag, /*integralWidth=*/16);
|
||||
auto any_int8 = quant::AnyQuantizedType::get(
|
||||
flag, int8_storage_type, expressed_type, int8_min, int8_max);
|
||||
auto any_int16 = quant::AnyQuantizedType::get(
|
||||
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
||||
|
||||
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
|
||||
Value* ln_w, Value* ln_bias,
|
||||
OpBuilder* builder) {
|
||||
// Note that l2_normalization and add ops here are not the execution kernle
|
||||
// implementation for layer_normalization and we just want to use them to
|
||||
// model the quantization requirement.
|
||||
auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
|
||||
auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
|
||||
return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
|
||||
none_af, fc_format, keep_dims);
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateGate(
|
||||
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
|
||||
OpBuilder* builder) {
|
||||
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
|
||||
none_af, fc_format, keep_dims);
|
||||
auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
|
||||
none_af, fc_format, keep_dims);
|
||||
|
||||
AddNOp s4;
|
||||
if (cell.hasValue()) {
|
||||
auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
|
||||
cell.getValue().second, none_af);
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>(
|
||||
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
|
||||
|
||||
} else {
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
|
||||
}
|
||||
|
||||
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
|
||||
|
||||
if (cell.hasValue()) {
|
||||
return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
|
||||
} else {
|
||||
return builder->create<TanhOp>(loc, int16, s5->getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
|
||||
Initialize(lstm, builder);
|
||||
|
||||
Region region;
|
||||
region.push_back(new Block);
|
||||
builder->setInsertionPointToEnd(®ion.front());
|
||||
Location loc = lstm.getLoc();
|
||||
Type int32_type = builder->getIntegerType(32);
|
||||
Type int32_tensor = builder->getTensorType(int32_type);
|
||||
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
|
||||
builder->getUnitAttr());
|
||||
|
||||
auto input_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_input_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
|
||||
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
|
||||
|
||||
auto forget_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_forget_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
|
||||
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
|
||||
|
||||
auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
|
||||
lstm.input_activation_state(),
|
||||
lstm.recurrent_to_cell_weights(), llvm::None,
|
||||
lstm.cell_layer_norm_coefficients(),
|
||||
lstm.cell_bias(), builder);
|
||||
|
||||
auto forget_cell_state = builder->create<MulOp>(
|
||||
loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
|
||||
auto input_cell_state = builder->create<MulOp>(
|
||||
loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
|
||||
auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
|
||||
input_cell_state.output(), none_af);
|
||||
|
||||
auto output_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_output_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
{new_cell, lstm.cell_to_output_weights()}),
|
||||
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
|
||||
|
||||
auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
|
||||
auto hidden_state = builder->create<MulOp>(
|
||||
loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af);
|
||||
auto act = builder->create<FullyConnectedOp>(
|
||||
loc, int8, hidden_state.output(), lstm.projection_weights(),
|
||||
lstm.projection_bias(), none_af, fc_format, keep_dims);
|
||||
|
||||
// TODO(fengliuai): define and register the op in the QuantOps Dialect.
|
||||
OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
|
||||
{int8}, {});
|
||||
builder->createOperation(return_state);
|
||||
|
||||
lstm.internal().takeBody(region);
|
||||
}
|
||||
|
||||
void LoadQuantizationRecipe::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
none_af = builder.getStringAttr("NONE");
|
||||
fc_format = builder.getStringAttr("DEFAULT");
|
||||
keep_dims = builder.getBoolAttr(false);
|
||||
|
||||
func.walk([&](Operation* op) {
|
||||
if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
|
||||
LoadForLSTMOp(lstm, &builder);
|
||||
}
|
||||
// Handles other ops.
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
|
||||
// pass.
|
||||
std::unique_ptr<FunctionPassBase> CreateLoadQuantizationRecipePass() {
|
||||
return absl::make_unique<LoadQuantizationRecipe>();
|
||||
}
|
||||
|
||||
static PassRegistration<LoadQuantizationRecipe> pass(
|
||||
"tfl-load-recipe", "Load TFL op quantization recipe");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -429,12 +429,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||
tf_op.element_dtype().isF64() ||
|
||||
tf_op.element_dtype().isInteger(1) ||
|
||||
tf_op.element_dtype().isInteger(8) ||
|
||||
tf_op.element_dtype().isInteger(16) ||
|
||||
tf_op.element_dtype().isInteger(32) ||
|
||||
tf_op.element_dtype().isInteger(64))) {
|
||||
return tf_op.emitError(
|
||||
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer "
|
||||
"or 16-bit/32-bit/64-bit "
|
||||
"float type during TF Lite transformation pass");
|
||||
}
|
||||
@ -461,6 +463,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
auto c = ConvertTFTensorListPushBack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListLength(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(tf_op);
|
||||
|
@ -110,3 +110,42 @@ def : Pat<(TFL_MulOp (TFL_DepthwiseConv2DOp $input,
|
||||
// with the same scale. We want to remove the redundancy.
|
||||
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
|
||||
// Constraint that makes sure both operands are the same operands.
|
||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Checks if the operand has rank == n
|
||||
class OperandHasRank<int n> : Constraint<
|
||||
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
|
||||
// This pattern constructs L2NormalizationOp from
|
||||
// Mul->Rsqrt->Sum->Square
|
||||
// Currently L2Normalization doesn't support activation function
|
||||
// in TFLite.
|
||||
// TODO(karimnosseir): Add constraints that the kernel code assumes.
|
||||
// constraint on axis and depth.
|
||||
def : Pat<(TFL_MulOp $operand1,
|
||||
(TFL_RsqrtOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp $square_operand),
|
||||
(ConstantOp I32ElementsAttr:$constant),
|
||||
$keep_dims)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand)]>;
|
||||
|
||||
// This pattern constructs L2NormalizationOp from
|
||||
// Div->sqrt->Sum->Square
|
||||
// Currently L2Normalization doesn't support activation function
|
||||
// in TFLite.
|
||||
// TODO(karimnosseir): Add constraints that the kernel code assumes.
|
||||
// constraint on axis and depth.
|
||||
def : Pat<(TFL_DivOp $operand1,
|
||||
(TFL_SqrtOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp $square_operand),
|
||||
(ConstantOp I32ElementsAttr:$constant),
|
||||
$keep_dims)),
|
||||
TFL_AF_None),
|
||||
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
|
||||
[(EqualOperands $operand1, $square_operand)]>;
|
||||
|
@ -21,8 +21,12 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
class FunctionPassBase;
|
||||
class ModulePassBase;
|
||||
class FuncOp;
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
class OpPassBase;
|
||||
using FunctionPassBase = OpPassBase<FuncOp>;
|
||||
using ModulePassBase = OpPassBase<ModuleOp>;
|
||||
|
||||
namespace TFL {
|
||||
|
||||
@ -64,6 +68,10 @@ std::unique_ptr<FunctionPassBase> CreatePrepareCompositeFunctionsPass();
|
||||
// Creates a instance of the TensorFlow Lite dialect ExtractOphint pass.
|
||||
std::unique_ptr<ModulePassBase> CreateExtractOphintPass();
|
||||
|
||||
// Creates a instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||
// pass. The composite op is created from the ophint extraction pass.
|
||||
std::unique_ptr<ModulePassBase> CreateLegalizeOphintFuncOpPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -18,6 +18,14 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
||||
|
||||
def NonOpaqueElementsAttr : ElementsAttrBase<
|
||||
CPred<"!$_self.isa<OpaqueElementsAttr>()">,
|
||||
"non-opaque constant tensor">;
|
||||
|
||||
// Convert to std constant for statically shaped, non-opaque constants.
|
||||
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
|
||||
[(AnyStaticShapeTensor $res)]>;
|
||||
|
||||
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
|
||||
// operations. Specifically, performs the following calculation:
|
||||
//
|
||||
@ -81,8 +89,8 @@ class TFi32<int v> : ConstantAttr<I32ElementsAttr, !cast<string>(v)>;
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
|
||||
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $b),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
|
||||
/*limit=*/(TF_ConstOp TFi32<0>),
|
||||
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))),
|
||||
$at, ConstBoolAttrTrue)>;
|
||||
|
||||
// Matmul with transpose on a to matmul with explicit transpose op and a not
|
||||
@ -90,10 +98,12 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
|
||||
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
|
||||
/*start=*/(TF_RankOp $a),
|
||||
/*limit=*/(ConstantOp TFi32<0>),
|
||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
|
||||
/*limit=*/(TF_ConstOp TFi32<0>),
|
||||
/*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b,
|
||||
ConstBoolAttrFalse, $bt)>;
|
||||
|
||||
// Partially supported in TFLite, treated as passthrough IdentityOp
|
||||
def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>;
|
||||
def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>;
|
||||
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
|
||||
|
||||
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -246,7 +247,8 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
filter_type.getShape());
|
||||
auto bias_type = rewriter.getTensorType({bias_dim}, elem_type);
|
||||
auto bias_attr = rewriter.getZeroAttr(bias_type);
|
||||
auto bias = rewriter.create<ConstantOp>(op->getLoc(), bias_type, bias_attr);
|
||||
auto bias =
|
||||
rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
|
||||
|
||||
auto *conv_state = static_cast<ConvertTFConvOpMatchState *>(state.get());
|
||||
auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
|
||||
@ -297,7 +299,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
rewriter.getIntegerType(32));
|
||||
auto perm_attr =
|
||||
DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
@ -366,7 +368,7 @@ class ConvertTFDepthwiseConv2dNative
|
||||
auto shape_type = rewriter.getTensorType({4}, rewriter.getIntegerType(64));
|
||||
auto shape_attr =
|
||||
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(result_shape));
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
|
||||
|
||||
return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
|
||||
}
|
||||
@ -377,6 +379,11 @@ class ConvertTFDepthwiseConv2dNative
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
|
@ -14,9 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorList transformation patterns.
|
||||
// Note that the pattern below rewrites `TensorList` tensors (which has type DT_VARIANT)
|
||||
@ -34,3 +38,11 @@ def ConvertTFTensorListStack : Pat<
|
||||
def ConvertTFTensorListGetItem : Pat<
|
||||
(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||
|
||||
// TensorListLength is equivalent to the size of the first dimension of the
|
||||
// input tensorlist, rewrite it to a combination of Gather and Shape op.
|
||||
def ConvertTFTensorListLength: Pat<
|
||||
(TF_TensorListLengthOp:$old_value $input),
|
||||
(TF_GatherOp
|
||||
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
|
||||
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;
|
||||
|
@ -105,13 +105,13 @@ void TrimFunctionsPass::Verify() {
|
||||
SymbolTable symbol_table = SymbolTable(getModule());
|
||||
llvm::SetVector<FuncOp> reachable_funcs;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
func.walk<CallOp>([&](CallOp op) {
|
||||
if (!symbol_table.lookup<FuncOp>(op.getCallee())) {
|
||||
getModule().emitError()
|
||||
<< func.getName() << " is not in the funcs whitelist";
|
||||
return signalPassFailure();
|
||||
}
|
||||
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
||||
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
||||
return getModule().emitError()
|
||||
<< func.getName() << " is not in the funcs whitelist";
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
|
309
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
309
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
@ -0,0 +1,309 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
||||
// of the inputs, matmul them individually, then stack them all back together at
|
||||
// the end.
|
||||
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void UnrollBatchMatMulPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
int64_t shape_rank = shape.size();
|
||||
auto shapeSpecType =
|
||||
rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
|
||||
Type resultType = rewriter.getTensorType(shape, elementType);
|
||||
auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
|
||||
auto shapeTensor =
|
||||
rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
|
||||
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
|
||||
/*shape=*/shapeTensor);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
||||
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
|
||||
Type elementType = tensorType.getElementType();
|
||||
|
||||
int rank = tensorType.getShape().size();
|
||||
int num_rows = tensorType.getShape()[rank - 2];
|
||||
int num_cols = tensorType.getShape()[rank - 1];
|
||||
|
||||
// Reshape to rank-3 Tensor with first dimension as the batch size.
|
||||
auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
|
||||
|
||||
std::vector<Value*> sliced;
|
||||
Type int64Type = rewriter.getIntegerType(64);
|
||||
Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
|
||||
|
||||
// Slice along each batch index and remember the slice output for future
|
||||
// use.
|
||||
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
auto vector3Type = rewriter.getTensorType({3}, int64Type);
|
||||
|
||||
auto begin_attr =
|
||||
DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
|
||||
auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
|
||||
auto begin = rewriter.create<ConstantOp>(loc, vector3Type, begin_attr);
|
||||
auto size = rewriter.create<ConstantOp>(loc, vector3Type, size_attr);
|
||||
auto sliceOp =
|
||||
rewriter.create<TF::SliceOp>(loc, sliceResultType,
|
||||
/*input=*/reshapeOp.output(), begin, size);
|
||||
|
||||
// Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
|
||||
// num_cols]
|
||||
auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
sliced.emplace_back(squeezeOp.output());
|
||||
}
|
||||
return sliced;
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
||||
Value* value, Location loc, PatternRewriter& rewriter) {
|
||||
auto valueType = value->getType().cast<RankedTensorType>();
|
||||
auto shape = valueType.getShape();
|
||||
int dims = shape.size();
|
||||
|
||||
std::vector<int32_t> perm(dims);
|
||||
for (int i = 0; i < dims - 2; i++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
perm[dims - 2] = dims - 1;
|
||||
perm[dims - 1] = dims - 2;
|
||||
|
||||
auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
|
||||
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
|
||||
int64_t r = transposed_shape[dims - 1];
|
||||
int64_t c = transposed_shape[dims - 2];
|
||||
|
||||
transposed_shape[dims - 1] = c;
|
||||
transposed_shape[dims - 2] = r;
|
||||
|
||||
auto transposed_type =
|
||||
rewriter.getTensorType(transposed_shape, valueType.getElementType());
|
||||
return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
|
||||
const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
|
||||
|
||||
std::vector<Value*> matmuls;
|
||||
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
||||
int lhs_batch_idx, rhs_batch_idx;
|
||||
if (bcast.IsBroadcastingRequired()) {
|
||||
lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
|
||||
rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
|
||||
} else {
|
||||
lhs_batch_idx = batch_idx;
|
||||
rhs_batch_idx = batch_idx;
|
||||
}
|
||||
auto false_attr = rewriter.getBoolAttr(false);
|
||||
auto matmul = rewriter.create<TF::MatMulOp>(loc, matmulType,
|
||||
/*a=*/sliced_lhs[lhs_batch_idx],
|
||||
/*b=*/sliced_rhs[rhs_batch_idx],
|
||||
/*transpose_a=*/false_attr,
|
||||
/*transpose_b=*/false_attr);
|
||||
matmuls.emplace_back(matmul.product());
|
||||
}
|
||||
|
||||
// Combine the result of each individual MatMul into a rank-3 Tensor.
|
||||
Type packedType = rewriter.getTensorType(
|
||||
{bcast.output_batch_size(), rows, cols}, elementType);
|
||||
|
||||
auto N = rewriter.getI64IntegerAttr(matmuls.size());
|
||||
auto axis = rewriter.getI64IntegerAttr(0);
|
||||
return rewriter.create<TF::PackOp>(loc, packedType,
|
||||
/*values=*/matmuls, N, axis);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
BatchMatMulOpType op, PatternRewriter& rewriter) const {
|
||||
Value* input_lhs = op.x();
|
||||
Value* input_rhs = op.y();
|
||||
|
||||
if (!input_lhs->getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
if (!input_rhs->getType().isa<RankedTensorType>()) {
|
||||
// RHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
|
||||
auto elementType = lhs_type.getElementType();
|
||||
|
||||
if (elementType != rhs_type.getElementType()) {
|
||||
// The element type of LHS must be the same with element type of RHS
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs_type.getShape();
|
||||
auto rhs_shape = rhs_type.getShape();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Transpose LHS input if necessary.
|
||||
if (op.adj_x()) {
|
||||
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
||||
|
||||
lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
lhs_shape = lhs_type.getShape();
|
||||
}
|
||||
|
||||
// Transpose RHS input if necessary.
|
||||
if (op.adj_y()) {
|
||||
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
||||
|
||||
rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
rhs_shape = rhs_type.getShape();
|
||||
}
|
||||
|
||||
// Ensure that input ranks are at least 2 and batch shapes are
|
||||
// broadcastable.
|
||||
const int dims_a = lhs_shape.size();
|
||||
const int dims_b = rhs_shape.size();
|
||||
if (dims_a < 2 || dims_b < 2) {
|
||||
// Both inputs must have rank >= 2
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
||||
// Input dimensions must be compatible for multipication.
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (dims_a == 2 && dims_b == 2) {
|
||||
// When both inputs are matrices, just replace the op to a matmul op.
|
||||
Type resultType =
|
||||
rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
|
||||
auto false_attr = rewriter.getBoolAttr(false);
|
||||
rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, resultType,
|
||||
/*a=*/input_lhs,
|
||||
/*b=*/input_rhs,
|
||||
/*transpose_a=*/false_attr,
|
||||
/*transpose_b=*/false_attr);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
|
||||
lhs_shape.begin(), lhs_shape.end()),
|
||||
absl::InlinedVector<tensorflow::int64, 4>(
|
||||
rhs_shape.begin(), rhs_shape.end()));
|
||||
|
||||
if (!bcast.IsValid()) {
|
||||
// Input batch dimensions must be broadcastable
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
// Compute slices for each batch in the LHS and RHS.
|
||||
std::vector<Value*> sliced_lhs =
|
||||
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
|
||||
std::vector<Value*> sliced_rhs =
|
||||
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
|
||||
|
||||
// Compute (single batch) MatMul for each output batch. The MatMul outputs
|
||||
// are then packed together into one output Tensor.
|
||||
auto packOp =
|
||||
createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
|
||||
rhs_shape[dims_b - 1], elementType, loc, rewriter);
|
||||
|
||||
// Reshape the rank-3 Tensor into the correct output shape.
|
||||
const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
|
||||
std::vector<int64_t> resultShape(resultBatchShape.begin(),
|
||||
resultBatchShape.end());
|
||||
resultShape.push_back(lhs_shape[dims_a - 2]);
|
||||
resultShape.push_back(rhs_shape[dims_b - 1]);
|
||||
|
||||
auto reshapeOp =
|
||||
createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
|
||||
rewriter.replaceOp(op, reshapeOp.output());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
static PassRegistration<UnrollBatchMatMulPass> pass(
|
||||
"tfl-unroll-batch-matmul",
|
||||
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
|
||||
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
|
||||
// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
|
||||
template <typename BatchMatMulOpType>
|
||||
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
|
||||
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
|
||||
|
||||
static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
|
||||
Type elementType, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static std::vector<Value*> sliceInput(Value* value, int batch_size,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::TransposeOp createTransposeOp(Value* value, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs,
|
||||
const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType,
|
||||
Location loc, PatternRewriter& rewriter);
|
||||
|
||||
PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
456
tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
Normal file
456
tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
Normal file
@ -0,0 +1,456 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
int32_t val, mlir::Location location) {
|
||||
auto type = builder->getTensorType(shape, builder->getIntegerType(32));
|
||||
auto attr = DenseElementsAttr::get(type, val);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
float val, mlir::Location location) {
|
||||
auto type = builder->getTensorType(shape, builder->getF32Type());
|
||||
auto attr = DenseElementsAttr::get(type, val);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> values, mlir::Location location) {
|
||||
auto type = builder->getTensorType(static_cast<int>(shape.size()),
|
||||
builder->getIntegerType(64));
|
||||
auto attr = DenseElementsAttr::get(type, values);
|
||||
return builder->create<ConstantOp>(location, type, attr);
|
||||
}
|
||||
|
||||
Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) {
|
||||
return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
|
||||
builder->getUnitAttr());
|
||||
}
|
||||
|
||||
Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 2> perm = {1, 0};
|
||||
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto transpose_type = type;
|
||||
auto transpose_shape = functional::map(
|
||||
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
||||
perm);
|
||||
auto elem_type = transpose_type.getElementType();
|
||||
auto result_type = builder->getTensorType(transpose_shape, elem_type);
|
||||
|
||||
return builder->create<TF::TransposeOp>(location, result_type,
|
||||
value_to_transpose, perm_op);
|
||||
}
|
||||
|
||||
Value* SliceRankedTensor(OpBuilder* builder, Value* input,
|
||||
ArrayRef<int64_t> begin_shape,
|
||||
ArrayRef<int64_t> begin_values,
|
||||
ArrayRef<int64_t> size_shape,
|
||||
ArrayRef<int64_t> size_values,
|
||||
mlir::Location location) {
|
||||
// Create a dense constant op for slice's begin
|
||||
auto slice_i2c_begin =
|
||||
CreateI64DenseConst(builder, begin_shape, begin_values, location);
|
||||
|
||||
// Create a dense constant op for slice's size
|
||||
auto slice_i2c_size =
|
||||
CreateI64DenseConst(builder, size_shape, size_values, location);
|
||||
|
||||
return builder->create<TF::SliceOp>(
|
||||
location,
|
||||
builder->getTensorType(
|
||||
size_values,
|
||||
input->getType().cast<RankedTensorType>().getElementType()),
|
||||
input, slice_i2c_begin, slice_i2c_size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
|
||||
SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
|
||||
input2cell_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
|
||||
SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
|
||||
input2input_ = couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, weight_transposed_,
|
||||
weight_slice_shape_, begin_i2i_values,
|
||||
weight_slice_shape_,
|
||||
weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
|
||||
int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
|
||||
input2forget_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
|
||||
int input_output_start =
|
||||
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
|
||||
input2output_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
|
||||
weight_slice_shape_, weight_slice_size_input_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
|
||||
SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
|
||||
rec2cell_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
|
||||
SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
|
||||
rec2input_ = couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, weight_transposed_,
|
||||
weight_slice_shape_, begin_rec2i_values,
|
||||
weight_slice_shape_,
|
||||
weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
|
||||
int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
|
||||
rec2forget_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
|
||||
int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
|
||||
rec2output_ = SliceRankedTensor(
|
||||
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
|
||||
weight_slice_shape_, weight_slice_size_recurrent_values_,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
|
||||
SmallVector<int64_t, 1> begin_bias2c_values = {0};
|
||||
bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2c_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
|
||||
SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
|
||||
bias2input_ =
|
||||
couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2i_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
|
||||
int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
|
||||
SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
|
||||
bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2f_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
|
||||
int bias_output_start =
|
||||
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
|
||||
SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
|
||||
bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
|
||||
begin_bias2o_values, bias_slice_shape_,
|
||||
bias_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
|
||||
SmallVector<int64_t, 2> projection_slice_shape = {
|
||||
1, num_cols_projection_transposed_};
|
||||
SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
|
||||
SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
|
||||
proj_weight_ =
|
||||
!projection_
|
||||
? none_
|
||||
: SliceRankedTensor(
|
||||
&builder_, projection_transposed_, projection_slice_shape,
|
||||
projection_slice_begin_values, projection_slice_shape,
|
||||
projection_slice_size_values, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
|
||||
proj_bias_ = !projection_type_
|
||||
? none_
|
||||
: CreateI32SplatConst(&builder_, {n_output_}, 0,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
|
||||
input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
|
||||
input_cell_state_ =
|
||||
CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
|
||||
cell_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
|
||||
input_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
|
||||
forget_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
|
||||
output_layer_norm_coefficients_ = none_;
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
|
||||
// Transpose both weight and projection.
|
||||
weight_transposed_ =
|
||||
Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
|
||||
projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
|
||||
fused_func_op_.getLoc());
|
||||
|
||||
none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
|
||||
// Extract input to cifg gates via slicing the weight tensor
|
||||
SetWeightForInputToCellGate();
|
||||
SetWeightForInputToInputGate();
|
||||
SetWeightForInputToForgetGate();
|
||||
SetWeightForInputToOutputGate();
|
||||
|
||||
// Extract recurrent to cifg gates via slicing the weight tensor
|
||||
SetWeightForRecurrentToCellGate();
|
||||
SetWeightForRecurrentToInputGate();
|
||||
SetWeightForRecurrentToForgetGate();
|
||||
SetWeightForRecurrentToOutputGate();
|
||||
|
||||
// Extract bias to cifg gates via slicing the bias tensor
|
||||
SetBiasToCellGate();
|
||||
SetBiasToInputGate();
|
||||
SetBiasToForgetGate();
|
||||
SetBiasToOutputGate();
|
||||
|
||||
// Extract projection and set an empty projection bias
|
||||
SetProjection();
|
||||
SetProjectionBias();
|
||||
|
||||
// Set the variable tensors
|
||||
SetInputActivationState();
|
||||
SetInputCellState();
|
||||
|
||||
// Extract the layer norm coefficients
|
||||
SetCellLayerNormCoefficients();
|
||||
SetInputLayerNormCoefficients();
|
||||
SetForgetLayerNormCoefficients();
|
||||
SetOutputLayerNormCoefficients();
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
|
||||
// https://github.com/tensorflow/community/pull/113
|
||||
auto attr = fused_func_op_.getAttrOfType<StringAttr>("tf_.implements");
|
||||
if (!attr) {
|
||||
fused_func_op_.setAttr("tf._implements",
|
||||
builder_.getStringAttr(GetCompositeOpName()));
|
||||
}
|
||||
SmallVector<int64_t, 2> output_shape{1, n_output_};
|
||||
auto input_types = fused_func_op_.getType().getInputs();
|
||||
auto output_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
input_->getType().cast<RankedTensorType>().getElementType());
|
||||
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
|
||||
fused_func_op_.getContext()));
|
||||
}
|
||||
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
||||
// Update the func signature, based on output shape.
|
||||
// The func will ultimately return the output of the fused
|
||||
// LSTM op.
|
||||
UpdateFuncSignature();
|
||||
|
||||
// Transoform the weights, projection, bias and layer norm coefficients
|
||||
// to generate operands for the TFL fused LSTM op.
|
||||
GenerateFusedOpOperands();
|
||||
|
||||
// Create the fused LSTM op.
|
||||
SmallVector<int64_t, 2> output_shape = {1, n_output_};
|
||||
auto result_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
input_->getType().cast<RankedTensorType>().getElementType());
|
||||
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
|
||||
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
|
||||
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
|
||||
rec2output_, /*cell_to_input_weights*/ none_,
|
||||
/*cell_to_forget_weights*/ none_,
|
||||
/*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
|
||||
bias2output_, proj_weight_, proj_bias_, input_activation_state_,
|
||||
input_cell_state_, input_layer_norm_coefficients_,
|
||||
forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
|
||||
output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
|
||||
builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
|
||||
builder_.getStringAttr("FULL"));
|
||||
|
||||
builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(), lstm_.getResult());
|
||||
}
|
||||
|
||||
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
|
||||
num_gates_ = couple_input_forget_gates_ ? 3 : 4;
|
||||
|
||||
input_ = fused_func_op_.getArgument(0);
|
||||
bias_ = fused_func_op_.getArgument(2);
|
||||
|
||||
weight_ = fused_func_op_.getArgument(1);
|
||||
weight_type_ = weight_->getType().cast<RankedTensorType>();
|
||||
|
||||
if (weight_type_.getRank() != 2) {
|
||||
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
|
||||
}
|
||||
|
||||
if (weight_type_.getDimSize(1) % num_gates_ != 0) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "Invalid dimension 1 of weight tensor, "
|
||||
"should be divisible by the number of gates";
|
||||
}
|
||||
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
|
||||
|
||||
projection_ = fused_func_op_.getArgument(3);
|
||||
projection_type_ = projection_->getType().cast<RankedTensorType>();
|
||||
if (projection_type_.getRank() != 2) {
|
||||
n_output_ = n_cell_;
|
||||
} else {
|
||||
n_output_ = projection_type_.getDimSize(1);
|
||||
}
|
||||
n_input_ = weight_type_.getDimSize(0) - n_output_;
|
||||
num_cols_weight_transposed_ = weight_type_.getDimSize(0);
|
||||
num_cols_projection_transposed_ = projection_type_.getDimSize(0);
|
||||
|
||||
bias_slice_shape_ = {n_cell_};
|
||||
bias_size_values_ = {n_cell_};
|
||||
weight_slice_shape_ = {1, num_cols_weight_transposed_};
|
||||
weight_slice_size_input_values_ = {n_cell_, n_input_};
|
||||
weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
|
||||
if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "Specified LayerNormalizedLSTMCellSimple was not of the expected "
|
||||
"interface and cannot not be converted to the fused LSTM op";
|
||||
}
|
||||
|
||||
layer_norm_scale_ = fused_func_op_.getArgument(4);
|
||||
layer_norm_scale_type_ =
|
||||
layer_norm_scale_->getType().cast<RankedTensorType>();
|
||||
if (layer_norm_scale_type_.getRank() != 1) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "The layer_norm_scale tensor was not of rank 1";
|
||||
}
|
||||
layer_norm_slice_shape_ = {n_cell_};
|
||||
layer_norm_size_values_ = {n_cell_};
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetCellLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
|
||||
cell_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_cell_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetInputLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
|
||||
input_layer_norm_coefficients_ =
|
||||
couple_input_forget_gates_
|
||||
? none_
|
||||
: SliceRankedTensor(
|
||||
&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_input_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetForgetLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
|
||||
forget_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_forget_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
|
||||
SetOutputLayerNormCoefficients() {
|
||||
SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
|
||||
output_layer_norm_coefficients_ =
|
||||
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
|
||||
begin_output_layer_norm_values, layer_norm_slice_shape_,
|
||||
layer_norm_size_values_, fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
214
tensorflow/compiler/mlir/lite/utils/lstm_utils.h
Normal file
214
tensorflow/compiler/mlir/lite/utils/lstm_utils.h
Normal file
@ -0,0 +1,214 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with op attributes.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
constexpr char kLstmCellSimple[] = "LSTMCellSimple";
|
||||
constexpr char kLayerNormalizedLstmCellSimple[] =
|
||||
"LayerNormalizedLstmCellSimple";
|
||||
|
||||
// A utility class that enables the conversion of the LSTMCellSimple composite
|
||||
// op into a fused TFL LSTM op. The fused op is contained within a FuncOp
|
||||
// that also contains other supporting ops needed to construct the operands for
|
||||
// the fused op. The caller provides the containing FuncOp as input with
|
||||
// arguments specifying the input, weight, projection and bias.
|
||||
// The weight, pprojection, bias and layer norm scale all need to be
|
||||
// RankedTensorType.
|
||||
// This class sets the layer norm coefficients to NoneType.
|
||||
class ConvertLSTMCellSimpleToFusedLSTM {
|
||||
public:
|
||||
// TODO(b/140053256): The couple_input_forget_gates should be specified on
|
||||
// FuncOp as an attribute.
|
||||
explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op,
|
||||
bool couple_input_forget_gates)
|
||||
: fused_func_op_(fused_func_op),
|
||||
couple_input_forget_gates_(couple_input_forget_gates),
|
||||
builder_(fused_func_op.getBody()) {}
|
||||
|
||||
// not copyable.
|
||||
ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
|
||||
delete;
|
||||
ConvertLSTMCellSimpleToFusedLSTM& operator=(
|
||||
const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
|
||||
|
||||
// verify input func op arguments and initialize internal state.
|
||||
virtual LogicalResult Initialize();
|
||||
|
||||
virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
|
||||
|
||||
// Rewrite the func body with constructed fused lstm.
|
||||
void RewriteFunc();
|
||||
|
||||
protected:
|
||||
void UpdateFuncSignature();
|
||||
void GenerateFusedOpOperands();
|
||||
|
||||
void SetWeightForInputToCellGate();
|
||||
void SetWeightForInputToInputGate();
|
||||
void SetWeightForInputToForgetGate();
|
||||
void SetWeightForInputToOutputGate();
|
||||
|
||||
void SetWeightForRecurrentToCellGate();
|
||||
void SetWeightForRecurrentToInputGate();
|
||||
void SetWeightForRecurrentToForgetGate();
|
||||
void SetWeightForRecurrentToOutputGate();
|
||||
|
||||
void SetBiasToCellGate();
|
||||
void SetBiasToInputGate();
|
||||
void SetBiasToForgetGate();
|
||||
void SetBiasToOutputGate();
|
||||
|
||||
void SetProjection();
|
||||
void SetProjectionBias();
|
||||
|
||||
void SetInputActivationState();
|
||||
void SetInputCellState();
|
||||
|
||||
virtual void SetCellLayerNormCoefficients();
|
||||
virtual void SetInputLayerNormCoefficients();
|
||||
virtual void SetForgetLayerNormCoefficients();
|
||||
virtual void SetOutputLayerNormCoefficients();
|
||||
|
||||
// specified state
|
||||
FuncOp fused_func_op_;
|
||||
Value* input_;
|
||||
Value* weight_;
|
||||
Value* bias_;
|
||||
Value* projection_;
|
||||
bool couple_input_forget_gates_;
|
||||
|
||||
// internal state
|
||||
Value* weight_transposed_;
|
||||
Value* projection_transposed_;
|
||||
RankedTensorType weight_type_;
|
||||
RankedTensorType projection_type_;
|
||||
int num_gates_;
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
int n_input_;
|
||||
int num_cols_weight_transposed_;
|
||||
int num_cols_projection_transposed_;
|
||||
|
||||
// input -> cifg
|
||||
Value* input2input_;
|
||||
Value* input2forget_;
|
||||
Value* input2cell_;
|
||||
Value* input2output_;
|
||||
|
||||
// reccurrent -> cifg
|
||||
Value* rec2input_;
|
||||
Value* rec2forget_;
|
||||
Value* rec2cell_;
|
||||
Value* rec2output_;
|
||||
|
||||
// bias -> cifg
|
||||
Value* bias2input_;
|
||||
Value* bias2forget_;
|
||||
Value* bias2cell_;
|
||||
Value* bias2output_;
|
||||
|
||||
// projection
|
||||
Value* proj_weight_;
|
||||
Value* proj_bias_;
|
||||
|
||||
// state
|
||||
Value* input_activation_state_;
|
||||
Value* input_cell_state_;
|
||||
|
||||
// layer norm coefficients
|
||||
Value* input_layer_norm_coefficients_;
|
||||
Value* forget_layer_norm_coefficients_;
|
||||
Value* cell_layer_norm_coefficients_;
|
||||
Value* output_layer_norm_coefficients_;
|
||||
|
||||
mlir::TFL::LSTMOp lstm_;
|
||||
|
||||
Value* none_;
|
||||
SmallVector<int64_t, 1> bias_slice_shape_;
|
||||
SmallVector<int64_t, 1> bias_size_values_;
|
||||
SmallVector<int64_t, 2> weight_slice_shape_;
|
||||
SmallVector<int64_t, 2> weight_slice_size_input_values_;
|
||||
SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
|
||||
OpBuilder builder_;
|
||||
};
|
||||
|
||||
// A utility class that enables the conversion of the
|
||||
// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
|
||||
// fused op is contained within a FuncOp that also contains other supporting ops
|
||||
// needed to construct the operands for the fused op. The caller provides the
|
||||
// containing FuncOp as input with arguments specifying the input, weight,
|
||||
// projection, bias and layer norm scale. The weight, pprojection, bias and
|
||||
// layer norm scale all need to be RankedTensorType.
|
||||
// This class overrides the layer norm coefficient setters from the base class.
|
||||
class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
|
||||
: public ConvertLSTMCellSimpleToFusedLSTM {
|
||||
public:
|
||||
// TODO(b/140053256): The couple_input_forget_gates should be specified on
|
||||
// FuncOp as an attribute.
|
||||
explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
||||
mlir::FuncOp fused_func_op, bool couple_input_forget_gates)
|
||||
: ConvertLSTMCellSimpleToFusedLSTM(fused_func_op,
|
||||
couple_input_forget_gates) {}
|
||||
|
||||
// not copyable.
|
||||
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
||||
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
|
||||
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
||||
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
|
||||
|
||||
llvm::StringRef GetCompositeOpName() override {
|
||||
return kLayerNormalizedLstmCellSimple;
|
||||
}
|
||||
|
||||
LogicalResult Initialize() override;
|
||||
|
||||
protected:
|
||||
void SetCellLayerNormCoefficients() override;
|
||||
void SetInputLayerNormCoefficients() override;
|
||||
void SetForgetLayerNormCoefficients() override;
|
||||
void SetOutputLayerNormCoefficients() override;
|
||||
|
||||
private:
|
||||
// specified state
|
||||
Value* layer_norm_scale_;
|
||||
|
||||
// internal state
|
||||
RankedTensorType layer_norm_scale_type_;
|
||||
SmallVector<int64_t, 1> layer_norm_slice_shape_;
|
||||
SmallVector<int64_t, 1> layer_norm_size_values_;
|
||||
};
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
216
tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
Normal file
216
tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
Normal file
@ -0,0 +1,216 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
FuncOp createFusedFunc(mlir::Builder* builder) {
|
||||
SmallVector<int64_t, 2> input_shape{1, 2};
|
||||
SmallVector<int64_t, 2> weight_shape{3, 12};
|
||||
SmallVector<int64_t, 1> bias_shape{2};
|
||||
SmallVector<int64_t, 2> projection_shape{1, 2};
|
||||
SmallVector<int64_t, 1> layer_norm_scale{4};
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
auto input_type = builder->getTensorType(input_shape, builder->getF32Type());
|
||||
auto weight_type =
|
||||
builder->getTensorType(weight_shape, builder->getF32Type());
|
||||
auto bias_type = builder->getTensorType(bias_shape, builder->getF32Type());
|
||||
auto projection_type =
|
||||
builder->getTensorType(projection_shape, builder->getF32Type());
|
||||
auto layer_norm_scale_type =
|
||||
builder->getTensorType(layer_norm_scale, builder->getF32Type());
|
||||
auto output_type =
|
||||
builder->getTensorType(output_shape, builder->getF32Type());
|
||||
SmallVector<mlir::Type, 4> input_types{input_type, weight_type, bias_type,
|
||||
projection_type,
|
||||
layer_norm_scale_type};
|
||||
auto func_type = builder->getFunctionType(input_types, output_type);
|
||||
|
||||
auto func =
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
|
||||
builder->getContext()),
|
||||
"fused_func", func_type, {});
|
||||
func.addEntryBlock();
|
||||
return func;
|
||||
}
|
||||
|
||||
// TODO(ashwinm): Revisit if this test should be moved to a test pass
|
||||
// with FileCheck test after the pass that consumes the lstm_utils to stack
|
||||
// the layers.
|
||||
class LstmUtilsTest : public ::testing::Test {
|
||||
protected:
|
||||
LstmUtilsTest() {}
|
||||
|
||||
void SetUp() override {
|
||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
|
||||
fused_lstm_func_ = createFusedFunc(builder_.get());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
fused_lstm_func_.erase();
|
||||
builder_.reset();
|
||||
}
|
||||
FuncOp fused_lstm_func_;
|
||||
mlir::MLIRContext context_;
|
||||
std::unique_ptr<mlir::Builder> builder_;
|
||||
};
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
||||
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, false);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
// verify transpose
|
||||
EXPECT_EQ(
|
||||
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
|
||||
convert.GetCompositeOpName());
|
||||
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
|
||||
auto transpose_op = fused_lstm_func_.getBody().front().begin();
|
||||
transpose_op++;
|
||||
EXPECT_EQ(transpose_op->getOperand(0)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getDimSize(0),
|
||||
3);
|
||||
EXPECT_EQ(transpose_op->getOperand(0)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getDimSize(1),
|
||||
12);
|
||||
EXPECT_EQ(
|
||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
||||
0),
|
||||
12);
|
||||
EXPECT_EQ(
|
||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
||||
1),
|
||||
3);
|
||||
|
||||
auto return_op = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(return_op->getName().getStringRef(),
|
||||
mlir::ReturnOp::getOperationName());
|
||||
return_op++;
|
||||
EXPECT_EQ(return_op->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(return_op->getNumOperands(), 24);
|
||||
EXPECT_EQ(return_op->getNumResults(), 1);
|
||||
// cifg = false, so input2input is not None.
|
||||
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
|
||||
// input layer norm is None
|
||||
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
|
||||
output_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
|
||||
mlir::TFL::ConvertLSTMCellSimpleToFusedLSTM convert(fused_lstm_func_, true);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
auto it = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
|
||||
it++;
|
||||
EXPECT_EQ(it->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(it->getNumOperands(), 24);
|
||||
EXPECT_EQ(it->getNumResults(), 1);
|
||||
// cifg = true, so input2input is None.
|
||||
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>());
|
||||
}
|
||||
|
||||
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
||||
mlir::TFL::ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM convert(
|
||||
fused_lstm_func_, false);
|
||||
|
||||
auto result = convert.Initialize();
|
||||
EXPECT_FALSE(failed(result));
|
||||
|
||||
convert.RewriteFunc();
|
||||
fused_lstm_func_.dump();
|
||||
|
||||
EXPECT_EQ(
|
||||
fused_lstm_func_.getAttrOfType<StringAttr>("tf._implements").getValue(),
|
||||
convert.GetCompositeOpName());
|
||||
EXPECT_EQ(fused_lstm_func_.getNumArguments(), 5);
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
|
||||
auto it = fused_lstm_func_.getBody().back().rbegin();
|
||||
EXPECT_EQ(it->getName().getStringRef(), mlir::ReturnOp::getOperationName());
|
||||
it++;
|
||||
EXPECT_EQ(it->getName().getStringRef(),
|
||||
mlir::TFL::LSTMOp::getOperationName());
|
||||
EXPECT_EQ(it->getNumOperands(), 24);
|
||||
EXPECT_EQ(it->getNumResults(), 1);
|
||||
// cifg = false, so input2input is not None.
|
||||
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
|
||||
|
||||
// input layer norm
|
||||
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>());
|
||||
EXPECT_EQ(
|
||||
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(),
|
||||
1);
|
||||
EXPECT_EQ(
|
||||
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3);
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getDimSize(i),
|
||||
output_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
86
tensorflow/compiler/mlir/op_name_mapper.cc
Normal file
86
tensorflow/compiler/mlir/op_name_mapper.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/op_name_mapper.h"
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using llvm::StringRef;
|
||||
using mlir::Operation;
|
||||
|
||||
OpNameMapper::~OpNameMapper() {}
|
||||
|
||||
std::string OpNameMapper::GetUniqueName(llvm::StringRef prefix) {
|
||||
std::string name = prefix;
|
||||
auto& val = name_to_count_[name];
|
||||
if (!val) {
|
||||
++val;
|
||||
return name;
|
||||
}
|
||||
|
||||
llvm::SmallString<64> probe_name(prefix);
|
||||
while (true) {
|
||||
probe_name.resize(prefix.size());
|
||||
// TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
|
||||
// of 1.
|
||||
// TODO(jpienaar): Switch to radix 36 and update tests.
|
||||
llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10,
|
||||
/*Signed=*/false);
|
||||
if (!name_to_count_.count(probe_name)) {
|
||||
name = llvm::StringRef(probe_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
const std::string& OpNameMapper::GetUniqueName(Operation* op) {
|
||||
auto& name = op_to_name_[op];
|
||||
if (!name.empty()) return name;
|
||||
// Update the value in the map with unique name.
|
||||
name = GetUniqueName(GetName(op));
|
||||
return name;
|
||||
}
|
||||
|
||||
int OpNameMapper::InitOpName(mlir::Operation* op, llvm::StringRef name) {
|
||||
op_to_name_[op] = name;
|
||||
return name_to_count_[name]++;
|
||||
}
|
||||
|
||||
std::string OpLocNameMapper::GetName(Operation* op) {
|
||||
if (auto name_loc = op->getLoc().dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().str();
|
||||
|
||||
if (auto call_loc = op->getLoc().dyn_cast<mlir::CallSiteLoc>()) {
|
||||
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
|
||||
// if imported with DebugInfo), else use the fallback naming scheme below.
|
||||
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
|
||||
return name_loc.getName().str();
|
||||
}
|
||||
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type.
|
||||
return op->getName().getStringRef();
|
||||
}
|
||||
|
||||
std::string OpStripNameMapper::GetName(Operation* op) {
|
||||
return llvm::APInt(32, count_++)
|
||||
.toString(/*Radix=*/36,
|
||||
/*Signed=*/false);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user