Merge branch 'tensorflow-master'
This commit is contained in:
commit
33dffe53fb
44
.bazelrc
44
.bazelrc
@ -39,32 +39,46 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
|||||||
|
|
||||||
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
|
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
|
||||||
build:download_clang --define=using_clang=true
|
build:download_clang --define=using_clang=true
|
||||||
|
build:download_clang --action_env TF_DOWNLOAD_CLANG=1
|
||||||
# Instruct clang to use LLD for linking.
|
# Instruct clang to use LLD for linking.
|
||||||
# This only works with GPU builds currently, since Bazel sets -B/usr/bin in
|
# This only works with GPU builds currently, since Bazel sets -B/usr/bin in
|
||||||
# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
|
# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
|
||||||
# the downloaded one.
|
# the downloaded one.
|
||||||
build:download_clang_use_lld --linkopt='-fuse-ld=lld'
|
build:download_clang_use_lld --linkopt='-fuse-ld=lld'
|
||||||
|
|
||||||
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
# This config refers to building with CUDA available. It does not necessarily
|
||||||
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
|
# mean that we build CUDA op kernels.
|
||||||
|
build:using_cuda --define=using_cuda=true
|
||||||
|
build:using_cuda --action_env TF_NEED_CUDA=1
|
||||||
|
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||||
|
|
||||||
|
# This config refers to building CUDA op kernels with nvcc.
|
||||||
|
build:cuda --config=using_cuda
|
||||||
|
build:cuda --define=using_cuda_nvcc=true
|
||||||
|
|
||||||
|
# This config refers to building CUDA op kernels with clang.
|
||||||
|
build:cuda_clang --config=using_cuda
|
||||||
|
build:cuda_clang --define=using_cuda_clang=true
|
||||||
|
build:cuda_clang --define=using_clang=true
|
||||||
|
|
||||||
|
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||||
|
|
||||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||||
|
build:rocm --action_env TF_NEED_ROCM=1
|
||||||
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
|
|
||||||
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
|
|
||||||
|
|
||||||
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||||
build:sycl --define=using_sycl=true --define=using_trisycl=false
|
build:sycl --define=using_sycl=true
|
||||||
|
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
|
||||||
|
|
||||||
build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain
|
build:sycl_nodouble --config=sycl
|
||||||
build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
|
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
|
||||||
|
|
||||||
build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain
|
build:sycl_nodouble --config=sycl
|
||||||
build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
|
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
|
||||||
|
|
||||||
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
build:sycl_nodouble --config=sycl
|
||||||
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true
|
build:sycl_trisycl --define=using_trisycl=true
|
||||||
|
|
||||||
# Options extracted from configure script
|
# Options extracted from configure script
|
||||||
build:gdr --define=with_gdr_support=true
|
build:gdr --define=with_gdr_support=true
|
||||||
@ -87,6 +101,9 @@ build --spawn_strategy=standalone
|
|||||||
build --strategy=Genrule=standalone
|
build --strategy=Genrule=standalone
|
||||||
build -c opt
|
build -c opt
|
||||||
|
|
||||||
|
# Make Bazel print out all options from rc files.
|
||||||
|
build --announce_rc
|
||||||
|
|
||||||
# Other build flags.
|
# Other build flags.
|
||||||
build --define=grpc_no_ares=true
|
build --define=grpc_no_ares=true
|
||||||
|
|
||||||
@ -97,8 +114,7 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
|||||||
# Build TF with C++ 17 features.
|
# Build TF with C++ 17 features.
|
||||||
build:c++17 --cxxopt=-std=c++1z
|
build:c++17 --cxxopt=-std=c++1z
|
||||||
build:c++17 --cxxopt=-stdlib=libc++
|
build:c++17 --cxxopt=-stdlib=libc++
|
||||||
build:c++1z --cxxopt=-std=c++1z
|
build:c++1z --config=c++17
|
||||||
build:c++1z --cxxopt=-stdlib=libc++
|
|
||||||
|
|
||||||
# Default paths for TF_SYSTEM_LIBS
|
# Default paths for TF_SYSTEM_LIBS
|
||||||
build --define=PREFIX=/usr
|
build --define=PREFIX=/usr
|
||||||
|
@ -38,7 +38,13 @@ working on getting your pull request submitted to our internal repository. After
|
|||||||
the change has been submitted internally, your pull request will be merged
|
the change has been submitted internally, your pull request will be merged
|
||||||
automatically on GitHub.
|
automatically on GitHub.
|
||||||
|
|
||||||
If you want to contribute but you're not sure where to start, take a look at the
|
If you want to contribute, start working through the TensorFlow codebase,
|
||||||
|
navigate to the
|
||||||
|
[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start
|
||||||
|
looking through interesting issues. If you are not sure of where to start, then
|
||||||
|
start by trying one of the smaller/easier issues here i.e.
|
||||||
|
[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue)
|
||||||
|
and then take a look at the
|
||||||
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
|
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
|
||||||
These are issues that we believe are particularly well suited for outside
|
These are issues that we believe are particularly well suited for outside
|
||||||
contributions, often because we probably won't get to them right now. If you
|
contributions, often because we probably won't get to them right now. If you
|
||||||
|
60
configure.py
60
configure.py
@ -403,7 +403,8 @@ def set_action_env_var(environ_cp,
|
|||||||
enabled_by_default,
|
enabled_by_default,
|
||||||
question=None,
|
question=None,
|
||||||
yes_reply=None,
|
yes_reply=None,
|
||||||
no_reply=None):
|
no_reply=None,
|
||||||
|
bazel_config_name=None):
|
||||||
"""Set boolean action_env variable.
|
"""Set boolean action_env variable.
|
||||||
|
|
||||||
Ask user if query_item will be enabled. Default is used if no input is given.
|
Ask user if query_item will be enabled. Default is used if no input is given.
|
||||||
@ -418,12 +419,16 @@ def set_action_env_var(environ_cp,
|
|||||||
question: optional string for how to ask for user input.
|
question: optional string for how to ask for user input.
|
||||||
yes_reply: optional string for reply when feature is enabled.
|
yes_reply: optional string for reply when feature is enabled.
|
||||||
no_reply: optional string for reply when feature is disabled.
|
no_reply: optional string for reply when feature is disabled.
|
||||||
|
bazel_config_name: adding config to .bazelrc instead of action_env.
|
||||||
"""
|
"""
|
||||||
var = int(
|
var = int(
|
||||||
get_var(environ_cp, var_name, query_item, enabled_by_default, question,
|
get_var(environ_cp, var_name, query_item, enabled_by_default, question,
|
||||||
yes_reply, no_reply))
|
yes_reply, no_reply))
|
||||||
|
|
||||||
|
if not bazel_config_name:
|
||||||
write_action_env_to_bazelrc(var_name, var)
|
write_action_env_to_bazelrc(var_name, var)
|
||||||
|
elif var:
|
||||||
|
write_to_bazelrc('build --config=%s' % bazel_config_name)
|
||||||
environ_cp[var_name] = str(var)
|
environ_cp[var_name] = str(var)
|
||||||
|
|
||||||
|
|
||||||
@ -543,7 +548,8 @@ def set_tf_cuda_clang(environ_cp):
|
|||||||
False,
|
False,
|
||||||
question=question,
|
question=question,
|
||||||
yes_reply=yes_reply,
|
yes_reply=yes_reply,
|
||||||
no_reply=no_reply)
|
no_reply=no_reply,
|
||||||
|
bazel_config_name='cuda_clang')
|
||||||
|
|
||||||
|
|
||||||
def set_tf_download_clang(environ_cp):
|
def set_tf_download_clang(environ_cp):
|
||||||
@ -558,7 +564,8 @@ def set_tf_download_clang(environ_cp):
|
|||||||
False,
|
False,
|
||||||
question=question,
|
question=question,
|
||||||
yes_reply=yes_reply,
|
yes_reply=yes_reply,
|
||||||
no_reply=no_reply)
|
no_reply=no_reply,
|
||||||
|
bazel_config_name='download_clang')
|
||||||
|
|
||||||
|
|
||||||
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
|
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
|
||||||
@ -782,8 +789,8 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path):
|
|||||||
print('WARNING: The NDK version in %s is %s, which is not '
|
print('WARNING: The NDK version in %s is %s, which is not '
|
||||||
'supported by Bazel (officially supported versions: %s). Please use '
|
'supported by Bazel (officially supported versions: %s). Please use '
|
||||||
'another version. Compiling Android targets may result in confusing '
|
'another version. Compiling Android targets may result in confusing '
|
||||||
'errors.\n' % (android_ndk_home_path, ndk_version,
|
'errors.\n' %
|
||||||
_SUPPORTED_ANDROID_NDK_VERSIONS))
|
(android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||||
|
|
||||||
# Now grab the NDK API level to use. Note that this is different from the
|
# Now grab the NDK API level to use. Note that this is different from the
|
||||||
# SDK API level, as the NDK API level is effectively the *min* target SDK
|
# SDK API level, as the NDK API level is effectively the *min* target SDK
|
||||||
@ -952,6 +959,7 @@ def set_tf_nccl_version(environ_cp):
|
|||||||
ask_nccl_version, '')
|
ask_nccl_version, '')
|
||||||
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
|
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
|
||||||
|
|
||||||
|
|
||||||
def get_native_cuda_compute_capabilities(environ_cp):
|
def get_native_cuda_compute_capabilities(environ_cp):
|
||||||
"""Get native cuda compute capabilities.
|
"""Get native cuda compute capabilities.
|
||||||
|
|
||||||
@ -1293,9 +1301,6 @@ def configure_ios():
|
|||||||
"""
|
"""
|
||||||
if not is_macos():
|
if not is_macos():
|
||||||
return
|
return
|
||||||
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
|
|
||||||
print(
|
|
||||||
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
|
|
||||||
for filepath in APPLE_BAZEL_FILES:
|
for filepath in APPLE_BAZEL_FILES:
|
||||||
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
||||||
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
||||||
@ -1386,7 +1391,7 @@ def main():
|
|||||||
# environment variables.
|
# environment variables.
|
||||||
environ_cp = dict(os.environ)
|
environ_cp = dict(os.environ)
|
||||||
|
|
||||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
|
current_bazel_version = check_bazel_version('0.24.1', '0.25.3')
|
||||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||||
|
|
||||||
reset_tf_configure_bazelrc()
|
reset_tf_configure_bazelrc()
|
||||||
@ -1422,7 +1427,12 @@ def main():
|
|||||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||||
xla_enabled_by_default, 'xla')
|
xla_enabled_by_default, 'xla')
|
||||||
|
|
||||||
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
|
set_action_env_var(
|
||||||
|
environ_cp,
|
||||||
|
'TF_NEED_OPENCL_SYCL',
|
||||||
|
'OpenCL SYCL',
|
||||||
|
False,
|
||||||
|
bazel_config_name='sycl')
|
||||||
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
|
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
|
||||||
set_host_cxx_compiler(environ_cp)
|
set_host_cxx_compiler(environ_cp)
|
||||||
set_host_c_compiler(environ_cp)
|
set_host_c_compiler(environ_cp)
|
||||||
@ -1432,30 +1442,44 @@ def main():
|
|||||||
else:
|
else:
|
||||||
set_trisycl_include_dir(environ_cp)
|
set_trisycl_include_dir(environ_cp)
|
||||||
|
|
||||||
set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
|
set_action_env_var(
|
||||||
|
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
|
||||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and
|
if (environ_cp.get('TF_NEED_ROCM') == '1' and
|
||||||
'LD_LIBRARY_PATH' in environ_cp and
|
'LD_LIBRARY_PATH' in environ_cp and
|
||||||
environ_cp.get('LD_LIBRARY_PATH') != '1'):
|
environ_cp.get('LD_LIBRARY_PATH') != '1'):
|
||||||
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
||||||
environ_cp.get('LD_LIBRARY_PATH'))
|
environ_cp.get('LD_LIBRARY_PATH'))
|
||||||
|
|
||||||
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
|
environ_cp['TF_NEED_CUDA'] = str(
|
||||||
|
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||||
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
||||||
|
|
||||||
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
|
set_action_env_var(
|
||||||
|
environ_cp,
|
||||||
|
'TF_NEED_TENSORRT',
|
||||||
|
'TensorRT',
|
||||||
|
False,
|
||||||
|
bazel_config_name='tensorrt')
|
||||||
|
|
||||||
environ_save = dict(environ_cp)
|
environ_save = dict(environ_cp)
|
||||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||||
|
|
||||||
if validate_cuda_config(environ_cp):
|
if validate_cuda_config(environ_cp):
|
||||||
cuda_env_names = [
|
cuda_env_names = [
|
||||||
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
|
'TF_CUDA_VERSION',
|
||||||
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
|
'TF_CUBLAS_VERSION',
|
||||||
|
'TF_CUDNN_VERSION',
|
||||||
|
'TF_TENSORRT_VERSION',
|
||||||
|
'TF_NCCL_VERSION',
|
||||||
|
'TF_CUDA_PATHS',
|
||||||
# Items below are for backwards compatibility when not using
|
# Items below are for backwards compatibility when not using
|
||||||
# TF_CUDA_PATHS.
|
# TF_CUDA_PATHS.
|
||||||
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
|
'CUDA_TOOLKIT_PATH',
|
||||||
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
|
'CUDNN_INSTALL_PATH',
|
||||||
|
'NCCL_INSTALL_PATH',
|
||||||
|
'NCCL_HDR_PATH',
|
||||||
|
'TENSORRT_INSTALL_PATH'
|
||||||
]
|
]
|
||||||
# Note: set_action_env_var above already writes to bazelrc.
|
# Note: set_action_env_var above already writes to bazelrc.
|
||||||
for name in cuda_env_names:
|
for name in cuda_env_names:
|
||||||
@ -1506,8 +1530,6 @@ def main():
|
|||||||
# CUDA not required. Ask whether we should download the clang toolchain and
|
# CUDA not required. Ask whether we should download the clang toolchain and
|
||||||
# use it for the CPU build.
|
# use it for the CPU build.
|
||||||
set_tf_download_clang(environ_cp)
|
set_tf_download_clang(environ_cp)
|
||||||
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
|
|
||||||
write_to_bazelrc('build --config=download_clang')
|
|
||||||
|
|
||||||
# SYCL / ROCm / CUDA are mutually exclusive.
|
# SYCL / ROCm / CUDA are mutually exclusive.
|
||||||
# At most 1 GPU platform can be configured.
|
# At most 1 GPU platform can be configured.
|
||||||
|
@ -59,7 +59,7 @@ except ImportError:
|
|||||||
|
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||||
_CONTRIB_WARNING = """
|
_CONTRIB_WARNING = """
|
||||||
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
||||||
For more information, please see:
|
For more information, please see:
|
||||||
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
|
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
|
||||||
* https://github.com/tensorflow/addons
|
* https://github.com/tensorflow/addons
|
||||||
|
@ -21,6 +21,9 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
|
"tf_attrtype.h",
|
||||||
|
"tf_datatype.h",
|
||||||
|
"tf_status.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
@ -51,6 +54,8 @@ tf_cuda_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"tf_datatype.h",
|
||||||
|
"tf_status.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
@ -61,6 +66,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
":tf_attrtype",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -71,16 +77,26 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_attrtype",
|
||||||
|
hdrs = ["tf_attrtype.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api",
|
name = "c_api",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
|
"tf_attrtype.h",
|
||||||
|
"tf_datatype.h",
|
||||||
|
"tf_status.h",
|
||||||
],
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api_no_xla",
|
":c_api_no_xla",
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
|
":tf_attrtype",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:with_xla_support": [
|
"//tensorflow:with_xla_support": [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
@ -96,14 +112,21 @@ tf_cuda_library(
|
|||||||
"c_api.cc",
|
"c_api.cc",
|
||||||
"c_api_function.cc",
|
"c_api_function.cc",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api.h"],
|
hdrs = [
|
||||||
|
"c_api.h",
|
||||||
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//tensorflow/c:__subpackages__"],
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
deps = [":c_api_internal"] + select({
|
deps = [
|
||||||
|
":c_api_internal",
|
||||||
|
":tf_attrtype",
|
||||||
|
":tf_datatype",
|
||||||
|
] + select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
":tf_status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"//tensorflow/cc/saved_model:loader_lite",
|
"//tensorflow/cc/saved_model:loader_lite",
|
||||||
"//tensorflow/cc:gradients",
|
"//tensorflow/cc:gradients",
|
||||||
@ -124,6 +147,37 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_status",
|
||||||
|
srcs = ["tf_status.cc"],
|
||||||
|
hdrs = ["tf_status.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_datatype",
|
||||||
|
srcs = ["tf_datatype.cc"],
|
||||||
|
hdrs = ["tf_datatype.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api_experimental",
|
name = "c_api_experimental",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -137,6 +191,7 @@ tf_cuda_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
|
":checkpoint_reader",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_internal",
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
@ -151,15 +206,6 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "c_api_headers",
|
|
||||||
hdrs = [
|
|
||||||
"c_api.h",
|
|
||||||
],
|
|
||||||
copts = tf_copts(),
|
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
|
||||||
)
|
|
||||||
|
|
||||||
exports_files(
|
exports_files(
|
||||||
[
|
[
|
||||||
"version_script.lds",
|
"version_script.lds",
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
// Required for IS_MOBILE_PLATFORM
|
// Required for IS_MOBILE_PLATFORM
|
||||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||||
|
|
||||||
@ -97,7 +98,6 @@ using tensorflow::TensorId;
|
|||||||
using tensorflow::TensorShape;
|
using tensorflow::TensorShape;
|
||||||
using tensorflow::TensorShapeProto;
|
using tensorflow::TensorShapeProto;
|
||||||
using tensorflow::VersionDef;
|
using tensorflow::VersionDef;
|
||||||
using tensorflow::error::Code;
|
|
||||||
using tensorflow::errors::FailedPrecondition;
|
using tensorflow::errors::FailedPrecondition;
|
||||||
using tensorflow::errors::InvalidArgument;
|
using tensorflow::errors::InvalidArgument;
|
||||||
using tensorflow::gtl::ArraySlice;
|
using tensorflow::gtl::ArraySlice;
|
||||||
@ -108,34 +108,6 @@ extern "C" {
|
|||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
const char* TF_Version() { return TF_VERSION_STRING; }
|
const char* TF_Version() { return TF_VERSION_STRING; }
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
size_t TF_DataTypeSize(TF_DataType dt) {
|
|
||||||
return static_cast<size_t>(
|
|
||||||
tensorflow::DataTypeSize(static_cast<DataType>(dt)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
|
|
||||||
TF_Status* TF_NewStatus() { return new TF_Status; }
|
|
||||||
|
|
||||||
void TF_DeleteStatus(TF_Status* s) { delete s; }
|
|
||||||
|
|
||||||
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
|
|
||||||
if (code == TF_OK) {
|
|
||||||
s->status = Status::OK();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_Code TF_GetCode(const TF_Status* s) {
|
|
||||||
return static_cast<TF_Code>(s->status.code());
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* TF_Message(const TF_Status* s) {
|
|
||||||
return s->status.error_message().c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -1697,7 +1669,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
|||||||
if (metadata.list_size == 0) {
|
if (metadata.list_size == 0) {
|
||||||
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
|
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
|
||||||
const auto& a = oper->node.op_def().attr(i);
|
const auto& a = oper->node.op_def().attr(i);
|
||||||
if (a.name().compare(attr_name) != 0) continue;
|
if (a.name() != attr_name) continue;
|
||||||
const string& typestr = a.type();
|
const string& typestr = a.type();
|
||||||
if (typestr == "list(string)") {
|
if (typestr == "list(string)") {
|
||||||
metadata.type = TF_ATTR_STRING;
|
metadata.type = TF_ATTR_STRING;
|
||||||
@ -2517,8 +2489,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
|||||||
// used in this graph
|
// used in this graph
|
||||||
for (const auto& pair : g->name_map) {
|
for (const auto& pair : g->name_map) {
|
||||||
const string& name = pair.first;
|
const string& name = pair.first;
|
||||||
if (name.compare(prefix) == 0 ||
|
if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
|
||||||
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
|
|
||||||
status->status = InvalidArgument(
|
status->status = InvalidArgument(
|
||||||
"prefix [", prefix,
|
"prefix [", prefix,
|
||||||
"] conflicts with existing node in the graph named [", name, "]");
|
"] conflicts with existing node in the graph named [", name, "]");
|
||||||
@ -2548,8 +2519,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
|||||||
// Adding the gradients to the graph can alter the prefix to prevent
|
// Adding the gradients to the graph can alter the prefix to prevent
|
||||||
// name collisions only if this prefix has not been provided explicitly
|
// name collisions only if this prefix has not been provided explicitly
|
||||||
// by the user. If it was provided, assert that it remained intact.
|
// by the user. If it was provided, assert that it remained intact.
|
||||||
if (prefix != nullptr &&
|
if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
|
||||||
!tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
|
|
||||||
status->status = tensorflow::errors::Internal(
|
status->status = tensorflow::errors::Internal(
|
||||||
"BUG: The gradients prefix have been unexpectedly altered when "
|
"BUG: The gradients prefix have been unexpectedly altered when "
|
||||||
"adding the nodes to the graph. This is a bug. Please file an "
|
"adding the nodes to the graph. This is a bug. Please file an "
|
||||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_attrtype.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// C API for TensorFlow.
|
// C API for TensorFlow.
|
||||||
//
|
//
|
||||||
@ -69,7 +73,7 @@ limitations under the License.
|
|||||||
// .dylib, .dll).
|
// .dylib, .dll).
|
||||||
// This duplicates the TF_EXPORT macro definition in
|
// This duplicates the TF_EXPORT macro definition in
|
||||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||||
// of any other includes.$a
|
// of any other includes.
|
||||||
#ifdef SWIG
|
#ifdef SWIG
|
||||||
#define TF_CAPI_EXPORT
|
#define TF_CAPI_EXPORT
|
||||||
#else
|
#else
|
||||||
@ -93,89 +97,6 @@ extern "C" {
|
|||||||
// TensorFlow library. TensorFlow using semantic versioning.
|
// TensorFlow library. TensorFlow using semantic versioning.
|
||||||
TF_CAPI_EXPORT extern const char* TF_Version(void);
|
TF_CAPI_EXPORT extern const char* TF_Version(void);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
|
|
||||||
// The enum values here are identical to corresponding values in types.proto.
|
|
||||||
typedef enum TF_DataType {
|
|
||||||
TF_FLOAT = 1,
|
|
||||||
TF_DOUBLE = 2,
|
|
||||||
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
|
|
||||||
TF_UINT8 = 4,
|
|
||||||
TF_INT16 = 5,
|
|
||||||
TF_INT8 = 6,
|
|
||||||
TF_STRING = 7,
|
|
||||||
TF_COMPLEX64 = 8, // Single-precision complex
|
|
||||||
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
|
|
||||||
TF_INT64 = 9,
|
|
||||||
TF_BOOL = 10,
|
|
||||||
TF_QINT8 = 11, // Quantized int8
|
|
||||||
TF_QUINT8 = 12, // Quantized uint8
|
|
||||||
TF_QINT32 = 13, // Quantized int32
|
|
||||||
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
|
|
||||||
TF_QINT16 = 15, // Quantized int16
|
|
||||||
TF_QUINT16 = 16, // Quantized uint16
|
|
||||||
TF_UINT16 = 17,
|
|
||||||
TF_COMPLEX128 = 18, // Double-precision complex
|
|
||||||
TF_HALF = 19,
|
|
||||||
TF_RESOURCE = 20,
|
|
||||||
TF_VARIANT = 21,
|
|
||||||
TF_UINT32 = 22,
|
|
||||||
TF_UINT64 = 23,
|
|
||||||
} TF_DataType;
|
|
||||||
|
|
||||||
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
|
|
||||||
// to the given TF_DataType enum value. Returns 0 for variable length types
|
|
||||||
// (eg. TF_STRING) or on failure.
|
|
||||||
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// TF_Code holds an error code. The enum values here are identical to
|
|
||||||
// corresponding values in error_codes.proto.
|
|
||||||
typedef enum TF_Code {
|
|
||||||
TF_OK = 0,
|
|
||||||
TF_CANCELLED = 1,
|
|
||||||
TF_UNKNOWN = 2,
|
|
||||||
TF_INVALID_ARGUMENT = 3,
|
|
||||||
TF_DEADLINE_EXCEEDED = 4,
|
|
||||||
TF_NOT_FOUND = 5,
|
|
||||||
TF_ALREADY_EXISTS = 6,
|
|
||||||
TF_PERMISSION_DENIED = 7,
|
|
||||||
TF_UNAUTHENTICATED = 16,
|
|
||||||
TF_RESOURCE_EXHAUSTED = 8,
|
|
||||||
TF_FAILED_PRECONDITION = 9,
|
|
||||||
TF_ABORTED = 10,
|
|
||||||
TF_OUT_OF_RANGE = 11,
|
|
||||||
TF_UNIMPLEMENTED = 12,
|
|
||||||
TF_INTERNAL = 13,
|
|
||||||
TF_UNAVAILABLE = 14,
|
|
||||||
TF_DATA_LOSS = 15,
|
|
||||||
} TF_Code;
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
|
||||||
// TF_Status holds error information. It either has an OK code, or
|
|
||||||
// else an error code with an associated error message.
|
|
||||||
typedef struct TF_Status TF_Status;
|
|
||||||
|
|
||||||
// Return a new status object.
|
|
||||||
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
|
|
||||||
|
|
||||||
// Delete a previously created status object.
|
|
||||||
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
|
|
||||||
|
|
||||||
// Record <code, msg> in *s. Any previous information is lost.
|
|
||||||
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
|
|
||||||
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
|
|
||||||
const char* msg);
|
|
||||||
|
|
||||||
// Return the code record in *s.
|
|
||||||
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
|
|
||||||
|
|
||||||
// Return a pointer to the (null-terminated) error message in *s. The
|
|
||||||
// return value points to memory that is only usable until the next
|
|
||||||
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
|
|
||||||
// TF_OK.
|
|
||||||
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// TF_Buffer holds a pointer to a block of data and its associated length.
|
// TF_Buffer holds a pointer to a block of data and its associated length.
|
||||||
// Typically, the data consists of a serialized protocol buffer, but other data
|
// Typically, the data consists of a serialized protocol buffer, but other data
|
||||||
@ -686,19 +607,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs(
|
|||||||
TF_Operation* oper, TF_Operation** control_outputs,
|
TF_Operation* oper, TF_Operation** control_outputs,
|
||||||
int max_control_outputs);
|
int max_control_outputs);
|
||||||
|
|
||||||
// TF_AttrType describes the type of the value of an attribute on an operation.
|
|
||||||
typedef enum TF_AttrType {
|
|
||||||
TF_ATTR_STRING = 0,
|
|
||||||
TF_ATTR_INT = 1,
|
|
||||||
TF_ATTR_FLOAT = 2,
|
|
||||||
TF_ATTR_BOOL = 3,
|
|
||||||
TF_ATTR_TYPE = 4,
|
|
||||||
TF_ATTR_SHAPE = 5,
|
|
||||||
TF_ATTR_TENSOR = 6,
|
|
||||||
TF_ATTR_PLACEHOLDER = 7,
|
|
||||||
TF_ATTR_FUNC = 8,
|
|
||||||
} TF_AttrType;
|
|
||||||
|
|
||||||
// TF_AttrMetadata describes the value of an attribute on an operation.
|
// TF_AttrMetadata describes the value of an attribute on an operation.
|
||||||
typedef struct TF_AttrMetadata {
|
typedef struct TF_AttrMetadata {
|
||||||
// A boolean: 1 if the attribute value is a list, 0 otherwise.
|
// A boolean: 1 if the attribute value is a list, 0 otherwise.
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/checkpoint_reader.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
@ -37,6 +38,7 @@ using tensorflow::FunctionDef;
|
|||||||
using tensorflow::Node;
|
using tensorflow::Node;
|
||||||
using tensorflow::NodeBuilder;
|
using tensorflow::NodeBuilder;
|
||||||
using tensorflow::Status;
|
using tensorflow::Status;
|
||||||
|
using tensorflow::errors::InvalidArgument;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
|
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
|
||||||
@ -149,7 +151,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
|
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
|
||||||
const auto& debug_str = func->fdef.DebugString();
|
const auto& debug_str = DebugString(func->fdef);
|
||||||
*len = debug_str.size();
|
*len = debug_str.size();
|
||||||
char* ret = static_cast<char*>(malloc(*len + 1));
|
char* ret = static_cast<char*>(malloc(*len + 1));
|
||||||
memcpy(ret, debug_str.c_str(), *len + 1);
|
memcpy(ret, debug_str.c_str(), *len + 1);
|
||||||
@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
|||||||
status->status = tensorflow::errors::Internal(errMsg);
|
status->status = tensorflow::errors::Internal(errMsg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
|
||||||
|
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
|
||||||
|
std::vector<std::string> variable_list;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
|
||||||
|
TF_Status* status) {
|
||||||
|
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
|
||||||
|
if (!status->status.ok()) return nullptr;
|
||||||
|
const auto& m = reader->GetVariableToDataTypeMap();
|
||||||
|
for (auto it = m.begin(); it != m.end(); ++it)
|
||||||
|
reader->variable_list.push_back(it->first);
|
||||||
|
std::sort(reader->variable_list.begin(), reader->variable_list.end());
|
||||||
|
return reader;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
|
||||||
|
|
||||||
|
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
|
||||||
|
const char* name) {
|
||||||
|
return reader->HasTensor(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
|
||||||
|
int index) {
|
||||||
|
return reader->variable_list[index].c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
|
||||||
|
return reader->variable_list.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
|
||||||
|
const char* name) {
|
||||||
|
const auto& m = reader->GetVariableToDataTypeMap();
|
||||||
|
return static_cast<TF_DataType>(m.at(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
||||||
|
const char* name, TF_Status* status) {
|
||||||
|
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||||
|
reader->GetTensor(name, &tensor, status);
|
||||||
|
if (!status->status.ok()) return nullptr;
|
||||||
|
return tensorflow::TF_TensorFromTensor(*tensor.get(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||||
|
const char* name, int64_t* dims,
|
||||||
|
int num_dims, TF_Status* status) {
|
||||||
|
const auto& shape = reader->GetVariableToShapeMap().at(name);
|
||||||
|
int rank = shape.dims();
|
||||||
|
if (num_dims != rank) {
|
||||||
|
status->status = InvalidArgument("Expected rank is ", num_dims,
|
||||||
|
" but actual rank is ", rank);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_dims; i++) {
|
||||||
|
dims[i] = shape.dim_size(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
|
||||||
|
const char* name) {
|
||||||
|
const auto& m = reader->GetVariableToShapeMap();
|
||||||
|
return m.at(name).dims();
|
||||||
|
}
|
||||||
|
|
||||||
// This builder is used in the eager API to build a NodeDef.
|
// This builder is used in the eager API to build a NodeDef.
|
||||||
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
||||||
using tensorflow::AttrBuilder::AttrBuilder;
|
using tensorflow::AttrBuilder::AttrBuilder;
|
||||||
|
@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
|||||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||||
const char* errMsg);
|
const char* errMsg);
|
||||||
|
|
||||||
|
// TF_NewCheckpointReader() return the CheckpointReader that can be use to
|
||||||
|
// investigate or load the variable from the checkpoint file
|
||||||
|
typedef struct TF_CheckpointReader TF_CheckpointReader;
|
||||||
|
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
|
||||||
|
const char* filename, TF_Status* status);
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
|
||||||
|
TF_CheckpointReader* reader);
|
||||||
|
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
|
||||||
|
TF_CheckpointReader* reader, const char* name);
|
||||||
|
// Get the variable name at the given index
|
||||||
|
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
|
||||||
|
TF_CheckpointReader* reader, int index);
|
||||||
|
// Get the number of variable in the checkpoint
|
||||||
|
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
|
||||||
|
// Get the DataType of a variable
|
||||||
|
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
|
||||||
|
TF_CheckpointReader* reader, const char* name);
|
||||||
|
// Read the shape of a variable and write to `dims`
|
||||||
|
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
|
||||||
|
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
|
||||||
|
TF_Status* status);
|
||||||
|
// Get the number of dimension of a variable
|
||||||
|
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
|
||||||
|
TF_CheckpointReader* reader, const char* name);
|
||||||
|
// Load the weight of a variable
|
||||||
|
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
|
||||||
|
TF_CheckpointReader* reader, const char* name, TF_Status* status);
|
||||||
|
|
||||||
// TF_NewAttrBuilder() returns an object that you can set attributes on as
|
// TF_NewAttrBuilder() returns an object that you can set attributes on as
|
||||||
// though it were an op. This allows querying properties of that op for
|
// though it were an op. This allows querying properties of that op for
|
||||||
// type-checking purposes like if the op will run on a particular device type.
|
// type-checking purposes like if the op will run on a particular device type.
|
||||||
|
@ -62,8 +62,8 @@ protocol: "grpc"
|
|||||||
TF_Buffer* null_result =
|
TF_Buffer* null_result =
|
||||||
TFE_GetServerDef(malformed_text_proto.c_str(), status);
|
TFE_GetServerDef(malformed_text_proto.c_str(), status);
|
||||||
EXPECT_NE(TF_GetCode(status), TF_OK);
|
EXPECT_NE(TF_GetCode(status), TF_OK);
|
||||||
EXPECT_TRUE(tensorflow::str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(TF_Message(status),
|
||||||
TF_Message(status), "Invalid text proto for ServerDef"));
|
"Invalid text proto for ServerDef"));
|
||||||
EXPECT_EQ(null_result, nullptr);
|
EXPECT_EQ(null_result, nullptr);
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
|
@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test {
|
|||||||
const std::unordered_set<string>& nodes) {
|
const std::unordered_set<string>& nodes) {
|
||||||
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
||||||
<< "Got unexpected number of nodes. Expected: ["
|
<< "Got unexpected number of nodes. Expected: ["
|
||||||
<< str_util::Join(nodes, ", ")
|
<< absl::StrJoin(nodes, ", ")
|
||||||
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
||||||
for (const NodeDef& node_def : fdef.node_def()) {
|
for (const NodeDef& node_def : fdef.node_def()) {
|
||||||
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
||||||
|
@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Experimental extensions to the C API for eager execution of kernels.
|
# Experimental extensions to the C API for eager execution of kernels.
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
@ -30,7 +30,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/platform/host_info.h"
|
#include "tensorflow/core/platform/host_info.h"
|
||||||
|
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||||
@ -135,11 +137,12 @@ tensorflow::Status CreateRemoteContexts(
|
|||||||
const std::vector<string>& remote_workers, int64 rendezvous_id,
|
const std::vector<string>& remote_workers, int64 rendezvous_id,
|
||||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||||
|
const tensorflow::eager::CreateContextRequest& base_request,
|
||||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
|
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
|
||||||
for (int i = 0; i < remote_workers.size(); i++) {
|
for (int i = 0; i < remote_workers.size(); i++) {
|
||||||
const string& remote_worker = remote_workers[i];
|
const string& remote_worker = remote_workers[i];
|
||||||
|
|
||||||
tensorflow::eager::CreateContextRequest request;
|
tensorflow::eager::CreateContextRequest request(base_request);
|
||||||
tensorflow::eager::CreateContextResponse response;
|
tensorflow::eager::CreateContextResponse response;
|
||||||
request.set_rendezvous_id(rendezvous_id);
|
request.set_rendezvous_id(rendezvous_id);
|
||||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||||
@ -221,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
remote_workers, grpc_server->master_env()->worker_cache,
|
remote_workers, grpc_server->master_env()->worker_cache,
|
||||||
&remote_device_mgr));
|
&remote_device_mgr));
|
||||||
|
|
||||||
|
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||||
|
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||||
|
|
||||||
|
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||||
|
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||||
|
&local_device_attributes);
|
||||||
|
|
||||||
|
// This request make sure that we can create Rendevzous properly between
|
||||||
|
// Local and Remote context.
|
||||||
|
tensorflow::eager::CreateContextRequest base_request;
|
||||||
|
for (const auto& da : cluster_device_attributes) {
|
||||||
|
*base_request.add_cluster_device_attributes() = da;
|
||||||
|
}
|
||||||
|
for (const auto& da : local_device_attributes) {
|
||||||
|
*base_request.add_cluster_device_attributes() = da;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
|
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
|
||||||
grpc_server->channel_cache();
|
grpc_server->channel_cache();
|
||||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
|
||||||
@ -230,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
||||||
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
|
remote_eager_workers.get(), ctx->context->Async(), base_request,
|
||||||
|
&remote_contexts));
|
||||||
|
|
||||||
tensorflow::RemoteRendezvous* r =
|
tensorflow::RemoteRendezvous* r =
|
||||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
||||||
|
|
||||||
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
|
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
|
||||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||||
session_name, server_def, true));
|
session_name, server_def, base_request.cluster_device_attributes(),
|
||||||
|
true));
|
||||||
|
|
||||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -250,9 +272,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||||
|
|
||||||
return ctx->context->InitializeRemote(
|
return ctx->context->InitializeRemote(
|
||||||
std::move(server), std::move(remote_eager_workers),
|
std::move(server), grpc_server->worker_env(), worker_session,
|
||||||
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
|
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
||||||
keep_alive_secs);
|
remote_contexts, r, device_mgr, keep_alive_secs,
|
||||||
|
worker_session->cluster_flr.get());
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
}
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
@ -970,6 +993,23 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status) {
|
||||||
|
// TensorHandles created by PyFuncOp lack context and therefore could
|
||||||
|
// not be copied.
|
||||||
|
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
|
||||||
|
tensorflow::TensorHandle* handle;
|
||||||
|
status->status = tensorflow::EagerCopyToDevice(
|
||||||
|
h->handle, h->handle->Context(), "CPU:0", &handle);
|
||||||
|
if (status->status.ok()) {
|
||||||
|
return new TFE_TensorHandle(handle);
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
TFE_ContextAsyncWait(ctx, status);
|
TFE_ContextAsyncWait(ctx, status);
|
||||||
|
@ -462,6 +462,9 @@ class Tensor;
|
|||||||
|
|
||||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||||
TFE_TensorHandle* h, TF_Status* status);
|
TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status);
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << "Fully padded shape of ["
|
VLOG(3) << "Fully padded shape of ["
|
||||||
<< tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
|
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||||
<< padded_shape.DebugString();
|
<< padded_shape.DebugString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
||||||
bool ok = str_util::StrContains(base, substr);
|
bool ok = absl::StrContains(base, substr);
|
||||||
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
@ -1408,6 +1408,10 @@ void FunctionDefAndExecute(bool async) {
|
|||||||
status);
|
status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
for (bool clear_cache : {true, false, true}) {
|
||||||
|
if (clear_cache) {
|
||||||
|
TFE_ContextClearCaches(ctx);
|
||||||
|
}
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||||
TFE_TensorHandle* retval[1] = {nullptr};
|
TFE_TensorHandle* retval[1] = {nullptr};
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
@ -1431,6 +1435,7 @@ void FunctionDefAndExecute(bool async) {
|
|||||||
EXPECT_EQ(10, product[1]);
|
EXPECT_EQ(10, product[1]);
|
||||||
EXPECT_EQ(15, product[2]);
|
EXPECT_EQ(15, product[2]);
|
||||||
EXPECT_EQ(22, product[3]);
|
EXPECT_EQ(22, product[3]);
|
||||||
|
}
|
||||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# Description:
|
# Description:
|
||||||
# Experimental C APIs for TensorFlow.
|
# Experimental C APIs for TensorFlow.
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
@ -6,10 +6,9 @@ load(
|
|||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "bitcast_op",
|
name = "bitcast_op",
|
||||||
prefix = "bitcast_op",
|
prefix = "bitcast_op",
|
||||||
|
39
tensorflow/c/tf_attrtype.h
Normal file
39
tensorflow/c/tf_attrtype.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
|
||||||
|
#define TENSORFLOW_C_TF_ATTRTYPE_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// TF_AttrType describes the type of the value of an attribute on an operation.
|
||||||
|
typedef enum TF_AttrType {
|
||||||
|
TF_ATTR_STRING = 0,
|
||||||
|
TF_ATTR_INT = 1,
|
||||||
|
TF_ATTR_FLOAT = 2,
|
||||||
|
TF_ATTR_BOOL = 3,
|
||||||
|
TF_ATTR_TYPE = 4,
|
||||||
|
TF_ATTR_SHAPE = 5,
|
||||||
|
TF_ATTR_TENSOR = 6,
|
||||||
|
TF_ATTR_PLACEHOLDER = 7,
|
||||||
|
TF_ATTR_FUNC = 8,
|
||||||
|
} TF_AttrType;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_ATTRTYPE_H_
|
23
tensorflow/c/tf_datatype.cc
Normal file
23
tensorflow/c/tf_datatype.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
|
size_t TF_DataTypeSize(TF_DataType dt) {
|
||||||
|
return static_cast<size_t>(
|
||||||
|
tensorflow::DataTypeSize(static_cast<tensorflow::DataType>(dt)));
|
||||||
|
}
|
83
tensorflow/c/tf_datatype.h
Normal file
83
tensorflow/c/tf_datatype.h
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_DATATYPE_H_
|
||||||
|
#define TENSORFLOW_C_TF_DATATYPE_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||||
|
// .dylib, .dll).
|
||||||
|
// This duplicates the TF_EXPORT macro definition in
|
||||||
|
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||||
|
// of any other includes.
|
||||||
|
#ifdef SWIG
|
||||||
|
#define TF_CAPI_EXPORT
|
||||||
|
#else
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#ifdef TF_COMPILE_LIBRARY
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||||
|
#endif // TF_COMPILE_LIBRARY
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif // _WIN32
|
||||||
|
#endif // SWIG
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
|
||||||
|
// The enum values here are identical to corresponding values in types.proto.
|
||||||
|
typedef enum TF_DataType {
|
||||||
|
TF_FLOAT = 1,
|
||||||
|
TF_DOUBLE = 2,
|
||||||
|
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
|
||||||
|
TF_UINT8 = 4,
|
||||||
|
TF_INT16 = 5,
|
||||||
|
TF_INT8 = 6,
|
||||||
|
TF_STRING = 7,
|
||||||
|
TF_COMPLEX64 = 8, // Single-precision complex
|
||||||
|
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
|
||||||
|
TF_INT64 = 9,
|
||||||
|
TF_BOOL = 10,
|
||||||
|
TF_QINT8 = 11, // Quantized int8
|
||||||
|
TF_QUINT8 = 12, // Quantized uint8
|
||||||
|
TF_QINT32 = 13, // Quantized int32
|
||||||
|
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
|
||||||
|
TF_QINT16 = 15, // Quantized int16
|
||||||
|
TF_QUINT16 = 16, // Quantized uint16
|
||||||
|
TF_UINT16 = 17,
|
||||||
|
TF_COMPLEX128 = 18, // Double-precision complex
|
||||||
|
TF_HALF = 19,
|
||||||
|
TF_RESOURCE = 20,
|
||||||
|
TF_VARIANT = 21,
|
||||||
|
TF_UINT32 = 22,
|
||||||
|
TF_UINT64 = 23,
|
||||||
|
} TF_DataType;
|
||||||
|
|
||||||
|
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
|
||||||
|
// to the given TF_DataType enum value. Returns 0 for variable length types
|
||||||
|
// (eg. TF_STRING) or on failure.
|
||||||
|
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_DATATYPE_H_
|
42
tensorflow/c/tf_status.cc
Normal file
42
tensorflow/c/tf_status.cc
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
using ::tensorflow::Status;
|
||||||
|
using ::tensorflow::error::Code;
|
||||||
|
|
||||||
|
TF_Status* TF_NewStatus() { return new TF_Status; }
|
||||||
|
|
||||||
|
void TF_DeleteStatus(TF_Status* s) { delete s; }
|
||||||
|
|
||||||
|
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
|
||||||
|
if (code == TF_OK) {
|
||||||
|
s->status = Status::OK();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Code TF_GetCode(const TF_Status* s) {
|
||||||
|
return static_cast<TF_Code>(s->status.code());
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* TF_Message(const TF_Status* s) {
|
||||||
|
return s->status.error_message().c_str();
|
||||||
|
}
|
88
tensorflow/c/tf_status.h
Normal file
88
tensorflow/c/tf_status.h
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_STATUS_H_
|
||||||
|
#define TENSORFLOW_C_TF_STATUS_H_
|
||||||
|
|
||||||
|
#ifdef SWIG
|
||||||
|
#define TF_CAPI_EXPORT
|
||||||
|
#else
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#ifdef TF_COMPILE_LIBRARY
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||||
|
#endif // TF_COMPILE_LIBRARY
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif // _WIN32
|
||||||
|
#endif // SWIG
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct TF_Status TF_Status;
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// TF_Code holds an error code. The enum values here are identical to
|
||||||
|
// corresponding values in error_codes.proto.
|
||||||
|
typedef enum TF_Code {
|
||||||
|
TF_OK = 0,
|
||||||
|
TF_CANCELLED = 1,
|
||||||
|
TF_UNKNOWN = 2,
|
||||||
|
TF_INVALID_ARGUMENT = 3,
|
||||||
|
TF_DEADLINE_EXCEEDED = 4,
|
||||||
|
TF_NOT_FOUND = 5,
|
||||||
|
TF_ALREADY_EXISTS = 6,
|
||||||
|
TF_PERMISSION_DENIED = 7,
|
||||||
|
TF_UNAUTHENTICATED = 16,
|
||||||
|
TF_RESOURCE_EXHAUSTED = 8,
|
||||||
|
TF_FAILED_PRECONDITION = 9,
|
||||||
|
TF_ABORTED = 10,
|
||||||
|
TF_OUT_OF_RANGE = 11,
|
||||||
|
TF_UNIMPLEMENTED = 12,
|
||||||
|
TF_INTERNAL = 13,
|
||||||
|
TF_UNAVAILABLE = 14,
|
||||||
|
TF_DATA_LOSS = 15,
|
||||||
|
} TF_Code;
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Return a new status object.
|
||||||
|
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
|
||||||
|
|
||||||
|
// Delete a previously created status object.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
|
||||||
|
|
||||||
|
// Record <code, msg> in *s. Any previous information is lost.
|
||||||
|
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
|
||||||
|
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
|
||||||
|
const char* msg);
|
||||||
|
|
||||||
|
// Return the code record in *s.
|
||||||
|
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
|
||||||
|
|
||||||
|
// Return a pointer to the (null-terminated) error message in *s. The
|
||||||
|
// return value points to memory that is only usable until the next
|
||||||
|
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
|
||||||
|
// TF_OK.
|
||||||
|
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_STATUS_H_
|
@ -4,10 +4,9 @@
|
|||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "srcs",
|
name = "srcs",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -638,6 +637,7 @@ cc_library(
|
|||||||
"//tensorflow/core:op_gen_lib",
|
"//tensorflow/core:op_gen_lib",
|
||||||
"//tensorflow/core:proto_text",
|
"//tensorflow/core:proto_text",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -657,6 +657,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
#include "absl/strings/escaping.h"
|
||||||
#include "tensorflow/core/framework/api_def.pb.h"
|
#include "tensorflow/core/framework/api_def.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string PrintString(const string& str) {
|
string PrintString(const string& str) {
|
||||||
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
return strings::StrCat("\"", absl::CEscape(str), "\"");
|
||||||
}
|
}
|
||||||
|
|
||||||
string PrintTensorShape(const TensorShapeProto& shape_proto) {
|
string PrintTensorShape(const TensorShapeProto& shape_proto) {
|
||||||
@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) {
|
|||||||
string ret;
|
string ret;
|
||||||
for (int64 i = 0; i < num_elts; ++i) {
|
for (int64 i = 0; i < num_elts; ++i) {
|
||||||
if (i > 0) strings::StrAppend(&ret, " ");
|
if (i > 0) strings::StrAppend(&ret, " ");
|
||||||
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
|
strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -62,12 +62,12 @@ op {
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
||||||
EXPECT_FALSE(str_util::StrContains(s, expected))
|
EXPECT_FALSE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' contains '" << expected << "'";
|
<< "'" << s << "' contains '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
|||||||
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
||||||
for (const string& entry : node_constraints) {
|
for (const string& entry : node_constraints) {
|
||||||
StringPiece s(entry);
|
StringPiece s(entry);
|
||||||
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
||||||
current_constraints.emplace(s);
|
current_constraints.emplace(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
load(
|
load(
|
||||||
|
@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
|||||||
const Status status = LoadSavedModelInternal(session_options, run_options,
|
const Status status = LoadSavedModelInternal(session_options, run_options,
|
||||||
export_dir, tags, bundle);
|
export_dir, tags, bundle);
|
||||||
auto log_and_count = [&](const string& status_str) {
|
auto log_and_count = [&](const string& status_str) {
|
||||||
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
|
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
|
||||||
<< " }; Status: " << status_str << ". Took "
|
<< " }; Status: " << status_str << ". Took "
|
||||||
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
|
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
|
||||||
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
||||||
|
@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{"missing-tag"}, &bundle);
|
{"missing-tag"}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: "))
|
"Could not find meta graph def matching supplied tags: "))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
|
|||||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
{kSavedModelTagServe}, &bundle);
|
{kSavedModelTagServe}, &bundle);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget))
|
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
|||||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||||
const std::unordered_set<string>& tags,
|
const std::unordered_set<string>& tags,
|
||||||
MetaGraphDef* meta_graph_def) {
|
MetaGraphDef* meta_graph_def) {
|
||||||
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
|
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
|
||||||
<< " }";
|
<< " }";
|
||||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||||
// Get tags from the graph_def.
|
// Get tags from the graph_def.
|
||||||
@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
|||||||
error::Code::NOT_FOUND,
|
error::Code::NOT_FOUND,
|
||||||
strings::StrCat(
|
strings::StrCat(
|
||||||
"Could not find meta graph def matching supplied tags: { ",
|
"Could not find meta graph def matching supplied tags: { ",
|
||||||
str_util::Join(tags, " "),
|
absl::StrJoin(tags, " "),
|
||||||
" }. To inspect available tag-sets in the SavedModel, please "
|
" }. To inspect available tag-sets in the SavedModel, please "
|
||||||
"use the SavedModel CLI: `saved_model_cli`"));
|
"use the SavedModel CLI: `saved_model_cli`"));
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
|||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||||||
Status st = ReadMetaGraphDefFromSavedModel(
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
EXPECT_TRUE(str_util::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
st.error_message(),
|
st.error_message(),
|
||||||
"Could not find meta graph def matching supplied tags: "))
|
"Could not find meta graph def matching supplied tags: "))
|
||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
|
@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
load(
|
load(
|
||||||
|
@ -167,8 +167,7 @@ namespace {
|
|||||||
|
|
||||||
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||||
int32* dst) {
|
int32* dst) {
|
||||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
|
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
|
||||||
tensorflow::str_util::ConsumePrefix(&arg, "=")) {
|
|
||||||
char extra;
|
char extra;
|
||||||
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
||||||
}
|
}
|
||||||
@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
|||||||
|
|
||||||
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||||
bool* dst) {
|
bool* dst) {
|
||||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
|
if (absl::ConsumePrefix(&arg, flag)) {
|
||||||
if (arg.empty()) {
|
if (arg.empty()) {
|
||||||
*dst = true;
|
*dst = true;
|
||||||
return true;
|
return true;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:private"],
|
default_visibility = ["//visibility:private"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
":internal",
|
||||||
|
# BEGIN-GOOGLE-INTERNAL
|
||||||
|
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||||
|
# END-GOOGLE-INTERNAL
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
@ -14,15 +22,6 @@ package_group(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = [
|
|
||||||
":internal",
|
|
||||||
# BEGIN-GOOGLE-INTERNAL
|
|
||||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
|
||||||
# END-GOOGLE-INTERNAL
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||||
@ -200,6 +199,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:host_constant_op",
|
"//tensorflow/core/kernels:host_constant_op",
|
||||||
"//tensorflow/core/kernels:identity_n_op",
|
"//tensorflow/core/kernels:identity_n_op",
|
||||||
"//tensorflow/core/kernels:identity_op",
|
"//tensorflow/core/kernels:identity_op",
|
||||||
|
"//tensorflow/core/kernels:logging_ops",
|
||||||
"//tensorflow/core/kernels:no_op",
|
"//tensorflow/core/kernels:no_op",
|
||||||
"//tensorflow/core/kernels:queue_op",
|
"//tensorflow/core/kernels:queue_op",
|
||||||
"//tensorflow/core/kernels:resource_variable_ops",
|
"//tensorflow/core/kernels:resource_variable_ops",
|
||||||
@ -257,10 +257,8 @@ cc_library(
|
|||||||
name = "xla_launch_util",
|
name = "xla_launch_util",
|
||||||
srcs = ["xla_launch_util.cc"],
|
srcs = ["xla_launch_util.cc"],
|
||||||
hdrs = ["xla_launch_util.h"],
|
hdrs = ["xla_launch_util.h"],
|
||||||
# TODO(skyewm): remove this once XlaAllocator is factored out.
|
|
||||||
visibility = [
|
visibility = [
|
||||||
":internal",
|
":internal",
|
||||||
"//tensorflow/compiler/xla/python:__pkg__",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
@ -244,11 +244,11 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
|||||||
"resource variable op in called function");
|
"resource variable op in called function");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
|
if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
|
||||||
return LogNotCompilableAndReturn(node, "operation with correctness issues");
|
return LogNotCompilableAndReturn(node, "operation with correctness issues");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
|
if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
|
||||||
return LogNotCompilableAndReturn(node, "slow operation");
|
return LogNotCompilableAndReturn(node, "slow operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,8 +268,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
|||||||
registration.elide_assert_and_checknumerics;
|
registration.elide_assert_and_checknumerics;
|
||||||
op_filter.allow_ops_producing_or_consuming_variant =
|
op_filter.allow_ops_producing_or_consuming_variant =
|
||||||
registration.cluster_variant_ops;
|
registration.cluster_variant_ops;
|
||||||
op_filter.allow_slow_and_inaccurate_ops =
|
op_filter.allow_slow_ops = registration.cluster_slow_ops;
|
||||||
registration.cluster_slow_and_inaccurate_ops;
|
op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
|
||||||
return op_filter;
|
return op_filter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,9 +97,12 @@ class RecursiveCompilabilityChecker {
|
|||||||
// live-out DT_VARIANT values.
|
// live-out DT_VARIANT values.
|
||||||
bool allow_ops_producing_or_consuming_variant;
|
bool allow_ops_producing_or_consuming_variant;
|
||||||
|
|
||||||
// Whether ops known to be slow or to have correctness issues should be
|
// Whether ops known to be slow on XLA-GPU should be considered compilable..
|
||||||
// auto-clustered.
|
bool allow_slow_ops;
|
||||||
bool allow_slow_and_inaccurate_ops;
|
|
||||||
|
// Whether ops known to have numerical accuracy issues should be considered
|
||||||
|
// compilable..
|
||||||
|
bool allow_inaccurate_ops;
|
||||||
};
|
};
|
||||||
|
|
||||||
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
||||||
|
@ -14,10 +14,12 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -43,12 +45,12 @@ limitations under the License.
|
|||||||
// ------------------------------------------
|
// ------------------------------------------
|
||||||
//
|
//
|
||||||
// If we ignore cycles for a moment, computing predicates is fairly
|
// If we ignore cycles for a moment, computing predicates is fairly
|
||||||
// straightforward. We traverse the graph in RPO, mapping each node to a
|
// straightforward. We traverse the graph in a topological order, mapping each
|
||||||
// predicate based on the predicates its inputs are mapped to. For instance a
|
// node to a predicate based on the predicates its inputs are mapped to. For
|
||||||
// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
|
// instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
|
||||||
// Roughtly speaking, we abstract interpret each node on the "liveness" domain,
|
// PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
|
||||||
// where values in the domain represent if a tensor carries a dead signal or
|
// the "liveness" domain, where values in the domain represent if a tensor
|
||||||
// not.
|
// carries a dead signal or not.
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
// DEALING WITH CYCLES
|
// DEALING WITH CYCLES
|
||||||
@ -85,22 +87,28 @@ limitations under the License.
|
|||||||
// true on iteration 0, 1, 2 respectively. This is made more precise in the
|
// true on iteration 0, 1, 2 respectively. This is made more precise in the
|
||||||
// comment on the AndRecurrence class.
|
// comment on the AndRecurrence class.
|
||||||
//
|
//
|
||||||
// The general algorithm that deals with cycles does two RPO (reverse post
|
// The general algorithm that deals with cycles does two topological-order
|
||||||
// order) passes over the graph. On the first pass it assigns a symbolic
|
// iterations over the graph. On the first iteration it assigns a symbolic
|
||||||
// predicate to merge nodes with backedges. On the second pass it tries to
|
// predicate to merge nodes with backedges. On the second iteration it tries
|
||||||
// pattern matche the predicates for the backedges of these merges and infer an
|
// to pattern match the predicates for the backedges of these merges and infer
|
||||||
// AndRecurrence for the merge.
|
// an AndRecurrence for the merge. In other words, we do a data flow analysis
|
||||||
|
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
|
||||||
|
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
|
||||||
|
// sufficient to converge.
|
||||||
//
|
//
|
||||||
// In other words, we do a pessimistic data flow analysis where the data-flow
|
// We first do an optimisitc analysis and, if it does not converge, we then fall
|
||||||
// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
|
// back to a pessimistic analysis. The optimistic analysis assigns the same
|
||||||
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
|
// symbolic predicate to all the merge nodes whose preceding enter nodes have
|
||||||
// converge. We don't do an optimistic data flow analysis to make pattern
|
// the same frame name on the first iteration. On the second iteration, if all
|
||||||
// matching easier: if we assigned the predicate of the initial value to the
|
// the merge nodes are pattern matched into the same AndRecurrence predicate
|
||||||
// merge during the first pass, on the second pass the backedge may see a
|
// instance, the optimistic assignment of the same symbolic predicate is correct
|
||||||
// simplified value that would be difficult to pattern match.
|
// and the analyzed result is taken.
|
||||||
//
|
//
|
||||||
// We still use symbolic predicates for merges for which we can't pattern match
|
// Otherwise, if the optimistic analysis fails to converge, we then obtain the
|
||||||
// on the backedge predicate. This is conservatively correct.
|
// result by falling back to the pessimistic analysis which assigns a unique
|
||||||
|
// symbolic predicate to each merge on the first iteration. We still use
|
||||||
|
// symbolic predicates for merges for which we can't pattern match on the
|
||||||
|
// backedge predicate. This is conservatively correct.
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -636,6 +644,35 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
|||||||
negated_ops.insert(negated_op);
|
negated_ops.insert(negated_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simplify {S,&,X} & ~X & ... => S & ...
|
||||||
|
if (is_and) {
|
||||||
|
absl::flat_hash_set<Predicate*> to_remove;
|
||||||
|
std::vector<Predicate*> to_add;
|
||||||
|
for (Predicate* op : simplified_ops) {
|
||||||
|
if (op->kind() == Predicate::Kind::kAndRecurrence) {
|
||||||
|
auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
|
||||||
|
if (negated_ops.contains(and_rec->step())) {
|
||||||
|
// Remove and_rec and ~X and insert S. Note that checking the
|
||||||
|
// existence of ~X through negated_ops is sufficient since it makes
|
||||||
|
// sure the predicate is in the input operands. It does not need to
|
||||||
|
// be in simplified_ops if it was already cancelled out.
|
||||||
|
to_remove.insert(and_rec);
|
||||||
|
to_remove.insert(MakeNotPredicate(and_rec->step()));
|
||||||
|
to_add.push_back(and_rec->start());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto it = simplified_ops.begin();
|
||||||
|
while (it != simplified_ops.end()) {
|
||||||
|
if (to_remove.contains(*it)) {
|
||||||
|
it = simplified_ops.erase(it);
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
|
||||||
|
}
|
||||||
|
|
||||||
// If all ops contain the same subop, then factor it out thanks to the
|
// If all ops contain the same subop, then factor it out thanks to the
|
||||||
// distributive property. Such as:
|
// distributive property. Such as:
|
||||||
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
|
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
|
||||||
@ -699,8 +736,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
|||||||
explicit DeadnessAnalysisImpl(const Graph* graph)
|
explicit DeadnessAnalysisImpl(const Graph* graph)
|
||||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||||
|
|
||||||
Status Populate();
|
Status Populate(bool enable_optimistic);
|
||||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
|
||||||
|
bool* success);
|
||||||
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
||||||
Node* n, int oidx) const override;
|
Node* n, int oidx) const override;
|
||||||
void Print() const override;
|
void Print() const override;
|
||||||
@ -742,16 +780,29 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
|
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
|
||||||
Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
|
Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
|
||||||
|
bool use_optimistic_mode);
|
||||||
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
|
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
|
||||||
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
|
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
|
||||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
|
Status HandleNode(Node* n, std::vector<bool>* should_revisit,
|
||||||
|
bool use_optimistic_mode = false);
|
||||||
|
|
||||||
|
Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
|
||||||
|
|
||||||
|
bool IsRootEnter(const Node* n) const {
|
||||||
|
return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRootExit(const Node* n) const {
|
||||||
|
return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
|
||||||
|
}
|
||||||
|
|
||||||
const Graph& graph_;
|
const Graph& graph_;
|
||||||
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||||
PredicateFactory predicate_factory_;
|
PredicateFactory predicate_factory_;
|
||||||
std::vector<ControlFlowInfo> control_flow_info_;
|
std::vector<ControlFlowInfo> control_flow_info_;
|
||||||
bool vlog_;
|
bool vlog_;
|
||||||
|
absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorId InputEdgeToTensorId(const Edge* e) {
|
TensorId InputEdgeToTensorId(const Edge* e) {
|
||||||
@ -914,10 +965,32 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
|||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the node is inside some frames, get the name of the outermost non-empty
|
||||||
|
// frame. Otherwise, get an empty frame name.
|
||||||
|
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||||
|
absl::string_view* frame) {
|
||||||
|
int depth = 0;
|
||||||
|
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
|
||||||
|
while (!cfi_iter->parent_frame->IsSource()) {
|
||||||
|
n = cfi_iter->parent_frame;
|
||||||
|
cfi_iter = &cfi_infos[n->id()];
|
||||||
|
|
||||||
|
if (depth++ > 5000) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Frame of depth > 5000: Probably malformed graph or a bug in "
|
||||||
|
"BuildControlFlowInfo");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*frame = cfi_iter->frame_name;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||||
std::vector<bool>* should_revisit) {
|
std::vector<bool>* should_revisit,
|
||||||
|
bool use_optimistic_mode) {
|
||||||
// Merge ignores deadness of its control inputs. A merge that isn't the
|
// Merge ignores deadness of its control inputs. A merge that isn't the
|
||||||
// target of a backedge has is alive iff any of its data inputs are. The
|
// target of a backedge has is alive iff any of its data inputs are. The
|
||||||
// liveness of a merge that is the target of a backedge can sometimes be
|
// liveness of a merge that is the target of a backedge can sometimes be
|
||||||
@ -937,8 +1010,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
|||||||
// We're visiting this merge for the first time and it has an unvisited
|
// We're visiting this merge for the first time and it has an unvisited
|
||||||
// backedge.
|
// backedge.
|
||||||
Predicate* input_data_pred;
|
Predicate* input_data_pred;
|
||||||
|
if (use_optimistic_mode) {
|
||||||
|
// In the optimistic mode, we use the first-seen Merge node per
|
||||||
|
// frame as the representative Merge node. It is just convenient and
|
||||||
|
// does not affect the result after pattern-matching into the
|
||||||
|
// AndRecurrence form.
|
||||||
|
absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
|
||||||
|
auto insert_result = frame_to_merge_node_.insert({frame_name, n});
|
||||||
|
Node* representative = insert_result.first->second;
|
||||||
|
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||||
|
representative, /*output_idx=*/0, /*must_be_true=*/false,
|
||||||
|
&input_data_pred));
|
||||||
|
} else {
|
||||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||||
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
|
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
|
||||||
|
}
|
||||||
|
|
||||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||||
should_revisit);
|
should_revisit);
|
||||||
@ -948,7 +1034,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
|||||||
std::vector<Predicate*> input_preds;
|
std::vector<Predicate*> input_preds;
|
||||||
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
|
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
|
||||||
|
|
||||||
// We're visiting this merge for the first time and it is a acyclic merge.
|
// We're visiting this merge for the first time and it is an acyclic merge.
|
||||||
Predicate* input_data_pred =
|
Predicate* input_data_pred =
|
||||||
predicate_factory_.MakeOrPredicate(input_preds);
|
predicate_factory_.MakeOrPredicate(input_preds);
|
||||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||||
@ -1022,11 +1108,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
||||||
std::vector<bool>* should_revisit) {
|
std::vector<bool>* should_revisit,
|
||||||
|
bool use_optimistic_mode) {
|
||||||
if (n->IsSwitch()) {
|
if (n->IsSwitch()) {
|
||||||
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
|
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
|
||||||
} else if (n->IsMerge()) {
|
} else if (n->IsMerge()) {
|
||||||
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
|
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
|
||||||
} else if (n->IsControlTrigger()) {
|
} else if (n->IsControlTrigger()) {
|
||||||
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
|
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
|
||||||
nullptr);
|
nullptr);
|
||||||
@ -1040,17 +1127,129 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeadnessAnalysisImpl::Populate() {
|
// Compute a special topological order for the Graph, where nodes having the
|
||||||
std::vector<Node*> rpo;
|
// same root frame are placed adjacent to each other. The traversal uses a
|
||||||
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
|
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
|
||||||
/*edge_filter=*/[](const Edge& edge) {
|
// many inputs of each node are ready; a node is ready to be scheduled if all
|
||||||
return !edge.src()->IsNextIteration();
|
// of its inputs are ready.
|
||||||
});
|
// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
|
||||||
return PopulateWithReversePostOrder(rpo);
|
Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
|
||||||
|
std::vector<Node*>* order) {
|
||||||
|
absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
|
||||||
|
absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
|
||||||
|
std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
|
||||||
|
Node* src_node = graph_.source_node();
|
||||||
|
for (const auto* node : graph_.op_nodes()) {
|
||||||
|
const ControlFlowInfo& cf = control_flow_info_[node->id()];
|
||||||
|
if (IsRootEnter(node)) {
|
||||||
|
// Since we care only the root-level frame, full frame names are the same
|
||||||
|
// as frame names.
|
||||||
|
++num_enters_for_frame[cf.frame_name];
|
||||||
|
} else if (IsRootExit(node)) {
|
||||||
|
++num_exits_for_frame[cf.frame_name];
|
||||||
|
}
|
||||||
|
// Edge NextIteration->Merge is counted before starting the traveral to
|
||||||
|
// break the backedges.
|
||||||
|
if (IsMerge(node)) {
|
||||||
|
for (const Edge* e : node->in_edges()) {
|
||||||
|
if (IsNextIteration(e->src())) {
|
||||||
|
++num_ready_inputs[node->id()];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dequeue is used to ensure that the nodes are first-in-first-out. This
|
||||||
|
// order guarantees that the exits in the ready queue are visited before
|
||||||
|
// nodes that will become ready in the future.
|
||||||
|
std::deque<Node*> ready;
|
||||||
|
ready.push_back(src_node);
|
||||||
|
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
|
||||||
|
// the ready enters/exits before they are moved to the `ready` queue for
|
||||||
|
// controlling the start and end of a processing frame.
|
||||||
|
absl::flat_hash_map<absl::string_view, std::vector<Node*>>
|
||||||
|
ready_enters_per_frame;
|
||||||
|
// Exit nodes shall all be from the same frame, as we process a frame at a
|
||||||
|
// time. So, one vector is enough.
|
||||||
|
std::vector<Node*> ready_exits;
|
||||||
|
while (!ready.empty()) {
|
||||||
|
Node* curr_node = ready.front();
|
||||||
|
ready.pop_front();
|
||||||
|
|
||||||
|
VLOG(4) << "Visiting " << curr_node->name();
|
||||||
|
order->push_back(curr_node);
|
||||||
|
|
||||||
|
for (const Edge* out_edge : curr_node->out_edges()) {
|
||||||
|
Node* out = out_edge->dst();
|
||||||
|
int out_id = out->id();
|
||||||
|
if (IsNextIteration(curr_node) && IsMerge(out)) {
|
||||||
|
// Edge NextIteration->Merge has been counted.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
++num_ready_inputs[out->id()];
|
||||||
|
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
|
||||||
|
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
|
||||||
|
|
||||||
|
absl::string_view frame_name = control_flow_info_[out_id].frame_name;
|
||||||
|
if (IsRootEnter(out)) {
|
||||||
|
ready_enters_per_frame[frame_name].push_back(out);
|
||||||
|
} else if (IsRootExit(out)) {
|
||||||
|
ready_exits.push_back(out);
|
||||||
|
} else {
|
||||||
|
ready.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ready.empty()) {
|
||||||
|
// Try moving nodes from ready_enters_per_frame and ready_exits to
|
||||||
|
// `ready`.
|
||||||
|
if (!ready_exits.empty()) {
|
||||||
|
// If there are nodes in ready_exits we must process them before
|
||||||
|
// processing ready_enters_per_frame to make sure all nodes in the
|
||||||
|
// currently processing frame are visited before starting processing
|
||||||
|
// other frames.
|
||||||
|
absl::string_view frame_name =
|
||||||
|
control_flow_info_[ready_exits.front()->id()].frame_name;
|
||||||
|
CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
|
||||||
|
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
|
||||||
|
ready_exits.clear();
|
||||||
|
} else {
|
||||||
|
// Otherwise, try moving nodes from ready_enters to `ready`.
|
||||||
|
for (auto iter = ready_enters_per_frame.begin();
|
||||||
|
iter != ready_enters_per_frame.end(); ++iter) {
|
||||||
|
absl::string_view frame_name = iter->first;
|
||||||
|
const std::vector<Node*>& ready_enters = iter->second;
|
||||||
|
if (ready_enters.size() == num_enters_for_frame[frame_name]) {
|
||||||
|
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
|
||||||
|
ready_enters_per_frame.erase(iter);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Some enters/exits have never been visited in the traversal."
|
||||||
|
" Most probably the input graph is malformed.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
// We populate the nodes along a special topological order where nodes having
|
||||||
absl::Span<Node* const> rpo) {
|
// the same root frame are placed adjacent to each other. This grouping enables
|
||||||
|
// processing the graph per root frame at a time and guarantees that when a root
|
||||||
|
// frame is being processed, nodes in the downstream frames have not yet been
|
||||||
|
// processed. This property is important because we need to process an entire
|
||||||
|
// frame to know whether the optimistic mode converges or not. In other words,
|
||||||
|
// nodes in the downstream frames shall not be populated until all of its
|
||||||
|
// upstream frames are populated. In effect, this order enables processing each
|
||||||
|
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
|
||||||
|
// (root) frame. Note that we don't separate while loops belonging to the same
|
||||||
|
// nested while, as there is no clean cut for separating them in the topological
|
||||||
|
// order.
|
||||||
|
Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
|
||||||
std::vector<string> unreachable_nodes;
|
std::vector<string> unreachable_nodes;
|
||||||
// Compute the loop structure of the graph.
|
// Compute the loop structure of the graph.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -1069,14 +1268,63 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
|||||||
absl::StrJoin(unreachable_nodes, ", "));
|
absl::StrJoin(unreachable_nodes, ", "));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Node*> topo;
|
||||||
|
TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
|
||||||
|
|
||||||
|
size_t frame_start = 0;
|
||||||
|
while (frame_start < topo.size()) {
|
||||||
|
// Batching nodes who have the same root frame.
|
||||||
|
absl::string_view cur_frame_name;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
|
||||||
|
size_t frame_end = frame_start;
|
||||||
|
for (size_t i = frame_start + 1; i < topo.size(); ++i) {
|
||||||
|
absl::string_view i_frame_name;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
|
||||||
|
if (i_frame_name == cur_frame_name) {
|
||||||
|
frame_end = i;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
absl::Span<Node*> sub_topo(topo.data() + frame_start,
|
||||||
|
/*length=*/frame_end - frame_start + 1);
|
||||||
|
frame_start = frame_end + 1;
|
||||||
|
|
||||||
|
// First, try the optimistic mode.
|
||||||
|
bool success = false;
|
||||||
|
if (enable_optimistic && !cur_frame_name.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
|
||||||
|
}
|
||||||
|
if (!success) {
|
||||||
|
// The optimistic mode does not converge. Let's fall back to the
|
||||||
|
// pessimistic mode.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
|
||||||
|
}
|
||||||
|
VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
|
||||||
|
<< (success ? "optimistic" : "pessimistic") << " mode.";
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
|
||||||
|
bool use_optimistic_mode,
|
||||||
|
bool* success) {
|
||||||
|
CHECK(use_optimistic_mode && success != nullptr ||
|
||||||
|
!use_optimistic_mode && success == nullptr);
|
||||||
|
|
||||||
// This an abstract interpretation over the deadness propagation semantics of
|
// This an abstract interpretation over the deadness propagation semantics of
|
||||||
// the graph executor.
|
// the graph executor.
|
||||||
//
|
//
|
||||||
// We iterate over the graph twice, each time in RPO. On the first iteration
|
// We iterate over the graph twice, each time in a topological order. On the
|
||||||
// merge nodes with backedges are mapped to symbolic predicates. On the
|
// first iteration merge nodes with backedges are mapped to symbolic
|
||||||
// second iteration we use the predicates assigned to the backedges in the
|
// predicates. On the second iteration we use the predicates assigned to the
|
||||||
// previous iteration to infer a more precise predicate for the backedge merge
|
// backedges in the previous iteration to infer a more precise predicate for
|
||||||
// nodes and all the nodes that transitively use it.
|
// the backedge merge nodes and all the nodes that transitively use it.
|
||||||
//
|
//
|
||||||
// We don't track the output indices for should_revisit. Instead, putting a
|
// We don't track the output indices for should_revisit. Instead, putting a
|
||||||
// node in `should_revisit` denotes that the deadness flowing out from any
|
// node in `should_revisit` denotes that the deadness flowing out from any
|
||||||
@ -1086,9 +1334,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
|||||||
// delta should not change in the second iteration.
|
// delta should not change in the second iteration.
|
||||||
std::vector<bool> should_revisit;
|
std::vector<bool> should_revisit;
|
||||||
should_revisit.resize(graph_.num_node_ids());
|
should_revisit.resize(graph_.num_node_ids());
|
||||||
for (Node* n : rpo) {
|
for (Node* n : topo) {
|
||||||
VLOG(4) << "Visiting " << n->name();
|
VLOG(4) << "Visiting " << n->name();
|
||||||
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
|
TF_RETURN_IF_ERROR(
|
||||||
|
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
|
||||||
if (n->IsNextIteration()) {
|
if (n->IsNextIteration()) {
|
||||||
// If this is a backedge for a merge node then remember to reprocess the
|
// If this is a backedge for a merge node then remember to reprocess the
|
||||||
// merge the next time we run.
|
// merge the next time we run.
|
||||||
@ -1100,11 +1349,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Node* n : rpo) {
|
for (Node* n : topo) {
|
||||||
// The nodes added to should_revisit in the previous loop need to be
|
// The nodes added to should_revisit in the previous loop need to be
|
||||||
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
||||||
// to should_revisit, and these newly added nodes will also be processed by
|
// to should_revisit, and these newly added nodes will also be processed by
|
||||||
// this very same loop. Since we're traversing the graph in reverse post
|
// this very same loop. Since we're traversing the graph in topological
|
||||||
// order (producers before consumers) and HandleNode(n) can only ever add
|
// order (producers before consumers) and HandleNode(n) can only ever add
|
||||||
// n's consumers to should_revisit, we won't "miss" an addition to
|
// n's consumers to should_revisit, we won't "miss" an addition to
|
||||||
// should_revisit.
|
// should_revisit.
|
||||||
@ -1114,6 +1363,71 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the optimistic analysis converges. Specifically, check whether
|
||||||
|
// all the predicates of the merge nodes in the same frame are the same. If
|
||||||
|
// yes, report success. If not, report failure and clear the assigned
|
||||||
|
// predicates.
|
||||||
|
if (use_optimistic_mode) {
|
||||||
|
bool is_converged = true;
|
||||||
|
absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
|
||||||
|
for (Node* n : topo) {
|
||||||
|
if (!n->IsMerge()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const Edge* e;
|
||||||
|
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
|
||||||
|
if (e == nullptr) {
|
||||||
|
// Skip acyclic merge nodes.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Node* merge = n;
|
||||||
|
// Note that here uses frame names instead of root frame names. In the
|
||||||
|
// case of a nested while loop, each level of while loops can have merges
|
||||||
|
// with different predicate instances, while the merge nodes on the same
|
||||||
|
// level must have the same predicate instances.
|
||||||
|
absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
|
||||||
|
auto it = predicate_map_.find(TensorId(merge->name(), 0));
|
||||||
|
Predicate* merge_pred = it->second;
|
||||||
|
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
|
||||||
|
is_converged = false;
|
||||||
|
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||||
|
<< " does not converge because node " << merge->name()
|
||||||
|
<< " cannot be mapped into the AndRecurrence form.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
|
||||||
|
if (!insert_result.second) {
|
||||||
|
// If we have already seen this frame name, verify the predicate is the
|
||||||
|
// same as the previously seen one's.
|
||||||
|
Predicate* curr_andrec = merge_pred;
|
||||||
|
Predicate* prev_andrec = insert_result.first->second;
|
||||||
|
if (curr_andrec != prev_andrec) {
|
||||||
|
is_converged = false;
|
||||||
|
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||||
|
<< " does not converge. Seeing different Merge predicates: \n"
|
||||||
|
<< curr_andrec->ToString() << " and \n"
|
||||||
|
<< prev_andrec->ToString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the assigned predicates if the optimistic mode does not converge.
|
||||||
|
if (!is_converged) {
|
||||||
|
for (Node* n : topo) {
|
||||||
|
for (int oid = 0; oid < n->num_outputs(); ++oid) {
|
||||||
|
predicate_map_.erase(TensorId(n->name(), oid));
|
||||||
|
}
|
||||||
|
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (success != nullptr) {
|
||||||
|
*success = is_converged;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1149,7 +1463,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
|||||||
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
|
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
|
||||||
std::unique_ptr<DeadnessAnalysisImpl> analysis(
|
std::unique_ptr<DeadnessAnalysisImpl> analysis(
|
||||||
new DeadnessAnalysisImpl(&graph));
|
new DeadnessAnalysisImpl(&graph));
|
||||||
TF_RETURN_IF_ERROR(analysis->Populate());
|
TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
|
||||||
|
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
analysis->Print();
|
analysis->Print();
|
||||||
@ -1170,22 +1484,18 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace deadness_analysis_internal {
|
namespace deadness_analysis_internal {
|
||||||
Status ComputePredicates(const Graph& graph,
|
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||||
PredicateMapTy* out_predicate_map) {
|
bool enable_optimistic) {
|
||||||
DeadnessAnalysisImpl impl(&graph);
|
DeadnessAnalysisImpl impl(&graph);
|
||||||
TF_RETURN_IF_ERROR(impl.Populate());
|
TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
|
||||||
*out_predicate_map = impl.PredicateMapAsString();
|
*out_predicate_map = impl.PredicateMapAsString();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ComputePredicates(const Graph& graph,
|
|
||||||
absl::Span<Node* const> reverse_post_order,
|
|
||||||
PredicateMapTy* out_predicate_map) {
|
|
||||||
DeadnessAnalysisImpl impl(&graph);
|
|
||||||
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
|
|
||||||
*out_predicate_map = impl.PredicateMapAsString();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace deadness_analysis_internal
|
} // namespace deadness_analysis_internal
|
||||||
|
|
||||||
|
string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
|
||||||
|
return static_cast<Predicate*>(predicate.pred_)->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -82,6 +82,8 @@ class DeadnessAnalysis {
|
|||||||
virtual void Print() const = 0;
|
virtual void Print() const = 0;
|
||||||
virtual ~DeadnessAnalysis();
|
virtual ~DeadnessAnalysis();
|
||||||
|
|
||||||
|
string DebugString(DeadnessPredicate predicate) const;
|
||||||
|
|
||||||
// Run the deadness analysis over `graph` and returns an error or a populated
|
// Run the deadness analysis over `graph` and returns an error or a populated
|
||||||
// instance of DeadnessAnalysis in `result`.
|
// instance of DeadnessAnalysis in `result`.
|
||||||
static Status Run(const Graph& graph,
|
static Status Run(const Graph& graph,
|
||||||
|
@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
|
|||||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||||
// testing purposes only.
|
// testing purposes only.
|
||||||
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
|
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
|
||||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||||
|
bool enable_optimistic = true);
|
||||||
|
|
||||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
|
||||||
// testing purposes only. Makes deadness analysis visit the graph in the order
|
|
||||||
// specified in `reverse_post_order` which must be a valid RPO for the graph
|
|
||||||
// minus NextIteration->Merge edges.
|
|
||||||
Status ComputePredicates(const Graph& graph,
|
|
||||||
absl::Span<Node* const> reverse_post_order,
|
|
||||||
PredicateMapTy* out_predicate_map);
|
|
||||||
} // namespace deadness_analysis_internal
|
} // namespace deadness_analysis_internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
PredicateMapTy predicate_map;
|
PredicateMapTy predicate_map;
|
||||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/true));
|
||||||
|
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
|
||||||
|
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
|
||||||
|
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||||
|
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
PredicateMapTy predicate_map;
|
||||||
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/false));
|
||||||
|
|
||||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||||
"{#true,&,*iv0/cond:0}<loop>");
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
|||||||
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
|
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
|
||||||
FixupSourceAndSinkEdges(root.graph());
|
FixupSourceAndSinkEdges(root.graph());
|
||||||
|
|
||||||
// To make deadness analysis think that dependent_iv is a loop we need an RPO
|
|
||||||
// that visits the merge before the backedge. This is a legal RPO for
|
|
||||||
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
|
|
||||||
// Right now dependent_iv has an edge from Merge to NextIteration so do the
|
|
||||||
// RPO with this edge in place. Then remove this edge to get our test case.
|
|
||||||
std::vector<Node*> rpo;
|
|
||||||
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
|
|
||||||
/*edge_filter=*/[](const Edge& edge) {
|
|
||||||
return !edge.src()->IsNextIteration();
|
|
||||||
});
|
|
||||||
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||||
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
|
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
|
||||||
|
|
||||||
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
PredicateMapTy predicate_map;
|
PredicateMapTy predicate_map;
|
||||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/true));
|
||||||
|
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<frame>");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
PredicateMapTy predicate_map;
|
||||||
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/false));
|
||||||
|
|
||||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||||
"div0/iv:0");
|
"div0/iv:0");
|
||||||
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
PredicateMapTy predicate_map;
|
PredicateMapTy predicate_map;
|
||||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/true));
|
||||||
|
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||||
|
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
|
||||||
|
"{(*iv_outer/cond:0 & "
|
||||||
|
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
|
||||||
|
"cond:0}<inner_loop;outer_loop>");
|
||||||
|
|
||||||
|
// enable_optimistic = true or not should produce the same results because
|
||||||
|
// of fallback. However, note that the order of iv_inner/cond:0 and
|
||||||
|
// iv_inner/iv:0 is different because the optimistic approach does not
|
||||||
|
// create predicates for all merges and it can change the predicate id and
|
||||||
|
// hence the symbol order.
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
|
||||||
|
"{{#true,&,(iv_outer/iv:0 & "
|
||||||
|
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
|
||||||
|
"iv_inner/iv:0)}<inner_loop;outer_loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||||
|
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||||
|
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
PredicateMapTy predicate_map;
|
||||||
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/false));
|
||||||
|
|
||||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||||
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
|||||||
"{{#true,&,(iv_outer/iv:0 & "
|
"{{#true,&,(iv_outer/iv:0 & "
|
||||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||||
|
|
||||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||||
"{{#true,&,(iv_outer/iv:0 & "
|
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
|
||||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
|
||||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||||
"{{#true,&,(iv_outer/iv:0 & "
|
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
|
||||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
InductionVarInfo iv_outer =
|
||||||
|
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
|
||||||
|
Output enter_constant_outer_loop = ops::internal::Enter(
|
||||||
|
root.WithOpName("constant_enter_outer_loop"),
|
||||||
|
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
|
||||||
|
ops::internal::Enter::Attrs().IsConstant(true));
|
||||||
|
ops::Switch inner_value(root.WithOpName("outer_is_live"),
|
||||||
|
enter_constant_outer_loop, iv_outer.loop_cond);
|
||||||
|
InductionVarInfo iv_inner = CreateInductionVariable(
|
||||||
|
root, "iv_inner", "inner_loop", inner_value.output_true);
|
||||||
|
|
||||||
|
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
|
||||||
|
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||||
|
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
|
||||||
|
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||||
|
|
||||||
|
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
|
||||||
|
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
|
||||||
|
div0_outer.induction_var);
|
||||||
|
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
|
||||||
|
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
|
||||||
|
div1_outer.induction_var);
|
||||||
|
|
||||||
|
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
|
||||||
|
"tensor_a", "sender", 0, "receiver");
|
||||||
|
Output capture_enter_outer = ops::internal::Enter(
|
||||||
|
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
|
||||||
|
ops::internal::Enter::Attrs().IsConstant(true));
|
||||||
|
Output capture_enter_inner = ops::internal::Enter(
|
||||||
|
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
|
||||||
|
ops::internal::Enter::Attrs().IsConstant(true));
|
||||||
|
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
|
||||||
|
capture_enter_inner);
|
||||||
|
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||||
|
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
|
||||||
|
|
||||||
|
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
|
||||||
|
div1_inner.induction_var);
|
||||||
|
|
||||||
|
VLogGraphIfAsked(*root.graph());
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<DeadnessAnalysis> result;
|
||||||
|
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool has_inputs_with_mismatching_deadness,
|
||||||
|
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||||
|
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
|
||||||
|
DependentInductionVar div0 =
|
||||||
|
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
|
||||||
|
DependentInductionVar div1 =
|
||||||
|
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
|
||||||
|
FixupSourceAndSinkEdges(root.graph());
|
||||||
|
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
|
||||||
|
div0.latch.output_true.node(), 0));
|
||||||
|
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
|
||||||
|
div1.latch.output_true.node(), 0));
|
||||||
|
|
||||||
|
VLogGraphIfAsked(*root.graph());
|
||||||
|
|
||||||
|
{
|
||||||
|
PredicateMapTy predicate_map;
|
||||||
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/true));
|
||||||
|
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
|
|
||||||
|
// This tests the rule {S,&,X} & ~X => S.
|
||||||
|
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
|
||||||
|
div1.latch.output_false.index()};
|
||||||
|
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
PredicateMapTy predicate_map;
|
||||||
|
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||||
|
/*enable_optimistic=*/false));
|
||||||
|
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||||
|
"{#true,&,*iv0/cond:0}<loop>");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
|
||||||
|
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
|
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
|
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -52,16 +52,6 @@ typedef std::function<Status(
|
|||||||
// 'group_attribute' must be a string valued-attribute that names the new
|
// 'group_attribute' must be a string valued-attribute that names the new
|
||||||
// functions to introduce.
|
// functions to introduce.
|
||||||
//
|
//
|
||||||
// 'outside_compilation_attribute' must be a string-valued attribute that is
|
|
||||||
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
|
|
||||||
// cluster within the subgraph. A cluster is formed from the set of nodes with
|
|
||||||
// the same value of outside_compilation_subgraph and group_attribute. The nodes
|
|
||||||
// in an outside_compilation cluster are left in the original graph. Edges
|
|
||||||
// crossing from the subgraph to an outside_compilation cluster nested in the
|
|
||||||
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
|
|
||||||
// crossing from an outside_compilation cluster into its enclosing subgraph are
|
|
||||||
// lifted into a SendFromHost/RecvFromHost pair of nodes.
|
|
||||||
//
|
|
||||||
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
||||||
// function conversion.
|
// function conversion.
|
||||||
//
|
//
|
||||||
@ -74,10 +64,9 @@ typedef std::function<Status(
|
|||||||
// dep from B. Originally D must run after C, post-transformation this
|
// dep from B. Originally D must run after C, post-transformation this
|
||||||
// dependency is lost.
|
// dependency is lost.
|
||||||
Status EncapsulateSubgraphsInFunctions(
|
Status EncapsulateSubgraphsInFunctions(
|
||||||
string group_attribute, string outside_compilation_attribute,
|
string group_attribute, const Graph& graph_in,
|
||||||
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
|
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
|
||||||
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
|
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
|
||||||
FunctionLibraryDefinition* library);
|
|
||||||
|
|
||||||
// The attribute that marks function calls produced by the encapsulate
|
// The attribute that marks function calls produced by the encapsulate
|
||||||
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
|
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
|
||||||
|
@ -514,10 +514,10 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
|||||||
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||||
|
|
||||||
std::unique_ptr<Graph> graph_out;
|
std::unique_ptr<Graph> graph_out;
|
||||||
s = EncapsulateSubgraphsInFunctions(
|
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
|
||||||
"_encapsulate", /*outside_compilation_attribute=*/"", *graph,
|
|
||||||
/*rewrite_subgraph_fn=*/{},
|
/*rewrite_subgraph_fn=*/{},
|
||||||
/*reuse_existing_functions=*/false, &graph_out, lib_def.get());
|
/*reuse_existing_functions=*/false,
|
||||||
|
&graph_out, lib_def.get());
|
||||||
if (!s.ok()) return s;
|
if (!s.ok()) return s;
|
||||||
|
|
||||||
std::unordered_map<string, XlaClusterInfo> clusters;
|
std::unordered_map<string, XlaClusterInfo> clusters;
|
||||||
@ -746,7 +746,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
|
|||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||||
"_cluster", "", graph_before_encapsulation,
|
"_cluster", graph_before_encapsulation,
|
||||||
/*rewrite_subgraph_fn=*/{},
|
/*rewrite_subgraph_fn=*/{},
|
||||||
/*reuse_existing_functions=*/false, &graph, &library));
|
/*reuse_existing_functions=*/false, &graph, &library));
|
||||||
|
|
||||||
@ -798,7 +798,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
|
|||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||||
int guaranteed_consts = 0;
|
int guaranteed_consts = 0;
|
||||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||||
"_encapsulate", "", graph_before,
|
"_encapsulate", graph_before,
|
||||||
/*rewrite_subgraph_fn=*/
|
/*rewrite_subgraph_fn=*/
|
||||||
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
||||||
std::unique_ptr<Graph>* graph_ptr,
|
std::unique_ptr<Graph>* graph_ptr,
|
||||||
@ -843,7 +843,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
|
|||||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||||
int guaranteed_consts = 0;
|
int guaranteed_consts = 0;
|
||||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||||
"_encapsulate", "", graph_before,
|
"_encapsulate", graph_before,
|
||||||
/*rewrite_subgraph_fn=*/
|
/*rewrite_subgraph_fn=*/
|
||||||
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
||||||
std::unique_ptr<Graph>* graph_ptr,
|
std::unique_ptr<Graph>* graph_ptr,
|
||||||
@ -1109,7 +1109,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
|||||||
absl::Span<const string>(
|
absl::Span<const string>(
|
||||||
{"_xla_token_arg_node",
|
{"_xla_token_arg_node",
|
||||||
"outside_compilation_O1_host_compute"})}},
|
"outside_compilation_O1_host_compute"})}},
|
||||||
{"F"}},
|
{"F", "outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O1_host_compute"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"XlaHostCompute",
|
"XlaHostCompute",
|
||||||
{"C:o:0", "D:o:0"},
|
{"C:o:0", "D:o:0"},
|
||||||
@ -1990,7 +1990,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
absl::Span<const string>(
|
absl::Span<const string>(
|
||||||
{"_xla_token_arg_node",
|
{"_xla_token_arg_node",
|
||||||
"outside_compilation_O1_host_compute"})}}},
|
"outside_compilation_O1_host_compute"})}},
|
||||||
|
{"outside_compilation_O1_host_compute"}},
|
||||||
},
|
},
|
||||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
{"h_0_retval_retval", "H:o:0"}});
|
{"h_0_retval_retval", "H:o:0"}});
|
||||||
@ -2117,7 +2118,8 @@ TEST(EncapsulateSubgraphsTest,
|
|||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
absl::Span<const string>(
|
absl::Span<const string>(
|
||||||
{"_xla_token_arg_node",
|
{"_xla_token_arg_node",
|
||||||
"outside_compilation_O1_host_compute"})}}},
|
"outside_compilation_O1_host_compute"})}},
|
||||||
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O1_host_compute"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"XlaHostCompute",
|
"XlaHostCompute",
|
||||||
{"D:o:0"},
|
{"D:o:0"},
|
||||||
@ -2267,7 +2269,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
|||||||
{"_xla_token_input_nodes",
|
{"_xla_token_input_nodes",
|
||||||
absl::Span<const string>(
|
absl::Span<const string>(
|
||||||
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
|
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
|
||||||
{}},
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O3_host_compute"},
|
{{"outside_compilation_O3_host_compute"},
|
||||||
"XlaHostCompute",
|
"XlaHostCompute",
|
||||||
{"D:o:0"},
|
{"D:o:0"},
|
||||||
@ -2282,7 +2284,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
|||||||
absl::Span<const string>({"_xla_token_arg_node",
|
absl::Span<const string>({"_xla_token_arg_node",
|
||||||
"outside_compilation_O1_host_compute",
|
"outside_compilation_O1_host_compute",
|
||||||
"outside_compilation_O2_host_compute"})}},
|
"outside_compilation_O2_host_compute"})}},
|
||||||
{}}},
|
{"outside_compilation_O1_host_compute",
|
||||||
|
"outside_compilation_O2_host_compute"}}},
|
||||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
{"h_0_retval_retval", "H:o:0"}});
|
{"h_0_retval_retval", "H:o:0"}});
|
||||||
|
|
||||||
|
@ -231,9 +231,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
|||||||
|
|
||||||
auto output = absl::make_unique<Graph>((*graph)->op_registry());
|
auto output = absl::make_unique<Graph>((*graph)->op_registry());
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
EncapsulateSubgraphsInFunctions(
|
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
|
||||||
kXlaClusterAttr, "", **graph, RewriteSubgraph,
|
/*reuse_existing_functions=*/true,
|
||||||
/*reuse_existing_functions=*/true, &output, flib_def),
|
&output, flib_def),
|
||||||
"EncapsulateXlaComputationsPass failed");
|
"EncapsulateXlaComputationsPass failed");
|
||||||
graph->swap(output);
|
graph->swap(output);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -393,7 +393,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
|||||||
// Replace outside compilation function call node with XlaHostCompute node.
|
// Replace outside compilation function call node with XlaHostCompute node.
|
||||||
// If the function call node has no input/output edges, we will just remove it
|
// If the function call node has no input/output edges, we will just remove it
|
||||||
// and not create a XlaHostCompute node.
|
// and not create a XlaHostCompute node.
|
||||||
Status ReplaceOrRemoveOutsideCompilationCallNode(
|
xla::StatusOr<Node*> ReplaceOrRemoveOutsideCompilationCallNode(
|
||||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||||
// If the function call node has no input/output edges, just remove it.
|
// If the function call node has no input/output edges, just remove it.
|
||||||
@ -413,7 +413,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
|||||||
if (!has_edge) {
|
if (!has_edge) {
|
||||||
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
|
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
|
||||||
g->RemoveNode(call_node);
|
g->RemoveNode(call_node);
|
||||||
return Status::OK();
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build XlaHostCompute NodeDef.
|
// Build XlaHostCompute NodeDef.
|
||||||
@ -424,7 +424,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
|||||||
ReplaceNode(g, call_node, node_def));
|
ReplaceNode(g, call_node, node_def));
|
||||||
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
|
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
|
||||||
|
|
||||||
return Status::OK();
|
return host_compute_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resets "device_ordinal" attr to placeholder value for related nodes
|
// Resets "device_ordinal" attr to placeholder value for related nodes
|
||||||
@ -1634,7 +1634,7 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
RewriteOutsideCompilationSubgraphFn rewrite_fn(
|
RewriteOutsideCompilationSubgraphFn rewrite_fn(
|
||||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
|
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
|
||||||
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
|
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
|
||||||
outside_compilation_attr_name, "", *fbody->graph, rewrite_fn,
|
outside_compilation_attr_name, *fbody->graph, rewrite_fn,
|
||||||
/*reuse_existing_functions=*/true, &graph_out, fld));
|
/*reuse_existing_functions=*/true, &graph_out, fld));
|
||||||
|
|
||||||
// Replace outside_compilation function nodes with HostCompute ops.
|
// Replace outside_compilation function nodes with HostCompute ops.
|
||||||
@ -1670,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::map<string, Node*> host_compute_nodes;
|
||||||
for (Node* n : outside_compilation_nodes) {
|
for (Node* n : outside_compilation_nodes) {
|
||||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||||
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
|
auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode(
|
||||||
graph_out.get(), n, host_compute_core, *cluster_deps));
|
graph_out.get(), n, host_compute_core, *cluster_deps);
|
||||||
|
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
||||||
|
Node* host_compute_node = host_compute_node_or.ValueOrDie();
|
||||||
|
if (host_compute_node) {
|
||||||
|
host_compute_nodes[host_compute_node->name()] = host_compute_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For XlaHostCompute nodes with dependencies, add control edges between them
|
||||||
|
// so XlaCompiler can handle them in correct order.
|
||||||
|
for (auto iter : host_compute_nodes) {
|
||||||
|
Node* host_compute_node = iter.second;
|
||||||
|
std::vector<string> token_input_node_names;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
|
||||||
|
kXlaTokenInputNodesAttrName,
|
||||||
|
&token_input_node_names));
|
||||||
|
for (const string& node_name : token_input_node_names) {
|
||||||
|
if (node_name == kXlaTokenArgNodeName) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = host_compute_nodes.find(node_name);
|
||||||
|
if (iter != host_compute_nodes.end()) {
|
||||||
|
graph_out->AddControlEdge(iter->second, host_compute_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle nodes with associated functions.
|
// Handle nodes with associated functions.
|
||||||
|
@ -990,6 +990,16 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
|
|||||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||||
"_xla_token_input_nodes", &token_input_nodes));
|
"_xla_token_input_nodes", &token_input_nodes));
|
||||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||||
|
|
||||||
|
// Check there is a control edge from host_compute_0 to host_compute_1.
|
||||||
|
bool has_control_edge = false;
|
||||||
|
for (const Edge *e : host_compute_1->in_edges()) {
|
||||||
|
if (e->IsControlEdge() && e->src() == host_compute_0) {
|
||||||
|
has_control_edge = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(has_control_edge);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||||
@ -1062,5 +1072,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
|
|||||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||||
"_xla_token_input_nodes", &token_input_nodes));
|
"_xla_token_input_nodes", &token_input_nodes));
|
||||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||||
|
|
||||||
|
// Check there is a control edge from host_compute_0 to host_compute_1.
|
||||||
|
bool has_control_edge = false;
|
||||||
|
for (const Edge *e : host_compute_1->in_edges()) {
|
||||||
|
if (e->IsControlEdge() && e->src() == host_compute_0) {
|
||||||
|
has_control_edge = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(has_control_edge);
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
],
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
],
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -29,6 +28,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
|
@ -61,7 +61,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
|||||||
DeviceType device_type = ctx->device_type();
|
DeviceType device_type = ctx->device_type();
|
||||||
se::Platform::Id platform_id = nullptr;
|
se::Platform::Id platform_id = nullptr;
|
||||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||||
std::unique_ptr<XlaAllocator> xla_allocator;
|
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
|
||||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||||
|
|
||||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||||
@ -93,7 +93,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
|||||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||||
|
|
||||||
xla_allocator = absl::make_unique<XlaAllocator>(
|
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
|
||||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/util/stream_executor_util.h"
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -36,10 +37,10 @@ class XlaPlatformInfo {
|
|||||||
public:
|
public:
|
||||||
XlaPlatformInfo() : device_type_("") {}
|
XlaPlatformInfo() : device_type_("") {}
|
||||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
explicit XlaPlatformInfo(
|
||||||
se::Platform::Id platform_id,
|
const DeviceType device_type, se::Platform::Id platform_id,
|
||||||
const XlaDevice::Metadata* xla_device_metadata,
|
const XlaDevice::Metadata* xla_device_metadata,
|
||||||
std::unique_ptr<XlaAllocator> xla_allocator,
|
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
|
||||||
se::DeviceMemoryAllocator* device_allocator)
|
se::DeviceMemoryAllocator* device_allocator)
|
||||||
: device_type_(device_type),
|
: device_type_(device_type),
|
||||||
platform_id_(platform_id),
|
platform_id_(platform_id),
|
||||||
@ -84,8 +85,8 @@ class XlaPlatformInfo {
|
|||||||
// then device_allocator_ is the xla::Backend's memory allocator and
|
// then device_allocator_ is the xla::Backend's memory allocator and
|
||||||
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
|
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
|
||||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||||
// XlaAllocator instance.
|
// se::TfAllocatorAdapter instance.
|
||||||
std::unique_ptr<XlaAllocator> xla_allocator_;
|
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
|
||||||
se::DeviceMemoryAllocator* device_allocator_;
|
se::DeviceMemoryAllocator* device_allocator_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||||
|
@ -229,32 +229,18 @@ class MarkForCompilationPassImpl {
|
|||||||
// Initialize some internal data structures.
|
// Initialize some internal data structures.
|
||||||
Status Initialize();
|
Status Initialize();
|
||||||
|
|
||||||
// Runs through all the nodes in `cycles_graph_` and tries to create clusters.
|
// Runs through the entire cluster graph in post-order and calls `fn(from,
|
||||||
// Returns true if any new clusters were created.
|
// to)` on each edge. `fn(from, to)` is expected to return true if it was
|
||||||
StatusOr<bool> RunEdgeContractionLoopInPostOrderOnce();
|
// able to contract `from`->`to`.
|
||||||
|
//
|
||||||
|
// Returns true if `fn` returned true for any edge.
|
||||||
|
template <typename FnTy>
|
||||||
|
StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
|
||||||
|
|
||||||
// Runs through all the nodes in `cycles_graph_` and tries to contract high
|
// If from->to is a "preferred" edge (i.e. if we have a choice, we want to
|
||||||
// priority edges for clusters. Returns true if any new clusters were created.
|
// prioritize contracting from->to over contracting other edges) then
|
||||||
//
|
// contracts it and returns true. Else returns false.
|
||||||
// There are potentially many maximal clustering results, but they will not
|
StatusOr<bool> ContractEdgeIfPreferred(Cluster* from, Cluster* to);
|
||||||
// all be equally performant. Some clustering decision are likely to improve
|
|
||||||
// performance much more than others, and we cannot order contractions on this
|
|
||||||
// cost function, nor can we look at global information while deciding on
|
|
||||||
// individual edges to contract. Instead, we will make decisions on these
|
|
||||||
// important edges then make decisions on all other edges, causing the highest
|
|
||||||
// chance of all most important edges to be contracted.
|
|
||||||
//
|
|
||||||
// An example of where this might occur is with a digraph:
|
|
||||||
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
|
|
||||||
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
|
|
||||||
// should be clustered with A because it will prevent a potentially large
|
|
||||||
// tensor from A being computed and copied.
|
|
||||||
//
|
|
||||||
// This pass will ensure that contraction happens, which cannot be enforced in
|
|
||||||
// a single pass with the current algorithm.
|
|
||||||
// graph and prevent B->C from being clusterd in anticipation of a later A->B
|
|
||||||
// cluster.
|
|
||||||
StatusOr<bool> ContractPreferredEdges();
|
|
||||||
|
|
||||||
// Contracts as many edges as possible to create XLA clusters. After this
|
// Contracts as many edges as possible to create XLA clusters. After this
|
||||||
// finishes the clustering decisions made are implicitly stored in
|
// finishes the clustering decisions made are implicitly stored in
|
||||||
@ -276,10 +262,6 @@ class MarkForCompilationPassImpl {
|
|||||||
// true if successful.
|
// true if successful.
|
||||||
StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
|
StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
|
||||||
|
|
||||||
// Tries to contract each edge from `cluster_from`. Returns true if any edges
|
|
||||||
// were contracted, false otherwise.
|
|
||||||
StatusOr<bool> TryToContractEdgesFrom(Cluster* cluster_from);
|
|
||||||
|
|
||||||
// Nodes that XLA can compile are put in `compilation_candidates_`.
|
// Nodes that XLA can compile are put in `compilation_candidates_`.
|
||||||
Status FindCompilationCandidates();
|
Status FindCompilationCandidates();
|
||||||
|
|
||||||
@ -401,6 +383,13 @@ class MarkForCompilationPassImpl {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
|
||||||
|
absl::string_view reason) {
|
||||||
|
return absl::StrCat("Could not contract ", from->DebugString(*graph_),
|
||||||
|
" -> ", to->DebugString(*graph_), " because ", reason,
|
||||||
|
".");
|
||||||
|
}
|
||||||
|
|
||||||
DebugOptions debug_options_;
|
DebugOptions debug_options_;
|
||||||
Graph* graph_;
|
Graph* graph_;
|
||||||
FunctionLibraryDefinition* flib_def_;
|
FunctionLibraryDefinition* flib_def_;
|
||||||
@ -611,7 +600,8 @@ Status MarkForCompilationPassImpl::Initialize() {
|
|||||||
return BuildInitialClusterSet();
|
return BuildInitialClusterSet();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
template <typename FnTy>
|
||||||
|
StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
|
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
|
||||||
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||||
@ -632,8 +622,18 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cluster_to->cluster_size() == 1) {
|
TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
|
||||||
Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode());
|
changed |= contracted_edge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> MarkForCompilationPassImpl::ContractEdgeIfPreferred(
|
||||||
|
Cluster* from, Cluster* to) {
|
||||||
|
if (to->cluster_size() == 1) {
|
||||||
|
Node* n = graph_->FindNodeId(to->GetIdOfOnlyNode());
|
||||||
|
|
||||||
// Shape consuming operations are desirable to cluster with their
|
// Shape consuming operations are desirable to cluster with their
|
||||||
// operands because they return a small set of scalar values after
|
// operands because they return a small set of scalar values after
|
||||||
@ -644,43 +644,11 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
|||||||
// tensor that must be computed and possible transposed/copied before
|
// tensor that must be computed and possible transposed/copied before
|
||||||
// the second cluster executes.
|
// the second cluster executes.
|
||||||
if (IsShapeConsumerOp(*n)) {
|
if (IsShapeConsumerOp(*n)) {
|
||||||
TF_ASSIGN_OR_RETURN(bool contracted_edge,
|
return TryToContractEdge(from, to);
|
||||||
TryToContractEdge(cluster_from, cluster_to));
|
|
||||||
changed |= contracted_edge;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return changed;
|
return false;
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<bool>
|
|
||||||
MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() {
|
|
||||||
bool changed = false;
|
|
||||||
// Iterating over the graph once in post-order is sufficient to produce a
|
|
||||||
// maximal clustering:
|
|
||||||
//
|
|
||||||
// A. We visit a cluster only after maximally clustering all its children.
|
|
||||||
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
|
|
||||||
// its children that could have been absorbed into `node` have been
|
|
||||||
// absorbed.
|
|
||||||
// C. We have an invariant that making a cluster larger does not make edges
|
|
||||||
// leaving it more contractable. That is, if we have
|
|
||||||
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
|
|
||||||
// to contract Y->Z if Y->Z was not contractible originally.
|
|
||||||
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
|
|
||||||
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
|
||||||
if (!cluster_from) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
|
|
||||||
TryToContractEdgesFrom(cluster_from));
|
|
||||||
changed |= contracted_one_edge;
|
|
||||||
}
|
|
||||||
|
|
||||||
return changed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
||||||
@ -694,25 +662,68 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
|||||||
// without restrictions. This helps to minimize data output from clusters (and
|
// without restrictions. This helps to minimize data output from clusters (and
|
||||||
// possible transpose operations before outputs) that might occur if a
|
// possible transpose operations before outputs) that might occur if a
|
||||||
// ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations.
|
// ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations.
|
||||||
TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges());
|
//
|
||||||
|
// There are potentially many maximal clustering results, but they will not
|
||||||
|
// all be equally performant. Some clustering decision are likely to improve
|
||||||
|
// performance much more than others, and we cannot order contractions on this
|
||||||
|
// cost function, nor can we look at global information while deciding on
|
||||||
|
// individual edges to contract. Instead, we will make decisions on these
|
||||||
|
// important edges then make decisions on all other edges, causing the highest
|
||||||
|
// chance of all most important edges to be contracted.
|
||||||
|
//
|
||||||
|
// An example of where this might occur is with a digraph:
|
||||||
|
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
|
||||||
|
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
|
||||||
|
// should be clustered with A because it will prevent a potentially large
|
||||||
|
// tensor from A being computed and copied.
|
||||||
|
//
|
||||||
|
// This pass will ensure that contraction happens, which cannot be enforced in
|
||||||
|
// a single pass with the current algorithm.
|
||||||
|
// graph and prevent B->C from being clusterd in anticipation of a later A->B
|
||||||
|
// cluster.
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
|
TF_ASSIGN_OR_RETURN(bool changed,
|
||||||
|
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||||
|
return ContractEdgeIfPreferred(from, to);
|
||||||
|
}));
|
||||||
|
|
||||||
// Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the
|
// Iterating over the whole graph once in post-order is sufficient to produce
|
||||||
// linear time post-order scheme has been battle tested we can move this to
|
// a maximal clustering:
|
||||||
// happen only in debug builds.
|
//
|
||||||
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
|
// A. We visit a cluster only after maximally clustering all its children.
|
||||||
|
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
|
||||||
|
// its children that could have been absorbed into `node` have been
|
||||||
|
// absorbed.
|
||||||
|
// C. We have an invariant that making a cluster larger does not make edges
|
||||||
|
// leaving it more contractable. That is, if we have
|
||||||
|
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
|
||||||
|
// to contract Y->Z if Y->Z was not contractible originally.
|
||||||
|
TF_ASSIGN_OR_RETURN(changed,
|
||||||
|
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||||
|
return TryToContractEdge(from, to);
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Check that the conclusion made above (that iterating over the graph once in
|
||||||
|
// post order gives a maximal clustering) holds. Once the linear time
|
||||||
|
// post-order scheme has been battle tested we can move this to happen only in
|
||||||
|
// debug builds.
|
||||||
|
TF_ASSIGN_OR_RETURN(changed,
|
||||||
|
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||||
|
return TryToContractEdge(from, to);
|
||||||
|
}));
|
||||||
TF_RET_CHECK(!changed);
|
TF_RET_CHECK(!changed);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::atomic<int64> cluster_sequence_num;
|
||||||
|
|
||||||
|
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
|
||||||
|
|
||||||
Status MarkForCompilationPassImpl::CreateClusters() {
|
Status MarkForCompilationPassImpl::CreateClusters() {
|
||||||
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
|
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
|
||||||
clusters_created_ = true;
|
clusters_created_ = true;
|
||||||
|
|
||||||
static std::atomic<int64> cluster_sequence_num;
|
|
||||||
|
|
||||||
// Names for each cluster.
|
// Names for each cluster.
|
||||||
std::unordered_map<int, string> cluster_names;
|
std::unordered_map<int, string> cluster_names;
|
||||||
|
|
||||||
@ -745,7 +756,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
|
|||||||
string& name = cluster_names[cluster->cycles_graph_node_id()];
|
string& name = cluster_names[cluster->cycles_graph_node_id()];
|
||||||
|
|
||||||
if (name.empty()) {
|
if (name.empty()) {
|
||||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
n->AddAttr(kXlaClusterAttr, name);
|
n->AddAttr(kXlaClusterAttr, name);
|
||||||
@ -1065,8 +1076,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
|
|||||||
|
|
||||||
bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
|
bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
|
||||||
Cluster* from, Cluster* to, absl::string_view reason) {
|
Cluster* from, Cluster* to, absl::string_view reason) {
|
||||||
VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> "
|
VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
|
||||||
<< to->DebugString(*graph_) << " because " << reason << ".";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1075,8 +1085,14 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
|||||||
DCHECK(from->deadness_predicate().has_value() ==
|
DCHECK(from->deadness_predicate().has_value() ==
|
||||||
to->deadness_predicate().has_value());
|
to->deadness_predicate().has_value());
|
||||||
if (from->deadness_predicate() != to->deadness_predicate()) {
|
if (from->deadness_predicate() != to->deadness_predicate()) {
|
||||||
return LogNotContractableAndReturnFalse(
|
VLOG(3) << EdgeContractionFailureMsg(
|
||||||
from, to, "the two nodes have mismatching deadness");
|
from, to,
|
||||||
|
absl::StrCat(
|
||||||
|
"the two nodes have mismatching deadness: ",
|
||||||
|
deadness_analysis_->DebugString(*from->deadness_predicate()),
|
||||||
|
" and ",
|
||||||
|
deadness_analysis_->DebugString(*to->deadness_predicate())));
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
||||||
@ -1133,32 +1149,6 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
|||||||
return MergeClusters(from, to);
|
return MergeClusters(from, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
|
|
||||||
Cluster* cluster_from) {
|
|
||||||
bool changed = false;
|
|
||||||
|
|
||||||
// Make a copy of the set of successors because we may modify the graph in
|
|
||||||
// TryToContractEdge.
|
|
||||||
std::vector<int32> successors_copy =
|
|
||||||
cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
|
||||||
|
|
||||||
for (int to : successors_copy) {
|
|
||||||
iteration_count_++;
|
|
||||||
|
|
||||||
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
|
||||||
if (!cluster_to) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool contracted_edge,
|
|
||||||
TryToContractEdge(cluster_from, cluster_to));
|
|
||||||
|
|
||||||
changed |= contracted_edge;
|
|
||||||
}
|
|
||||||
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MarkForCompilationPassImpl::Run() {
|
Status MarkForCompilationPassImpl::Run() {
|
||||||
// Make sure that kernels have been registered on the JIT device.
|
// Make sure that kernels have been registered on the JIT device.
|
||||||
XlaOpRegistry::RegisterCompilationKernels();
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
@ -1485,7 +1475,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
|||||||
op_filter.allow_control_trigger = true;
|
op_filter.allow_control_trigger = true;
|
||||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||||
op_filter.allow_slow_and_inaccurate_ops = true;
|
op_filter.allow_slow_ops = true;
|
||||||
|
op_filter.allow_inaccurate_ops = true;
|
||||||
|
|
||||||
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
|
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
|
||||||
.IsCompilableCall(ndef, flr);
|
.IsCompilableCall(ndef, flr);
|
||||||
@ -1522,4 +1513,8 @@ Status MarkForCompilationPass::RunForTest(
|
|||||||
|
|
||||||
return MarkForCompilation(options, debug_options);
|
return MarkForCompilation(options, debug_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||||
|
} // namespace testing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
|||||||
// function is compilable iff every operator in the function body is
|
// function is compilable iff every operator in the function body is
|
||||||
// compilable.
|
// compilable.
|
||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
|
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
// DO NOT USE IN PRODUCTION.
|
||||||
|
//
|
||||||
|
// Resets some internal state to let us write reliable unit tests.
|
||||||
|
void ResetClusterSequenceNumber();
|
||||||
|
} // namespace testing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||||
|
@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch(
|
|||||||
missing.push_back(entry.first);
|
missing.push_back(entry.first);
|
||||||
}
|
}
|
||||||
return errors::InvalidArgument("Missing shapes for nodes: ",
|
return errors::InvalidArgument("Missing shapes for nodes: ",
|
||||||
str_util::Join(missing, ","));
|
absl::StrJoin(missing, ","));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,8 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
|||||||
registration.cluster_control_trigger = true;
|
registration.cluster_control_trigger = true;
|
||||||
registration.elide_assert_and_checknumerics = true;
|
registration.elide_assert_and_checknumerics = true;
|
||||||
registration.cluster_variant_ops = true;
|
registration.cluster_variant_ops = true;
|
||||||
registration.cluster_slow_and_inaccurate_ops = true;
|
registration.cluster_slow_ops = true;
|
||||||
|
registration.cluster_inaccurate_ops = true;
|
||||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||||
|
|
||||||
static XlaDeviceOpRegistrations* registrations =
|
static XlaDeviceOpRegistrations* registrations =
|
||||||
|
@ -71,7 +71,7 @@ class XlaDeviceContext : public DeviceContext {
|
|||||||
StatusCallback done) const override;
|
StatusCallback done) const override;
|
||||||
|
|
||||||
xla::LocalClient* client() const { return client_; }
|
xla::LocalClient* client() const { return client_; }
|
||||||
se::Stream* stream() const { return stream_.get(); }
|
se::Stream* stream() const override { return stream_.get(); }
|
||||||
se::Stream* host_to_device_stream() const {
|
se::Stream* host_to_device_stream() const {
|
||||||
return host_to_device_stream_.get();
|
return host_to_device_stream_.get();
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/host_constant_op.h"
|
#include "tensorflow/core/kernels/host_constant_op.h"
|
||||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||||
#include "tensorflow/core/kernels/identity_op.h"
|
#include "tensorflow/core/kernels/identity_op.h"
|
||||||
|
#include "tensorflow/core/kernels/logging_ops.h"
|
||||||
#include "tensorflow/core/kernels/no_op.h"
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
#include "tensorflow/core/kernels/queue_op.h"
|
#include "tensorflow/core/kernels/queue_op.h"
|
||||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||||
@ -81,6 +82,11 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
||||||
|
|
||||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("Assert") \
|
||||||
|
.Device(DEVICE) \
|
||||||
|
.HostMemory("condition") \
|
||||||
|
.HostMemory("data"), \
|
||||||
|
AssertOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
|
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
|
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
@ -95,7 +95,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
|||||||
registration.cluster_control_trigger = true;
|
registration.cluster_control_trigger = true;
|
||||||
registration.elide_assert_and_checknumerics = true;
|
registration.elide_assert_and_checknumerics = true;
|
||||||
registration.cluster_variant_ops = true;
|
registration.cluster_variant_ops = true;
|
||||||
registration.cluster_slow_and_inaccurate_ops = true;
|
registration.cluster_slow_ops = true;
|
||||||
|
registration.cluster_inaccurate_ops = true;
|
||||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||||
|
|
||||||
static XlaDeviceOpRegistrations* registrations =
|
static XlaDeviceOpRegistrations* registrations =
|
||||||
|
@ -63,7 +63,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
|||||||
registration.cluster_control_trigger = true;
|
registration.cluster_control_trigger = true;
|
||||||
registration.elide_assert_and_checknumerics = true;
|
registration.elide_assert_and_checknumerics = true;
|
||||||
registration.cluster_variant_ops = true;
|
registration.cluster_variant_ops = true;
|
||||||
registration.cluster_slow_and_inaccurate_ops = true;
|
registration.cluster_slow_ops = true;
|
||||||
|
registration.cluster_inaccurate_ops = true;
|
||||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||||
registration);
|
registration);
|
||||||
|
|
||||||
|
@ -167,32 +167,6 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
|
|
||||||
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
|
||||||
|
|
||||||
XlaAllocator::~XlaAllocator() {}
|
|
||||||
|
|
||||||
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
|
|
||||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
|
||||||
AllocationAttributes attrs;
|
|
||||||
attrs.no_retry_on_failure = !retry_on_failure;
|
|
||||||
void* data = nullptr;
|
|
||||||
if (size != 0) {
|
|
||||||
data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
|
|
||||||
if (data == nullptr) {
|
|
||||||
return errors::ResourceExhausted(
|
|
||||||
"Out of memory while trying to allocate ", size, " bytes.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
|
||||||
device_ordinal, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
|
||||||
wrapped_->DeallocateRaw(mem.opaque());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
class XlaAllocator;
|
|
||||||
|
|
||||||
// Struct that represents a possibly-absent Tensor.
|
// Struct that represents a possibly-absent Tensor.
|
||||||
struct OptionalTensor {
|
struct OptionalTensor {
|
||||||
@ -104,74 +103,6 @@ class VariableInfo {
|
|||||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||||
EXCLUSIVE_LOCK_FUNCTION();
|
EXCLUSIVE_LOCK_FUNCTION();
|
||||||
|
|
||||||
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
|
|
||||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
|
|
||||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
|
||||||
class XlaAllocator : public se::DeviceMemoryAllocator {
|
|
||||||
public:
|
|
||||||
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
|
|
||||||
~XlaAllocator() override;
|
|
||||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
|
||||||
int device_ordinal, uint64 size, bool retry_on_failure) override;
|
|
||||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
|
|
||||||
|
|
||||||
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
|
|
||||||
// before GPU execution takes place. Tensorflow uses the ordering of the main
|
|
||||||
// compute stream to enforce a happens-before relationship between a memory
|
|
||||||
// allocation and code that reuses the same memory. If Tensorflow adds
|
|
||||||
// support for multiple GPU streams or allocators with different ordering
|
|
||||||
// requirements, this code may need to change.
|
|
||||||
// (This attribute has no effect on CPU.)
|
|
||||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Allocator* wrapped_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adapter class that wraps per-device TF allocators as an XLA allocator.
|
|
||||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
|
|
||||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
|
||||||
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
|
|
||||||
public:
|
|
||||||
MultiDeviceAdapter(
|
|
||||||
const se::Platform* platform,
|
|
||||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
|
|
||||||
: DeviceMemoryAllocator(platform),
|
|
||||||
tf_allocators_(std::move(tf_allocators)) {
|
|
||||||
for (const auto& tf_allocator : tf_allocators_) {
|
|
||||||
per_device_allocators_.emplace_back(platform, tf_allocator.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
|
||||||
int device_ordinal, uint64 size, bool retry_on_failure) override {
|
|
||||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
|
||||||
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
|
|
||||||
retry_on_failure);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
|
|
||||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
|
||||||
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
|
|
||||||
mem);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
|
|
||||||
// before GPU execution takes place. Tensorflow uses the ordering of the main
|
|
||||||
// compute stream to enforce a happens-before relationship between a memory
|
|
||||||
// allocation and code that reuses the same memory. If Tensorflow adds
|
|
||||||
// support for multiple GPU streams or allocators with different ordering
|
|
||||||
// requirements, this code may need to change.
|
|
||||||
// (This attribute has no effect on CPU.)
|
|
||||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
|
|
||||||
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
|
|
||||||
// not take ownership of its underlying Allocator).
|
|
||||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||||
// ShapedBuffers suitable for passing to an XLA computation.
|
// ShapedBuffers suitable for passing to an XLA computation.
|
||||||
class XlaComputationLaunchContext {
|
class XlaComputationLaunchContext {
|
||||||
|
@ -28,10 +28,9 @@
|
|||||||
** Please don't remove this file - it is supporting some 3rd party plugins **
|
** Please don't remove this file - it is supporting some 3rd party plugins **
|
||||||
"""
|
"""
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -954,7 +954,7 @@ tf_xla_py_test(
|
|||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "ternary_ops_test",
|
name = "ternary_ops_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["ternary_ops_test.py"],
|
srcs = ["ternary_ops_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
def f():
|
def f():
|
||||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
output = control_flow_ops.cond(
|
output = control_flow_ops.cond(
|
||||||
constant_op.constant(
|
constant_op.constant(True),
|
||||||
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
return output.stack()
|
return output.stack()
|
||||||
|
|
||||||
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondAndTensorArrayInDefun_constFolding(self):
|
||||||
|
g = ops.Graph()
|
||||||
|
with session.Session(graph=g), g.as_default(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
@function.defun
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.cond(
|
||||||
|
constant_op.constant(False),
|
||||||
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t = f()
|
||||||
|
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondAndTensorArray_xlaCompile(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.cond(
|
||||||
|
constant_op.constant(True),
|
||||||
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t, = xla.compile(f)
|
||||||
|
self.assertAllEqual([5.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
def testCondConstPropagation(self):
|
def testCondConstPropagation(self):
|
||||||
with self.session() as sess, self.test_scope():
|
with self.session() as sess, self.test_scope():
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testSwitchCaseAndTensorArray_xlaCompile(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.switch_case(
|
||||||
|
constant_op.constant(1), {
|
||||||
|
0: lambda: ta.write(0, 5.),
|
||||||
|
1: lambda: ta.write(0, 10.),
|
||||||
|
2: lambda: ta.write(0, 15.),
|
||||||
|
})
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t, = xla.compile(f)
|
||||||
|
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
def testSwitchCaseConstPropagation(self):
|
def testSwitchCaseConstPropagation(self):
|
||||||
self.skipTest("b/127846988")
|
self.skipTest("b/127846988")
|
||||||
with self.session() as sess, self.test_scope():
|
with self.session() as sess, self.test_scope():
|
||||||
|
@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
patches=patches)
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
|
||||||
|
"""Test for 2x2 kernel with VALID padding."""
|
||||||
|
# [1, 2, 2, 2]
|
||||||
|
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
|
||||||
|
# [1, 1, 1, 8]
|
||||||
|
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[2, 2],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -72,21 +72,21 @@ class TernaryOpsTest(xla_test.XLATestCase):
|
|||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testTernary(
|
self._testTernary(
|
||||||
array_ops.where,
|
array_ops.where,
|
||||||
np.array(0, dtype=np.bool),
|
np.array(False),
|
||||||
np.array(2, dtype=dtype),
|
np.array(2, dtype=dtype),
|
||||||
np.array(7, dtype=dtype),
|
np.array(7, dtype=dtype),
|
||||||
expected=np.array(7, dtype=dtype))
|
expected=np.array(7, dtype=dtype))
|
||||||
|
|
||||||
self._testTernary(
|
self._testTernary(
|
||||||
array_ops.where,
|
array_ops.where,
|
||||||
np.array(1, dtype=np.bool),
|
np.array(True),
|
||||||
np.array([1, 2, 3, 4], dtype=dtype),
|
np.array([1, 2, 3, 4], dtype=dtype),
|
||||||
np.array([5, 6, 7, 8], dtype=dtype),
|
np.array([5, 6, 7, 8], dtype=dtype),
|
||||||
expected=np.array([1, 2, 3, 4], dtype=dtype))
|
expected=np.array([1, 2, 3, 4], dtype=dtype))
|
||||||
|
|
||||||
self._testTernary(
|
self._testTernary(
|
||||||
array_ops.where,
|
array_ops.where,
|
||||||
np.array(0, dtype=np.bool),
|
np.array(False),
|
||||||
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
|
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
|
||||||
@ -105,6 +105,74 @@ class TernaryOpsTest(xla_test.XLATestCase):
|
|||||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
|
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
|
||||||
|
|
||||||
|
def testSelectV2(self):
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array(False),
|
||||||
|
np.array(2, dtype=dtype),
|
||||||
|
np.array(7, dtype=dtype),
|
||||||
|
expected=np.array(7, dtype=dtype))
|
||||||
|
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array(True),
|
||||||
|
np.array([1, 2, 3, 4], dtype=dtype),
|
||||||
|
np.array([5, 6, 7, 8], dtype=dtype),
|
||||||
|
expected=np.array([1, 2, 3, 4], dtype=dtype))
|
||||||
|
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array(False),
|
||||||
|
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||||
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
|
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([0, 1, 1, 0], dtype=np.bool),
|
||||||
|
np.array([1, 2, 3, 4], dtype=dtype),
|
||||||
|
np.array([5, 6, 7, 8], dtype=dtype),
|
||||||
|
expected=np.array([5, 2, 3, 8], dtype=dtype))
|
||||||
|
|
||||||
|
# Broadcast the condition
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([0, 1], dtype=np.bool),
|
||||||
|
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||||
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
|
expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype))
|
||||||
|
|
||||||
|
# Broadcast the then branch to the else
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool),
|
||||||
|
np.array([[1, 2]], dtype=dtype),
|
||||||
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
|
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
|
||||||
|
|
||||||
|
# Broadcast the else branch to the then
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
|
||||||
|
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||||
|
np.array([[1, 2]], dtype=dtype),
|
||||||
|
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
|
||||||
|
|
||||||
|
# Broadcast the then/else branches to the condition
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool),
|
||||||
|
np.array(7, dtype=dtype),
|
||||||
|
np.array(8, dtype=dtype),
|
||||||
|
expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype))
|
||||||
|
self._testTernary(
|
||||||
|
array_ops.where_v2,
|
||||||
|
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
|
||||||
|
np.array(7, dtype=dtype),
|
||||||
|
np.array([8, 9], dtype=dtype),
|
||||||
|
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
|
||||||
|
|
||||||
def testSlice(self):
|
def testSlice(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testTernary(
|
self._testTernary(
|
||||||
|
@ -104,6 +104,24 @@ struct EdgePtrCompare {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO(laigd): instead of deciding the device here, the converter should accept
|
||||||
|
// a device name as one of the conversion parameter so users can control on
|
||||||
|
// which device they want to run the conversion.
|
||||||
|
std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
|
||||||
|
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
|
||||||
|
TfGpuId tf_gpu_id(tf_gpu_id_value);
|
||||||
|
PlatformGpuId platform_gpu_id;
|
||||||
|
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
|
||||||
|
if (s.ok()) {
|
||||||
|
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
|
||||||
|
<< platform_gpu_id.value();
|
||||||
|
return std::make_pair(tf_gpu_id, platform_gpu_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(ERROR) << "Could not find any TF GPUs";
|
||||||
|
return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
|
||||||
|
}
|
||||||
|
|
||||||
// Function to get subsegment information structure.
|
// Function to get subsegment information structure.
|
||||||
Status GetEngineInfo(const Graph* g,
|
Status GetEngineInfo(const Graph* g,
|
||||||
const grappler::GraphProperties& graph_properties,
|
const grappler::GraphProperties& graph_properties,
|
||||||
@ -128,20 +146,37 @@ Status GetEngineInfo(const Graph* g,
|
|||||||
if (segment_nodes.count(node) == 0) continue;
|
if (segment_nodes.count(node) == 0) continue;
|
||||||
auto node_device = node->requested_device();
|
auto node_device = node->requested_device();
|
||||||
if (!node_device.empty()) {
|
if (!node_device.empty()) {
|
||||||
// If device is CPU, treat as if no device was assigned. Don't add CPU to
|
// If device is set, it means device placement may have been done before,
|
||||||
// segment_device because that would cause a segfault in
|
// so we need to assign a device for the TRTEngineOp to maintain the
|
||||||
// GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes
|
// invariance.
|
||||||
// any already set device is a GPU.
|
// If the device is CPU in this case, it tries to find the first available
|
||||||
|
// GPU and use it as the device.
|
||||||
DeviceNameUtils::ParsedName parsed_name;
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
const bool parse_succeeded =
|
||||||
DeviceNameUtils::ParseFullName(node_device, &parsed_name);
|
DeviceNameUtils::ParseFullName(node_device, &parsed_name);
|
||||||
if (parsed_name.type == "CPU") {
|
if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) {
|
||||||
VLOG(1) << "Node " << node->name() << " was assigned to the CPU. "
|
string msg;
|
||||||
<< "Attempting to place on GPU.";
|
if (!parse_succeeded) {
|
||||||
|
msg = StrCat("Failed to parse assigned device of node ", node->name(),
|
||||||
|
". ");
|
||||||
|
} else {
|
||||||
|
msg = StrCat("Node ", node->name(), " was assigned to the CPU. ");
|
||||||
|
}
|
||||||
|
VLOG(1) << msg << "Attempting to place on GPU.";
|
||||||
|
TfGpuId tf_gpu_id;
|
||||||
|
PlatformGpuId platform_gpu_id;
|
||||||
|
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
|
||||||
|
if (tf_gpu_id.value() >= 0) {
|
||||||
|
parsed_name.type = "GPU";
|
||||||
|
parsed_name.id = tf_gpu_id.value();
|
||||||
|
segment_devices.insert(DeviceNameUtils::FullName(
|
||||||
|
parsed_name.job, parsed_name.replica, parsed_name.task,
|
||||||
|
parsed_name.type, parsed_name.id));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
segment_devices.insert(node_device);
|
segment_devices.insert(node_device);
|
||||||
}
|
}
|
||||||
} else {
|
} else if (node->has_assigned_device_name()) {
|
||||||
if (node->has_assigned_device_name()) {
|
|
||||||
// It appears that nodes will not have assigned devices at this point in
|
// It appears that nodes will not have assigned devices at this point in
|
||||||
// execution.
|
// execution.
|
||||||
segment_devices.insert(node->assigned_device_name());
|
segment_devices.insert(node->assigned_device_name());
|
||||||
@ -149,7 +184,6 @@ Status GetEngineInfo(const Graph* g,
|
|||||||
VLOG(2) << "Node " << node->name()
|
VLOG(2) << "Node " << node->name()
|
||||||
<< " neither have requested device nor assigned device";
|
<< " neither have requested device nor assigned device";
|
||||||
}
|
}
|
||||||
}
|
|
||||||
subgraph_nodes.push_back(node);
|
subgraph_nodes.push_back(node);
|
||||||
|
|
||||||
const int node_id = node->id();
|
const int node_id = node->id();
|
||||||
@ -251,13 +285,11 @@ Status GetEngineInfo(const Graph* g,
|
|||||||
info->engine_name = StrCat(scope_name, info->engine_name);
|
info->engine_name = StrCat(scope_name, info->engine_name);
|
||||||
VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
|
VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
|
||||||
<< "' to a GraphDef";
|
<< "' to a GraphDef";
|
||||||
// TODO(sami): This should not happen once segmenter is updated.
|
|
||||||
if (segment_devices.size() == 1) {
|
if (segment_devices.size() == 1) {
|
||||||
info->device = *segment_devices.begin();
|
info->device = *segment_devices.begin();
|
||||||
} else if (segment_devices.size() > 1) {
|
} else if (segment_devices.size() > 1) {
|
||||||
LOG(WARNING) << "Detected multiple(" << segment_devices.size()
|
LOG(WARNING) << "Detected multiple (" << segment_devices.size()
|
||||||
<< ") devices for the segment. Picking first one to continue "
|
<< ") devices for the segment. Picking first one to continue.";
|
||||||
<< "but this shouldn't have happened";
|
|
||||||
info->device = *segment_devices.begin();
|
info->device = *segment_devices.begin();
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "No device is assigned to the segment. "
|
VLOG(1) << "No device is assigned to the segment. "
|
||||||
@ -543,10 +575,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
|
|||||||
std::map<string, Node*> io_nodes;
|
std::map<string, Node*> io_nodes;
|
||||||
int num_inputs = 0;
|
int num_inputs = 0;
|
||||||
for (auto n : sgraph.op_nodes()) {
|
for (auto n : sgraph.op_nodes()) {
|
||||||
if (str_util::StartsWith(n->name(), kInputPHName)) {
|
if (absl::StartsWith(n->name(), kInputPHName)) {
|
||||||
num_inputs++;
|
num_inputs++;
|
||||||
io_nodes.insert({n->name(), n});
|
io_nodes.insert({n->name(), n});
|
||||||
} else if (str_util::StartsWith(n->name(), kOutputPHName)) {
|
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
|
||||||
io_nodes.insert({n->name(), n});
|
io_nodes.insert({n->name(), n});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -640,24 +672,17 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
|||||||
if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
|
if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
|
||||||
engine.device.empty()) {
|
engine.device.empty()) {
|
||||||
// If device is not set, use the first found GPU device for the conversion.
|
// If device is not set, use the first found GPU device for the conversion.
|
||||||
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
|
TfGpuId tf_gpu_id;
|
||||||
TfGpuId tf_gpu_id(tf_gpu_id_value);
|
|
||||||
PlatformGpuId platform_gpu_id;
|
PlatformGpuId platform_gpu_id;
|
||||||
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
|
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
|
||||||
if (s.ok()) {
|
|
||||||
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
|
|
||||||
<< platform_gpu_id.value();
|
|
||||||
cuda_device_id = platform_gpu_id.value();
|
cuda_device_id = platform_gpu_id.value();
|
||||||
|
if (cuda_device_id >= 0) {
|
||||||
GPUOptions gpu_options;
|
GPUOptions gpu_options;
|
||||||
// If the TF to Cuda gpu id mapping exist, the device and corresponding
|
// If the TF to Cuda gpu id mapping exist, the device and corresponding
|
||||||
// allocator must have been initialized already, so the
|
// allocator must have been initialized already, so the
|
||||||
// GetGPUAllocator() call won't create a new allocator.
|
// GetGPUAllocator() call won't create a new allocator.
|
||||||
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
|
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
|
||||||
gpu_options, tf_gpu_id, 1);
|
gpu_options, tf_gpu_id, 1);
|
||||||
break;
|
|
||||||
}
|
|
||||||
LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
|
|
||||||
<< s;
|
|
||||||
}
|
}
|
||||||
return std::make_pair(cuda_device_id, dev_allocator);
|
return std::make_pair(cuda_device_id, dev_allocator);
|
||||||
}
|
}
|
||||||
@ -750,8 +775,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
EngineInfo curr_engine;
|
EngineInfo curr_engine;
|
||||||
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
|
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
|
||||||
Status status =
|
Status status =
|
||||||
GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
|
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
|
||||||
node_map, reverse_topo_order, &curr_engine);
|
reverse_topo_order, &curr_engine);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
|
LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
|
||||||
<< status;
|
<< status;
|
||||||
@ -776,7 +801,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
|
|
||||||
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
||||||
total_engine_bytes_size += engine_bytes_size.back();
|
total_engine_bytes_size += engine_bytes_size.back();
|
||||||
total_num_nodes_in_segments += curr_segment.first.size();
|
total_num_nodes_in_segments += curr_segment.size();
|
||||||
engine_segments.push_back(std::move(curr_engine));
|
engine_segments.push_back(std::move(curr_engine));
|
||||||
converted_segments.push_back(std::move(curr_segment));
|
converted_segments.push_back(std::move(curr_segment));
|
||||||
|
|
||||||
@ -806,7 +831,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
engine.max_workspace_size_bytes =
|
engine.max_workspace_size_bytes =
|
||||||
params.max_workspace_size_bytes *
|
params.max_workspace_size_bytes *
|
||||||
(engine_bytes_size.at(i) / total_engine_bytes_size +
|
(engine_bytes_size.at(i) / total_engine_bytes_size +
|
||||||
converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
|
converted_segments.at(i).size() / total_num_nodes_in_segments) /
|
||||||
2.0;
|
2.0;
|
||||||
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
|
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
|
||||||
<< engine.engine_name;
|
<< engine.engine_name;
|
||||||
@ -828,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
|
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
|
||||||
alloc.get(), &engine_nodes);
|
alloc.get(), &engine_nodes);
|
||||||
|
|
||||||
string msg = StrCat("TensorRT node ", engine.engine_name,
|
string msg =
|
||||||
" added for segment ", i, " consisting of ",
|
StrCat("TensorRT node ", engine.engine_name, " added for segment ", i,
|
||||||
converted_segments.at(i).first.size(), " nodes");
|
" consisting of ", converted_segments.at(i).size(), " nodes");
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
LOG(INFO) << msg << " succeeded.";
|
LOG(INFO) << msg << " succeeded.";
|
||||||
} else {
|
} else {
|
||||||
@ -839,7 +864,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
}
|
}
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
msg = "Segment consists of nodes: ";
|
msg = "Segment consists of nodes: ";
|
||||||
for (const Node* node : converted_segments.at(i).first) {
|
for (const Node* node : converted_segments.at(i)) {
|
||||||
StrAppend(&msg, node->name(), ", ");
|
StrAppend(&msg, node->name(), ", ");
|
||||||
}
|
}
|
||||||
VLOG(1) << msg;
|
VLOG(1) << msg;
|
||||||
@ -848,7 +873,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
// If status is ok, we successfully added the node to the graph and can
|
// If status is ok, we successfully added the node to the graph and can
|
||||||
// remove segment ops. Otherwise graph is not modified.
|
// remove segment ops. Otherwise graph is not modified.
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
for (const Node* node : converted_segments.at(i).first) {
|
for (const Node* node : converted_segments.at(i)) {
|
||||||
graph.RemoveNode(const_cast<Node*>(node));
|
graph.RemoveNode(const_cast<Node*>(node));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,7 +239,7 @@ class ConvertAfterShapesTest : public ::testing::Test {
|
|||||||
params.output_names = &output_names;
|
params.output_names = &output_names;
|
||||||
params.max_workspace_size_bytes = 8 << 20;
|
params.max_workspace_size_bytes = 8 << 20;
|
||||||
params.output_graph_def = output_graph_def;
|
params.output_graph_def = output_graph_def;
|
||||||
params.minimum_segment_size = 2;
|
params.minimum_segment_size = 1;
|
||||||
params.graph_properties = &graph_properties;
|
params.graph_properties = &graph_properties;
|
||||||
params.use_calibration = false;
|
params.use_calibration = false;
|
||||||
|
|
||||||
|
@ -385,11 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) {
|
|||||||
", dims=", DebugString(tensor.getDimensions()), ")");
|
", dims=", DebugString(tensor.getDimensions()), ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Converter::GetTrtBroadcastShape(
|
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||||
const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
|
const TRT_TensorOrWeights& operand_r,
|
||||||
nvinfer1::Dims* operand_l_new_dims,
|
nvinfer1::Dims* operand_l_new_dims,
|
||||||
nvinfer1::Dims* operand_r_new_dims) const {
|
nvinfer1::Dims* operand_r_new_dims) {
|
||||||
// ***************************************************************************
|
|
||||||
// TensorRT Elementwise op supports broadcast but requires both tensor to be
|
// TensorRT Elementwise op supports broadcast but requires both tensor to be
|
||||||
// of Identical rank
|
// of Identical rank
|
||||||
//
|
//
|
||||||
@ -473,14 +472,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
|
|||||||
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
|
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
|
||||||
nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
|
nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
|
||||||
if (!layer) return nullptr;
|
if (!layer) return nullptr;
|
||||||
const nvinfer1::DataType trt_dtype = trt_weights.type;
|
|
||||||
nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
|
nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
|
||||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
// TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
|
// TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
|
||||||
// the data type below, it will always be kFLOAT regardless what the data type
|
// the data type below, it will always be kFLOAT regardless what the data type
|
||||||
// of the weights is. Once NVIDIA fixes this bug, we should remove the data
|
// of the weights is. Once NVIDIA fixes this bug, we should remove the data
|
||||||
// type setting logic below and test should still pass.
|
// type setting logic below and test should still pass.
|
||||||
trt_tensor->setType(trt_dtype);
|
trt_tensor->setType(trt_weights.type);
|
||||||
#endif
|
#endif
|
||||||
return trt_tensor;
|
return trt_tensor;
|
||||||
}
|
}
|
||||||
@ -1677,190 +1675,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
|
|
||||||
// right operand. If swapped_inputs is true, those two are swapped.
|
|
||||||
//
|
|
||||||
// TODO(jie): broadcast is needed yet not implemented.
|
|
||||||
// Only implemented channel wise for the time being.
|
|
||||||
Status BinaryTensorOpWeight(OpConverterParams* params,
|
|
||||||
nvinfer1::ITensor* tensor,
|
|
||||||
TRT_ShapedWeights weights, bool swapped_inputs) {
|
|
||||||
static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
|
|
||||||
"Div", "RealDiv"};
|
|
||||||
const auto& node_def = params->node_def;
|
|
||||||
if (!supported_ops.count(node_def.op())) {
|
|
||||||
return errors::Unimplemented(node_def.op(), " is not supported, at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check scale mode.
|
|
||||||
auto dims_w = weights.shape_;
|
|
||||||
const auto dims_t = tensor->getDimensions();
|
|
||||||
|
|
||||||
// TODO(jie): addScale checks for input tensor dimension
|
|
||||||
if (dims_t.nbDims != 3) {
|
|
||||||
return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to element-wise
|
|
||||||
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
|
|
||||||
|
|
||||||
// TODO(jie): maybe use a permutation instead to support more cases;
|
|
||||||
bool need_to_permute = false;
|
|
||||||
|
|
||||||
if (weights.count() == 1) {
|
|
||||||
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
|
|
||||||
} else {
|
|
||||||
VLOG(2) << "weights dims: " << DebugString(dims_w)
|
|
||||||
<< "; tensor dims: " << DebugString(dims_t);
|
|
||||||
// Make sure no broadcasting on batch dimension.
|
|
||||||
if (dims_w.nbDims == dims_t.nbDims + 1) {
|
|
||||||
if (dims_w.d[0] == 1) {
|
|
||||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
|
||||||
dims_w.d[i - 1] = dims_w.d[i];
|
|
||||||
}
|
|
||||||
dims_w.nbDims--;
|
|
||||||
} else {
|
|
||||||
return errors::InvalidArgument("Binary op cannot operate on batch, at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
|
|
||||||
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
|
|
||||||
// Default is element-wise
|
|
||||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
|
||||||
if (dims_w.d[i] != dims_t.d[i]) {
|
|
||||||
// If dimension does not match, switch back to per-channel
|
|
||||||
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If the mode is per-channel, since channel dimension is assumed to be
|
|
||||||
// the third to last dimension, we need to make sure all other dimensions
|
|
||||||
// have size 1.
|
|
||||||
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
|
|
||||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
|
||||||
if (dims_w.d[i] != 1)
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Weight dims not compatible for channel-wise broadcast at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (dims_w.nbDims == 1 &&
|
|
||||||
dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
|
|
||||||
// Channel wise and broadcast required. We compare the last dimension of
|
|
||||||
// the tensor shape because of tensorflow default broadcasting rules.
|
|
||||||
need_to_permute = true;
|
|
||||||
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
|
|
||||||
} else {
|
|
||||||
return errors::InvalidArgument("Weight dims not compatible at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO(laigd): we should add validation_only support in TransposeTensor() and
|
|
||||||
// PrepareTensorForShape().
|
|
||||||
if (params->validation_only) return Status::OK();
|
|
||||||
|
|
||||||
// Transpose last dimension.
|
|
||||||
std::vector<int> permutation(dims_t.nbDims + 1);
|
|
||||||
if (need_to_permute) {
|
|
||||||
// We swap the last dimension into channel for trt, because of tensorflow
|
|
||||||
// default broadcasting rules.
|
|
||||||
for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
|
|
||||||
permutation[i] = i;
|
|
||||||
}
|
|
||||||
permutation[1] = dims_t.nbDims;
|
|
||||||
permutation[dims_t.nbDims] = 1;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
params->converter->TransposeTensor(tensor, permutation, &tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare weights
|
|
||||||
TRT_ShapedWeights shift_weights(weights.TrtDType());
|
|
||||||
TRT_ShapedWeights scale_weights(weights.TrtDType());
|
|
||||||
TRT_ShapedWeights power_weights(weights.TrtDType());
|
|
||||||
|
|
||||||
if (node_def.op() == "Sub") {
|
|
||||||
if (swapped_inputs) {
|
|
||||||
shift_weights = weights;
|
|
||||||
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
|
|
||||||
*tensor, nvinfer1::UnaryOperation::kNEG);
|
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
|
||||||
// Since quantization ranges are symmetric, the same range as the input
|
|
||||||
// will work for the negation of the input.
|
|
||||||
params->converter->MarkQuantizationRangesAsInferrable(
|
|
||||||
tensor, layer->getOutput(0));
|
|
||||||
tensor = layer->getOutput(0);
|
|
||||||
} else {
|
|
||||||
TRT_ShapedWeights neg_weights =
|
|
||||||
params->weight_store->GetTempWeights(weights);
|
|
||||||
LambdaFactory unary_op;
|
|
||||||
unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
|
|
||||||
TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
|
|
||||||
shift_weights = neg_weights;
|
|
||||||
}
|
|
||||||
} else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
|
|
||||||
if (swapped_inputs) {
|
|
||||||
// We need to infer the quantization range for this intermediate tensor.
|
|
||||||
//
|
|
||||||
// x -> [Recip] -> 1/x -> [Scale] -> s/x
|
|
||||||
// ^
|
|
||||||
// need range for this
|
|
||||||
//
|
|
||||||
// We have the quantization scales for x and s/x - can we divide the scale
|
|
||||||
// for s/x by s? Only if it is a scalar.
|
|
||||||
//
|
|
||||||
// Because of this issue, fall back to BinaryTensorOpTensor if we are
|
|
||||||
// doing INT8 with no calibration. There is most likely no performance
|
|
||||||
// penalty by falling back here.
|
|
||||||
if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
|
|
||||||
!params->converter->use_calibration()) {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"Intermediate quantization range cannot be determined without"
|
|
||||||
" calibration. Falling back to BinaryTensorOpTensor for ",
|
|
||||||
node_def.op(), ", at ", node_def.name());
|
|
||||||
}
|
|
||||||
scale_weights = weights;
|
|
||||||
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
|
|
||||||
*tensor, nvinfer1::UnaryOperation::kRECIP);
|
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
|
||||||
tensor = layer->getOutput(0);
|
|
||||||
} else {
|
|
||||||
TRT_ShapedWeights recip_weights =
|
|
||||||
params->weight_store->GetTempWeights(weights);
|
|
||||||
LambdaFactory unary_op;
|
|
||||||
unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
|
|
||||||
TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
|
|
||||||
scale_weights = recip_weights;
|
|
||||||
}
|
|
||||||
} else if (node_def.op() == "Mul") {
|
|
||||||
scale_weights = weights;
|
|
||||||
} else if (node_def.op() == "Add") {
|
|
||||||
shift_weights = weights;
|
|
||||||
} else {
|
|
||||||
// This should not happen.
|
|
||||||
return errors::Unimplemented("Binary op not supported at ", node_def.op());
|
|
||||||
}
|
|
||||||
|
|
||||||
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
|
|
||||||
*tensor, scale_mode, shift_weights.GetTrtWeights(),
|
|
||||||
scale_weights.GetTrtWeights(), power_weights.GetTrtWeights());
|
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
|
||||||
|
|
||||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
|
||||||
// Transpose back dimension
|
|
||||||
if (need_to_permute) {
|
|
||||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
|
||||||
output_tensor, permutation, &output_tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass the output
|
|
||||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||||
bool is_conv2d_backprop_input) {
|
bool is_conv2d_backprop_input) {
|
||||||
const auto& inputs = params->inputs;
|
const auto& inputs = params->inputs;
|
||||||
@ -1951,7 +1765,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
kernel_size.h() = weights.shape_.d[2];
|
kernel_size.h() = weights.shape_.d[2];
|
||||||
kernel_size.w() = weights.shape_.d[3];
|
kernel_size.w() = weights.shape_.d[3];
|
||||||
|
|
||||||
// Add padding.
|
// Before TRT 5.1.3, we have to calculate padding ourselves.
|
||||||
|
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
std::vector<std::pair<int, int>> padding;
|
std::vector<std::pair<int, int>> padding;
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
||||||
@ -1978,12 +1793,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
padding = {{0, 0}, {0, 0}};
|
padding = {{0, 0}, {0, 0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we
|
// Handle asymmetric padding. TensorRT 5.1 added support for asymmetric
|
||||||
// can only use asymmetric padding in convolutions with 5.1.3+.
|
// padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can
|
||||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
// only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we
|
||||||
|
// will always use setPaddingMode for simplicity.
|
||||||
if (padding[0].first != padding[0].second ||
|
if (padding[0].first != padding[0].second ||
|
||||||
padding[1].first != padding[1].second) {
|
padding[1].first != padding[1].second) {
|
||||||
// Handle asymmetric padding.
|
|
||||||
auto pad_layer = params->converter->network()->addPadding(
|
auto pad_layer = params->converter->network()->addPadding(
|
||||||
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
||||||
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
||||||
@ -2006,20 +1821,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
layer->setStride(stride);
|
layer->setStride(stride);
|
||||||
// TensorRT 5.1.3 added support for padding modes.
|
// TensorRT 5.1.3 added support for padding modes.
|
||||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
|
// VALID padding is the default TRT behavior.
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
VLOG(2) << "Using SAME padding";
|
|
||||||
// SAME_UPPER means that post padding is preferred.
|
// SAME_UPPER means that post padding is preferred.
|
||||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||||
}
|
}
|
||||||
// For VALID padding, we need to manually set the padding.
|
|
||||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
|
||||||
layer->setPostPadding(
|
|
||||||
nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
|
||||||
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
|
|
||||||
<< " and post-padding to: " << DebugString(layer->getPostPadding());
|
|
||||||
#else
|
#else
|
||||||
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||||
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
|
|
||||||
#endif
|
#endif
|
||||||
layer->setName(node_def.name().c_str());
|
layer->setName(node_def.name().c_str());
|
||||||
layer->setNbGroups(num_groups);
|
layer->setNbGroups(num_groups);
|
||||||
@ -2033,17 +1841,10 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
layer->setStride(stride);
|
layer->setStride(stride);
|
||||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
VLOG(2) << "Using SAME padding";
|
|
||||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||||
}
|
}
|
||||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
|
||||||
layer->setPostPadding(
|
|
||||||
nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
|
||||||
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
|
|
||||||
<< " and post-padding to: " << DebugString(layer->getPostPadding());
|
|
||||||
#else
|
#else
|
||||||
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||||
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
|
|
||||||
#endif
|
#endif
|
||||||
layer->setName(node_def.name().c_str());
|
layer->setName(node_def.name().c_str());
|
||||||
layer->setNbGroups(num_groups);
|
layer->setNbGroups(num_groups);
|
||||||
@ -2061,74 +1862,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BinaryTensorOpTensor(OpConverterParams* params,
|
|
||||||
const TRT_TensorOrWeights& operand_l,
|
|
||||||
const TRT_TensorOrWeights& operand_r) {
|
|
||||||
const auto& node_def = params->node_def;
|
|
||||||
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
|
|
||||||
{"Add", nvinfer1::ElementWiseOperation::kSUM},
|
|
||||||
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
|
|
||||||
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
|
|
||||||
{"Div", nvinfer1::ElementWiseOperation::kDIV},
|
|
||||||
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
|
|
||||||
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
|
|
||||||
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
|
|
||||||
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
|
|
||||||
};
|
|
||||||
auto op_pair = ops.find(node_def.op());
|
|
||||||
if (op_pair == ops.end()) {
|
|
||||||
return errors::Unimplemented("Binary op ", node_def.op(),
|
|
||||||
" not supported at: ", node_def.name());
|
|
||||||
}
|
|
||||||
|
|
||||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
|
||||||
Status status = params->converter->GetTrtBroadcastShape(
|
|
||||||
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
|
|
||||||
status.error_message());
|
|
||||||
}
|
|
||||||
TFAttrs attrs(node_def);
|
|
||||||
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
|
|
||||||
if (dtype == nvinfer1::DataType::kINT32) {
|
|
||||||
return errors::Unimplemented("Binary op ", node_def.op(),
|
|
||||||
" does not support INT32, at ",
|
|
||||||
node_def.name());
|
|
||||||
}
|
|
||||||
if (params->validation_only) return Status::OK();
|
|
||||||
|
|
||||||
nvinfer1::ITensor* tensor_l = nullptr;
|
|
||||||
nvinfer1::ITensor* tensor_r = nullptr;
|
|
||||||
status = params->converter->PrepareTensorForShape(
|
|
||||||
operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l);
|
|
||||||
if (status.ok()) {
|
|
||||||
status = params->converter->PrepareTensorForShape(
|
|
||||||
operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
|
||||||
return errors::Internal("Failed to convert binary op ", node_def.name(),
|
|
||||||
": ", status.error_message());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check type consistency.
|
|
||||||
TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
|
|
||||||
<< DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
|
|
||||||
TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
|
|
||||||
<< DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
|
|
||||||
|
|
||||||
// Add ElementWise layer.
|
|
||||||
nvinfer1::IElementWiseLayer* layer =
|
|
||||||
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
|
|
||||||
op_pair->second);
|
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
|
||||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
|
||||||
|
|
||||||
// Pass the output
|
|
||||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ConvertPlugin(OpConverterParams* params) {
|
Status ConvertPlugin(OpConverterParams* params) {
|
||||||
const auto& inputs = params->inputs;
|
const auto& inputs = params->inputs;
|
||||||
const auto& node_def = params->node_def;
|
const auto& node_def = params->node_def;
|
||||||
@ -2777,6 +2510,8 @@ Status ConvertPool(OpConverterParams* params) {
|
|||||||
const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
|
const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
|
||||||
const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
|
const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
|
||||||
|
|
||||||
|
// Before TRT 5.1.3, we have to calculate padding ourselves.
|
||||||
|
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
auto tensor_dim = tensor->getDimensions();
|
auto tensor_dim = tensor->getDimensions();
|
||||||
std::vector<std::pair<int, int>> padding;
|
std::vector<std::pair<int, int>> padding;
|
||||||
if (padding_type == "SAME") {
|
if (padding_type == "SAME") {
|
||||||
@ -2789,13 +2524,13 @@ Status ConvertPool(OpConverterParams* params) {
|
|||||||
} else if (padding_type == "VALID") {
|
} else if (padding_type == "VALID") {
|
||||||
padding = {{0, 0}, {0, 0}};
|
padding = {{0, 0}, {0, 0}};
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
// TensorRT 5.1 added support for asymmetric padding.
|
// TensorRT 5.1 added support for asymmetric padding. Before that, we need an
|
||||||
|
// extra padding layer.
|
||||||
#if !IS_TRT_VERSION_GE(5, 1, 0, 0)
|
#if !IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||||
|
// Asymmetric padding case.
|
||||||
if (padding[0].first != padding[0].second ||
|
if (padding[0].first != padding[0].second ||
|
||||||
padding[1].first != padding[1].second) {
|
padding[1].first != padding[1].second) {
|
||||||
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
|
|
||||||
<< padding[1].first << padding[1].second;
|
|
||||||
auto pad_layer = params->converter->network()->addPadding(
|
auto pad_layer = params->converter->network()->addPadding(
|
||||||
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
||||||
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
||||||
@ -2817,16 +2552,13 @@ Status ConvertPool(OpConverterParams* params) {
|
|||||||
layer->getOutput(0));
|
layer->getOutput(0));
|
||||||
|
|
||||||
layer->setStride(stride);
|
layer->setStride(stride);
|
||||||
// TensorRT 5.1.3 added support for padding modes.
|
|
||||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||||
|
// VALID padding is the default TRT behavior.
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
// SAME_UPPER means that post padding is preferred.
|
// SAME_UPPER means that post padding is preferred.
|
||||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||||
}
|
}
|
||||||
#endif
|
#elif IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||||
// TensorRT 5.1 has support for asymmetric padding.
|
|
||||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
|
||||||
// If padding mode is not SAME, then these values will be used instead.
|
|
||||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||||
layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
||||||
#else
|
#else
|
||||||
@ -3350,9 +3082,6 @@ Status ConvertIdentity(OpConverterParams* params) {
|
|||||||
Status ConvertBinary(OpConverterParams* params) {
|
Status ConvertBinary(OpConverterParams* params) {
|
||||||
const auto& inputs = params->inputs;
|
const auto& inputs = params->inputs;
|
||||||
const auto& node_def = params->node_def;
|
const auto& node_def = params->node_def;
|
||||||
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
|
|
||||||
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
|
|
||||||
// false}}));
|
|
||||||
if (inputs.size() != 2) {
|
if (inputs.size() != 2) {
|
||||||
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
|
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
|
||||||
" inputs but expected 2, at ",
|
" inputs but expected 2, at ",
|
||||||
@ -3368,33 +3097,45 @@ Status ConvertBinary(OpConverterParams* params) {
|
|||||||
"both input as constant at: ",
|
"both input as constant at: ",
|
||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
|
const TRT_TensorOrWeights& operand_l = inputs.at(0);
|
||||||
|
const TRT_TensorOrWeights& operand_r = inputs.at(1);
|
||||||
|
|
||||||
// TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with
|
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
|
||||||
// IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For
|
{"Add", nvinfer1::ElementWiseOperation::kSUM},
|
||||||
// now, the performance will be slightly better with IScaleLayer because it
|
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
|
||||||
// can be fused in more situations. However, most of the benefits of
|
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
|
||||||
// IScaleLayer are when the layer performs both a shift and a scale, which we
|
{"Div", nvinfer1::ElementWiseOperation::kDIV},
|
||||||
// don't do except for convolutions.
|
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
|
||||||
//
|
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
|
||||||
// Try to convert into Scale layer first (for better performance).
|
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
|
||||||
// Since scale layer supports restricted broadcast policy and op types, we
|
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
|
||||||
// allow failure and try to handle it through Elementwise op
|
};
|
||||||
// (BinaryTensorOpTensor).
|
auto op_pair = ops.find(node_def.op());
|
||||||
Status status = Status::OK();
|
if (op_pair == ops.end()) {
|
||||||
if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
|
return errors::Unimplemented("Binary op ", node_def.op(),
|
||||||
status = BinaryTensorOpWeight(params, inputs.at(0).tensor(),
|
" not supported at: ", node_def.name());
|
||||||
inputs.at(1).weights(), false);
|
|
||||||
} else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
|
|
||||||
status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
|
|
||||||
inputs.at(0).weights(), true);
|
|
||||||
}
|
}
|
||||||
// If both input are tensors, or one of them is weights but the conversion
|
|
||||||
// above failed, try the conversion using BinaryTensorOpTensor.
|
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||||
if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
|
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||||
if (!status.ok()) VLOG(2) << status;
|
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r));
|
||||||
status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
|
|
||||||
}
|
nvinfer1::ITensor* tensor_l = nullptr;
|
||||||
return status;
|
nvinfer1::ITensor* tensor_r = nullptr;
|
||||||
|
// This will also convert constants to tensors, and set quantization ranges.
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||||
|
operand_l, broadcasted_dims_l, params->validation_only, &tensor_l));
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||||
|
operand_r, broadcasted_dims_r, params->validation_only, &tensor_r));
|
||||||
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
|
// Add ElementWise layer.
|
||||||
|
nvinfer1::IElementWiseLayer* layer =
|
||||||
|
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
|
||||||
|
op_pair->second);
|
||||||
|
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||||
|
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertRsqrt(OpConverterParams* params) {
|
Status ConvertRsqrt(OpConverterParams* params) {
|
||||||
@ -4547,7 +4288,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) {
|
|||||||
const auto& node_def = params->node_def;
|
const auto& node_def = params->node_def;
|
||||||
// Broadcast inputs.
|
// Broadcast inputs.
|
||||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||||
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape(
|
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||||
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
|
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
|
||||||
nvinfer1::ITensor* tensor_l = nullptr;
|
nvinfer1::ITensor* tensor_l = nullptr;
|
||||||
nvinfer1::ITensor* tensor_r = nullptr;
|
nvinfer1::ITensor* tensor_r = nullptr;
|
||||||
@ -4692,8 +4433,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) {
|
|||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
|
TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
|
||||||
|
|
||||||
// Create plugin
|
// Create plugin
|
||||||
nvinfer1::IPluginV2* plugin =
|
TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
|
||||||
creator->createPlugin(node_def.name().c_str(), &fc);
|
creator->createPlugin(node_def.name().c_str(), &fc));
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
|
TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
|
||||||
|
|
||||||
// Set plugin inputs
|
// Set plugin inputs
|
||||||
@ -4875,7 +4616,8 @@ static void RegisterValidatableOpConverters(
|
|||||||
for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
|
for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
|
||||||
(*registration)[pool_op_type] = ConvertPool;
|
(*registration)[pool_op_type] = ConvertPool;
|
||||||
}
|
}
|
||||||
for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) {
|
for (auto normalization_op_type :
|
||||||
|
{"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) {
|
||||||
(*registration)[normalization_op_type] = ConvertFusedBatchNorm;
|
(*registration)[normalization_op_type] = ConvertFusedBatchNorm;
|
||||||
}
|
}
|
||||||
for (auto unary_op_pair : *UnaryOperationMap()) {
|
for (auto unary_op_pair : *UnaryOperationMap()) {
|
||||||
|
@ -512,13 +512,6 @@ class Converter {
|
|||||||
const bool validation_only,
|
const bool validation_only,
|
||||||
nvinfer1::ITensor** tensor);
|
nvinfer1::ITensor** tensor);
|
||||||
|
|
||||||
// Return OK if the broadcast scheme is supported and compute the shapes after
|
|
||||||
// broadcasting.
|
|
||||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
|
||||||
const TRT_TensorOrWeights& operand_r,
|
|
||||||
nvinfer1::Dims* operand_l_new_dims,
|
|
||||||
nvinfer1::Dims* operand_r_new_dims) const;
|
|
||||||
|
|
||||||
// Creates an IConstantLayer using 'weights' whose dimensions are specified by
|
// Creates an IConstantLayer using 'weights' whose dimensions are specified by
|
||||||
// 'dims', and returns the output ITensor.
|
// 'dims', and returns the output ITensor.
|
||||||
nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
|
nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
|
||||||
@ -592,6 +585,13 @@ class Converter {
|
|||||||
friend class OpConverterTest;
|
friend class OpConverterTest;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Return OK if the broadcast scheme is supported and compute the shapes after
|
||||||
|
// broadcasting.
|
||||||
|
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||||
|
const TRT_TensorOrWeights& operand_r,
|
||||||
|
nvinfer1::Dims* operand_l_new_dims,
|
||||||
|
nvinfer1::Dims* operand_r_new_dims);
|
||||||
|
|
||||||
// Map of all supported UnaryOperations
|
// Map of all supported UnaryOperations
|
||||||
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
|
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
|
||||||
// Map of all supported ActivationTypes
|
// Map of all supported ActivationTypes
|
||||||
|
@ -988,18 +988,16 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
|
|||||||
operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
|
operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
|
||||||
|
|
||||||
// operand_1 broadcast operand_2
|
// operand_1 broadcast operand_2
|
||||||
ExpectStatus(
|
ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims,
|
||||||
this->converter_->GetTrtBroadcastShape(
|
&operand_2_new_dims),
|
||||||
operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
|
|
||||||
expected_code, expected_error_msg_substr);
|
expected_code, expected_error_msg_substr);
|
||||||
if (expected_code == error::OK) {
|
if (expected_code == error::OK) {
|
||||||
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
||||||
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
|
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
|
||||||
}
|
}
|
||||||
// operand_2 broadcast operand_1
|
// operand_2 broadcast operand_1
|
||||||
ExpectStatus(
|
ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims,
|
||||||
this->converter_->GetTrtBroadcastShape(
|
&operand_1_new_dims),
|
||||||
operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
|
|
||||||
expected_code, expected_error_msg_substr);
|
expected_code, expected_error_msg_substr);
|
||||||
if (expected_code == error::OK) {
|
if (expected_code == error::OK) {
|
||||||
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
||||||
@ -1033,18 +1031,29 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
|
|||||||
error::INVALID_ARGUMENT,
|
error::INVALID_ARGUMENT,
|
||||||
"Broadcasting beyond batch dimension is not supported "
|
"Broadcasting beyond batch dimension is not supported "
|
||||||
"(tensor #dims 4 vs broadcast #dims 5)");
|
"(tensor #dims 4 vs broadcast #dims 5)");
|
||||||
|
symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {},
|
||||||
|
error::INVALID_ARGUMENT,
|
||||||
|
"Broadcasting beyond batch dimension is not supported "
|
||||||
|
"(tensor #dims 2 vs broadcast #dims 3)",
|
||||||
|
/*operand_1_batch_size=*/2);
|
||||||
|
|
||||||
// Both inputs are tensors.
|
// Both inputs are tensors.
|
||||||
symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
|
symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
|
||||||
error::INVALID_ARGUMENT,
|
error::INVALID_ARGUMENT,
|
||||||
"Broadcasting beyond batch dimension is not supported "
|
"Broadcasting beyond batch dimension is not supported "
|
||||||
"(tensor #dims 3 vs broadcast #dims 4)");
|
"(tensor #dims 3 vs broadcast #dims 4)");
|
||||||
|
symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {},
|
||||||
|
error::INVALID_ARGUMENT,
|
||||||
|
"Broadcasting beyond batch dimension is not supported "
|
||||||
|
"(tensor #dims 2 vs broadcast #dims 3)");
|
||||||
symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
|
symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
|
||||||
{2, 1, 4});
|
{2, 1, 4});
|
||||||
symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
|
symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
|
||||||
error::INVALID_ARGUMENT,
|
error::INVALID_ARGUMENT,
|
||||||
"Broadcasting beyond batch dimension is not supported "
|
"Broadcasting beyond batch dimension is not supported "
|
||||||
"(tensor #dims 4 vs broadcast #dims 5)");
|
"(tensor #dims 4 vs broadcast #dims 5)");
|
||||||
|
symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {},
|
||||||
|
error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ConverterTest, CreateConstantLayer) {
|
TEST_F(ConverterTest, CreateConstantLayer) {
|
||||||
@ -1070,7 +1079,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
|||||||
int batch_size = -1;
|
int batch_size = -1;
|
||||||
for (const NodeDef& node : gdef.node()) {
|
for (const NodeDef& node : gdef.node()) {
|
||||||
absl::string_view node_name(node.name());
|
absl::string_view node_name(node.name());
|
||||||
if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
|
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
|
||||||
int port = -1;
|
int port = -1;
|
||||||
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
||||||
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
||||||
@ -1351,10 +1360,6 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatMulHelper(
|
|
||||||
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
|
||||||
const std::string& op_name);
|
|
||||||
|
|
||||||
// Expose quantization_ranges_ for tests
|
// Expose quantization_ranges_ for tests
|
||||||
std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
|
std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
|
||||||
return converter_->quantization_ranges_;
|
return converter_->quantization_ranges_;
|
||||||
@ -1682,59 +1687,60 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
|||||||
// Helper function for testing MatMul and BatchMatMul
|
// Helper function for testing MatMul and BatchMatMul
|
||||||
// get_matmul corresponds to the function used to generate the node. It should
|
// get_matmul corresponds to the function used to generate the node. It should
|
||||||
// accept (DataType, transpose_a, transpose_b) as parameters.
|
// accept (DataType, transpose_a, transpose_b) as parameters.
|
||||||
void OpConverterTest::TestMatMulHelper(
|
void TestMatMulHelper(
|
||||||
|
OpConverterTest* test,
|
||||||
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
||||||
const std::string& op_name) {
|
const std::string& op_name) {
|
||||||
// HACK: This needs to be done in a better way.
|
// HACK: This needs to be done in a better way.
|
||||||
const bool is_batch_matmul = op_name == "BatchMatMul";
|
const bool is_batch_matmul = op_name == "BatchMatMul";
|
||||||
{
|
{
|
||||||
// Unsupported data type.
|
// Unsupported data type.
|
||||||
Reset();
|
test->Reset();
|
||||||
NodeDef node_def = get_matmul(DT_INT32, false, false);
|
NodeDef node_def = get_matmul(DT_INT32, false, false);
|
||||||
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
|
test->AddTestTensor("input", {2}, /*batch_size=*/1,
|
||||||
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
nvinfer1::DataType::kINT32);
|
||||||
RunValidationAndConversion(
|
test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||||
|
test->RunValidationAndConversion(
|
||||||
node_def, error::UNIMPLEMENTED,
|
node_def, error::UNIMPLEMENTED,
|
||||||
("Data type int32 is not supported for " + op_name +
|
StrCat("Data type int32 is not supported for ", op_name,
|
||||||
", "
|
", must be one of [float, half], at my_matmul")
|
||||||
"must be one of [float, half], at my_matmul")
|
|
||||||
.c_str());
|
.c_str());
|
||||||
}
|
}
|
||||||
// OK.
|
// OK.
|
||||||
for (bool transpose_a : {false, true}) {
|
for (bool transpose_a : {false, true}) {
|
||||||
for (bool transpose_b : {false, true}) {
|
for (bool transpose_b : {false, true}) {
|
||||||
Reset();
|
test->Reset();
|
||||||
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
|
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
|
||||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||||
if (is_batch_matmul) {
|
if (is_batch_matmul) {
|
||||||
if (transpose_a || transpose_b) {
|
if (transpose_a || transpose_b) {
|
||||||
RunValidationAndConversion(
|
test->RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Input weight attempts to broadcast across batch dimension for "
|
"Input weight attempts to broadcast across batch dimension for "
|
||||||
"BatchMatMul, at my_matmul");
|
"BatchMatMul, at my_matmul");
|
||||||
} else {
|
} else {
|
||||||
RunValidationAndConversion(
|
test->RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Input weight attempts to broadcast across batch dimension");
|
"Input weight attempts to broadcast across batch dimension");
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
} else if (transpose_a) {
|
} else if (transpose_a) {
|
||||||
RunValidationAndConversion(
|
test->RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||||
"non-batch dimensions");
|
"non-batch dimensions");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||||
ASSERT_TRUE(output.is_tensor());
|
ASSERT_TRUE(output.is_tensor());
|
||||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||||
|
|
||||||
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
||||||
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
||||||
BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(input_data, &output_data);
|
||||||
if (transpose_b) {
|
if (transpose_b) {
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||||
} else {
|
} else {
|
||||||
@ -1744,31 +1750,31 @@ void OpConverterTest::TestMatMulHelper(
|
|||||||
}
|
}
|
||||||
// OK, 3D inputs
|
// OK, 3D inputs
|
||||||
for (bool transpose_b : {false, true}) {
|
for (bool transpose_b : {false, true}) {
|
||||||
Reset();
|
test->Reset();
|
||||||
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
|
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
|
||||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||||
if (is_batch_matmul) {
|
if (is_batch_matmul) {
|
||||||
if (transpose_b) {
|
if (transpose_b) {
|
||||||
RunValidationAndConversion(
|
test->RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Input weight attempts to broadcast across batch dimension for "
|
"Input weight attempts to broadcast across batch dimension for "
|
||||||
"BatchMatMul, at my_matmul");
|
"BatchMatMul, at my_matmul");
|
||||||
} else {
|
} else {
|
||||||
RunValidationAndConversion(
|
test->RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Input weight attempts to broadcast across batch dimension");
|
"Input weight attempts to broadcast across batch dimension");
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||||
ASSERT_TRUE(output.is_tensor());
|
ASSERT_TRUE(output.is_tensor());
|
||||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||||
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
||||||
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
||||||
BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(input_data, &output_data);
|
||||||
if (transpose_b) {
|
if (transpose_b) {
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||||
} else {
|
} else {
|
||||||
@ -1832,7 +1838,7 @@ TEST_F(OpConverterTest, ConvertMatMul) {
|
|||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
"Cannot currently transpose constant input if it is not 2 dimensional");
|
"Cannot currently transpose constant input if it is not 2 dimensional");
|
||||||
}
|
}
|
||||||
TestMatMulHelper(get_matmul_nodedef, "MatMul");
|
TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
||||||
@ -1889,7 +1895,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul");
|
TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType dtype>
|
template <DataType dtype>
|
||||||
@ -2010,250 +2016,82 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename OpType, DataType dtype>
|
template <typename OpType, DataType dtype>
|
||||||
void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
|
void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor,
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
bool operand_2_is_tensor) {
|
||||||
for (auto swap_inputs : {false, true}) {
|
|
||||||
test->Reset();
|
|
||||||
NodeDef node_def;
|
|
||||||
if (swap_inputs) {
|
|
||||||
node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
|
|
||||||
} else {
|
|
||||||
node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<CType> operand1{CType(3), CType(7.5)};
|
|
||||||
const std::vector<CType> operand2{CType(2), CType(3)};
|
|
||||||
|
|
||||||
// It requires the dims to be at least of rank 3 to apply an IScaleLayer.
|
|
||||||
test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
|
|
||||||
TfDataTypeToTrt(dtype));
|
|
||||||
test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
|
|
||||||
/*values=*/swap_inputs ? operand1 : operand2);
|
|
||||||
test->RunValidationAndConversion(node_def);
|
|
||||||
|
|
||||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
|
||||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
|
||||||
|
|
||||||
// Check the dims of the output ITensor.
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
|
||||||
ASSERT_TRUE(output.is_tensor());
|
|
||||||
ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
const DataVec input_data{
|
|
||||||
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
|
|
||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
|
|
||||||
test->BuildAndRun(
|
|
||||||
input_data, &output_data,
|
|
||||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
|
||||||
if (node_def.op() == "Add") {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(5), CType(10.5)));
|
|
||||||
} else if (node_def.op() == "Sub") {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(1), CType(4.5)));
|
|
||||||
} else if (node_def.op() == "Mul") {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(6), CType(22.5)));
|
|
||||||
} else if (node_def.op() == "Div") {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(1.5), CType(2.5)));
|
|
||||||
} else if (node_def.op() == "RealDiv") {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(1.5), CType(2.5)));
|
|
||||||
} else {
|
|
||||||
ASSERT_TRUE(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <DataType dtype>
|
|
||||||
void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
|
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
|
||||||
const NodeDef node_def =
|
|
||||||
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
|
|
||||||
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
|
|
||||||
const std::vector<CType> weights{CType(10), CType(20)};
|
|
||||||
// There are two types of valid dim pairs which requires channel-wise
|
|
||||||
// broadcasting:
|
|
||||||
// - input dims (X Y Z) vs weights dims (X 1 1)
|
|
||||||
// - input dims (X Y Z) vs weights dims (Z)
|
|
||||||
// Here X=Z=2 and Y=1.
|
|
||||||
for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
|
|
||||||
test->Reset();
|
|
||||||
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
|
|
||||||
TfDataTypeToTrt(dtype));
|
|
||||||
test->AddTestWeights<CType>("weights", weights_dims, weights);
|
|
||||||
test->RunValidationAndConversion(node_def);
|
|
||||||
|
|
||||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
|
||||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
|
||||||
|
|
||||||
// Check the dims of the output ITensor.
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
|
||||||
ASSERT_TRUE(output.is_tensor());
|
|
||||||
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
|
|
||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
|
||||||
test->BuildAndRun(input_data, &output_data);
|
|
||||||
if (weights_dims.size() == 1) {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(11), CType(22), CType(13), CType(24)));
|
|
||||||
} else {
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(11), CType(12), CType(23), CType(24)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <DataType dtype>
|
|
||||||
void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
|
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
|
||||||
const NodeDef node_def =
|
|
||||||
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
|
|
||||||
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
|
|
||||||
const std::vector<CType> weights{CType(10)};
|
|
||||||
test->Reset();
|
|
||||||
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
|
|
||||||
TfDataTypeToTrt(dtype));
|
|
||||||
test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
|
|
||||||
test->RunValidationAndConversion(node_def);
|
|
||||||
|
|
||||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
|
||||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
|
||||||
|
|
||||||
// Check the dims of the output ITensor.
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
|
||||||
ASSERT_TRUE(output.is_tensor());
|
|
||||||
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
|
|
||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
|
||||||
test->BuildAndRun(input_data, &output_data);
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAre(CType(11), CType(12), CType(13), CType(14)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename OpType>
|
|
||||||
void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
|
|
||||||
const std::vector<int32>& input_dims,
|
|
||||||
const std::vector<int>& weights_dims,
|
|
||||||
error::Code code = error::OK,
|
|
||||||
const char* error_msg_substr = nullptr,
|
|
||||||
const int input_batch_size = 1) {
|
|
||||||
const DataType dtype = DT_FLOAT;
|
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
|
||||||
const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims));
|
|
||||||
const size_t num_weights =
|
|
||||||
TrtWeightDimsNumElements(GetTestDims(weights_dims));
|
|
||||||
|
|
||||||
test->Reset();
|
|
||||||
const NodeDef node_def =
|
|
||||||
GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
|
|
||||||
test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
|
|
||||||
TfDataTypeToTrt(dtype));
|
|
||||||
test->AddTestWeights<CType>(
|
|
||||||
"weights", /*dims=*/weights_dims,
|
|
||||||
/*values=*/std::vector<CType>(num_weights, CType(1)));
|
|
||||||
test->RunValidationAndConversion(node_def, code, error_msg_substr);
|
|
||||||
if (code != error::OK) return;
|
|
||||||
|
|
||||||
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
|
|
||||||
CheckAddedLayers(test, /*expect_scale_layer=*/false);
|
|
||||||
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
|
||||||
ASSERT_TRUE(output.is_tensor());
|
|
||||||
|
|
||||||
// Check the dims of the output ITensor.
|
|
||||||
std::vector<int> expected_output_dims = input_dims;
|
|
||||||
for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
|
|
||||||
i >= 0 && j >= 0; --i, --j) {
|
|
||||||
if (expected_output_dims[i] == 1) {
|
|
||||||
expected_output_dims[i] = weights_dims[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExpectTrtDimsEqualsArray(expected_output_dims,
|
|
||||||
output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
// Check the result of running the engine.
|
|
||||||
const int expected_num_outputs =
|
|
||||||
TrtTensorDimsNumElements(GetTestDims(expected_output_dims));
|
|
||||||
const DataVec input_data{
|
|
||||||
{"input", ConstructTensor<CType>(num_inputs, CType(2))}};
|
|
||||||
DataVec output_data{
|
|
||||||
{"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
|
|
||||||
test->BuildAndRun(input_data, &output_data);
|
|
||||||
if (node_def.op() == "Add") {
|
|
||||||
EXPECT_THAT(
|
|
||||||
GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
|
|
||||||
} else if (node_def.op() == "Minimum") {
|
|
||||||
EXPECT_THAT(
|
|
||||||
GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
|
|
||||||
} else {
|
|
||||||
ASSERT_TRUE(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename OpType, DataType dtype>
|
|
||||||
void TestBinaryTensorOpTensor(OpConverterTest* test) {
|
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
typedef typename EnumToDataType<dtype>::Type CType;
|
||||||
test->Reset();
|
test->Reset();
|
||||||
const NodeDef node_def =
|
const NodeDef node_def =
|
||||||
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
|
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
|
||||||
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
|
if (operand_1_is_tensor) {
|
||||||
|
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2,
|
||||||
TfDataTypeToTrt(dtype));
|
TfDataTypeToTrt(dtype));
|
||||||
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
|
} else {
|
||||||
|
test->AddTestWeights("input1", /*dims=*/{1, 2},
|
||||||
|
/*values=*/std::vector<CType>{CType(3), CType(6)});
|
||||||
|
}
|
||||||
|
if (operand_2_is_tensor) {
|
||||||
|
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2,
|
||||||
TfDataTypeToTrt(dtype));
|
TfDataTypeToTrt(dtype));
|
||||||
|
} else {
|
||||||
|
test->AddTestWeights("input2", /*dims=*/{2, 1},
|
||||||
|
/*values=*/std::vector<CType>{CType(2), CType(3)});
|
||||||
|
}
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
|
DataVec input_data;
|
||||||
CheckAddedLayers(test, /*expect_scale_layer=*/false);
|
if (operand_1_is_tensor) {
|
||||||
|
input_data.push_back(
|
||||||
|
{"input1",
|
||||||
|
test::AsTensor<CType>({CType(3), CType(6), CType(3), CType(6)})});
|
||||||
|
}
|
||||||
|
if (operand_2_is_tensor) {
|
||||||
|
input_data.push_back(
|
||||||
|
{"input2",
|
||||||
|
test::AsTensor<CType>({CType(2), CType(3), CType(2), CType(3)})});
|
||||||
|
}
|
||||||
|
DataVec output_data{{"my_binary", ConstructTensor<CType>(8)}};
|
||||||
// Check output dims.
|
// Check output dims.
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||||
ASSERT_TRUE(output.is_tensor());
|
ASSERT_TRUE(output.is_tensor());
|
||||||
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
|
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
|
||||||
|
|
||||||
const DataVec input_data{
|
|
||||||
{"input1", test::AsTensor<CType>({CType(3), CType(6)})},
|
|
||||||
{"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
|
|
||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
|
||||||
// After broadcasting first input becomes {3, 6, 3, 6} and second input
|
// After broadcasting first input becomes {3, 6, 3, 6} and second input
|
||||||
// becomes {2, 3, 2, 3}.
|
// becomes {2, 3, 2, 3}.
|
||||||
test->BuildAndRun(
|
test->BuildAndRun(
|
||||||
input_data, &output_data,
|
input_data, &output_data,
|
||||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32,
|
||||||
|
/*batch_size=*/2);
|
||||||
if (node_def.op() == "Add") {
|
if (node_def.op() == "Add") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(
|
||||||
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
|
GetSpanForData<CType>(output_data[0]),
|
||||||
|
ElementsAreArray(CastTestVector<int, CType>({5, 8, 6, 9, 5, 8, 6, 9})));
|
||||||
} else if (node_def.op() == "Sub") {
|
} else if (node_def.op() == "Sub") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(
|
||||||
ElementsAre(CType(1), CType(4), CType(0), CType(3)));
|
GetSpanForData<CType>(output_data[0]),
|
||||||
|
ElementsAreArray(CastTestVector<int, CType>({1, 4, 0, 3, 1, 4, 0, 3})));
|
||||||
} else if (node_def.op() == "Mul") {
|
} else if (node_def.op() == "Mul") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
ElementsAre(CType(6), CType(12), CType(9), CType(18)));
|
ElementsAreArray(
|
||||||
|
CastTestVector<int, CType>({6, 12, 9, 18, 6, 12, 9, 18})));
|
||||||
} else if (node_def.op() == "Div") {
|
} else if (node_def.op() == "Div") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
|
ElementsAreArray(CastTestVector<float, CType>(
|
||||||
|
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
|
||||||
} else if (node_def.op() == "RealDiv") {
|
} else if (node_def.op() == "RealDiv") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
|
ElementsAreArray(CastTestVector<float, CType>(
|
||||||
|
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
|
||||||
} else if (node_def.op() == "Minimum") {
|
} else if (node_def.op() == "Minimum") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(
|
||||||
ElementsAre(CType(2), CType(2), CType(3), CType(3)));
|
GetSpanForData<CType>(output_data[0]),
|
||||||
|
ElementsAreArray(CastTestVector<int, CType>({2, 2, 3, 3, 2, 2, 3, 3})));
|
||||||
} else if (node_def.op() == "Maximum") {
|
} else if (node_def.op() == "Maximum") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(
|
||||||
ElementsAre(CType(3), CType(6), CType(3), CType(6)));
|
GetSpanForData<CType>(output_data[0]),
|
||||||
|
ElementsAreArray(CastTestVector<int, CType>({3, 6, 3, 6, 3, 6, 3, 6})));
|
||||||
} else if (node_def.op() == "Pow") {
|
} else if (node_def.op() == "Pow") {
|
||||||
ExpectArrayNear(
|
ExpectArrayNear(
|
||||||
std::vector<CType>{CType(9), CType(36), CType(27), CType(216)},
|
CastTestVector<int, CType>({9, 36, 27, 216, 9, 36, 27, 216}),
|
||||||
GetSpanForData<CType>(output_data[0]));
|
GetSpanForData<CType>(output_data[0]));
|
||||||
} else {
|
} else {
|
||||||
ASSERT_TRUE(false);
|
ASSERT_TRUE(false);
|
||||||
@ -2287,58 +2125,48 @@ TEST_F(OpConverterTest, ConvertBinary) {
|
|||||||
"both input as constant at: my_add");
|
"both input as constant at: my_add");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test BinaryTensorOpWeight() without broadcasting.
|
// Test combinations of tensor vs weight inputs (except when both inputs are
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
|
// weights).
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
|
for (const bool operand_1_is_tensor : {true, false}) {
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
|
for (const bool operand_2_is_tensor : {true, false}) {
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
|
if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
|
// FP32 tests
|
||||||
|
TestBinaryOp<ops::Add, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
|
TestBinaryOp<ops::Sub, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
|
TestBinaryOp<ops::Mul, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
|
operand_2_is_tensor);
|
||||||
|
TestBinaryOp<ops::Div, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
// Test BinaryTensorOpWeight() with channel-wise broadcasting.
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
|
TestBinaryOp<ops::RealDiv, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
|
operand_2_is_tensor);
|
||||||
// Test BinaryTensorOpWeight() with uniformly broadcasting.
|
TestBinaryOp<ops::Minimum, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
|
operand_2_is_tensor);
|
||||||
|
TestBinaryOp<ops::Maximum, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
// Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
|
operand_2_is_tensor);
|
||||||
// Unsupported op.
|
TestBinaryOp<ops::Pow, DT_FLOAT>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
|
operand_2_is_tensor);
|
||||||
// Rank of input tensor dimension <3.
|
// FP16 tests
|
||||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
|
// TODO(tmorris): Use templates to avoid duplication.
|
||||||
// Broadcast on batch dimension, should fail.
|
TestBinaryOp<ops::Add, DT_HALF>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpWeightFallback<ops::Add>(
|
operand_2_is_tensor);
|
||||||
this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
|
TestBinaryOp<ops::Sub, DT_HALF>(this, operand_1_is_tensor,
|
||||||
"Unsupported binary op broadcast scheme for op my_binary",
|
operand_2_is_tensor);
|
||||||
/*input_batch_size=*/2);
|
TestBinaryOp<ops::Mul, DT_HALF>(this, operand_1_is_tensor,
|
||||||
// Incompatible dims with per-channel mode.
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
|
TestBinaryOp<ops::Div, DT_HALF>(this, operand_1_is_tensor,
|
||||||
// Incompatible dims.
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
|
TestBinaryOp<ops::RealDiv, DT_HALF>(this, operand_1_is_tensor,
|
||||||
|
operand_2_is_tensor);
|
||||||
// Test BinaryTensorOpTensor() with broadcasting.
|
TestBinaryOp<ops::Minimum, DT_HALF>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
|
TestBinaryOp<ops::Maximum, DT_HALF>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
|
TestBinaryOp<ops::Pow, DT_HALF>(this, operand_1_is_tensor,
|
||||||
TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
|
operand_2_is_tensor);
|
||||||
TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
|
}
|
||||||
TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
|
}
|
||||||
TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
|
|
||||||
|
|
||||||
TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
|
|
||||||
TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertQuantize) {
|
TEST_F(OpConverterTest, ConvertQuantize) {
|
||||||
@ -2583,7 +2411,6 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
|
|||||||
// implementation that, the extra output classes that are outside of the
|
// implementation that, the extra output classes that are outside of the
|
||||||
// range specified by valid_detections[i] are not zeros but -1s.
|
// range specified by valid_detections[i] are not zeros but -1s.
|
||||||
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
|
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
|
||||||
const int batch_size = 1;
|
|
||||||
|
|
||||||
for (int i = 0; i < kCombinedNMSOKCases; ++i) {
|
for (int i = 0; i < kCombinedNMSOKCases; ++i) {
|
||||||
Reset();
|
Reset();
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/escaping.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||||
@ -32,9 +34,9 @@ namespace tensorflow {
|
|||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
namespace convert {
|
namespace convert {
|
||||||
// TODO(sami): Remove VLOG messages once the code matures
|
// TODO(sami): Remove VLOG messages once the code matures
|
||||||
|
using absl::AsciiStrToUpper;
|
||||||
using absl::StrAppend;
|
using absl::StrAppend;
|
||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
using str_util::Uppercase;
|
|
||||||
|
|
||||||
Status TRTOptimizationPass::Init(
|
Status TRTOptimizationPass::Init(
|
||||||
const RewriterConfig_CustomGraphOptimizer* config) {
|
const RewriterConfig_CustomGraphOptimizer* config) {
|
||||||
@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init(
|
|||||||
}
|
}
|
||||||
if (params.count("precision_mode")) {
|
if (params.count("precision_mode")) {
|
||||||
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
|
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
|
||||||
Uppercase(params.at("precision_mode").s()), &precision_mode_));
|
AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
|
||||||
}
|
}
|
||||||
if (params.count("use_calibration")) {
|
if (params.count("use_calibration")) {
|
||||||
use_calibration_ = params.at("use_calibration").b();
|
use_calibration_ = params.at("use_calibration").b();
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <dirent.h>
|
#include <dirent.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
|
|||||||
TF_ASSERT_OK(RunOpKernel());
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
// Verify the result.
|
// Verify the result.
|
||||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
// string type output will remain on CPU, so we're not using GetOutput() here.
|
||||||
Tensor* output = context_->mutable_output(0);
|
EXPECT_EQ("my_serialized_str",
|
||||||
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
|
context_->mutable_output(0)->scalar<string>()());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
|
@ -87,15 +87,10 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
|
|||||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||||
|
|
||||||
// Verify the result.
|
// Verify the result.
|
||||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
Tensor* output = OpsTestBase::GetOutput(0);
|
||||||
Tensor* output = OpsTestBase::context_->mutable_output(0);
|
EXPECT_THAT(
|
||||||
const auto& tensor_map = output->flat<TypeParam>();
|
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
|
||||||
std::vector<TypeParam> output_data(tensor_map.size());
|
output->NumElements()),
|
||||||
ASSERT_EQ(0, cudaDeviceSynchronize());
|
|
||||||
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
|
|
||||||
sizeof(TypeParam) * tensor_map.size(),
|
|
||||||
cudaMemcpyDeviceToHost));
|
|
||||||
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
|
|
||||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -681,31 +681,33 @@ Status SegmentGraph(const Graph* tf_graph,
|
|||||||
<< " with parent=" << segment_root << ":" << s;
|
<< " with parent=" << segment_root << ":" << s;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't use small segments.
|
const int num_effective_nodes = std::count_if(
|
||||||
if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
|
segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
|
||||||
|
static auto noops =
|
||||||
|
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
|
||||||
|
return noops->count(node->type_string()) == 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Don't use segments whose number of effective nodes is small.
|
||||||
|
if (num_effective_nodes < options.minimum_segment_size) {
|
||||||
VLOG(1) << "Segment " << segments->size() << " has only "
|
VLOG(1) << "Segment " << segments->size() << " has only "
|
||||||
<< segment_nodes.size() << " nodes, dropping";
|
<< num_effective_nodes << " effective nodes, dropping";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(sami): Make segmenter placement aware once trtscopes are in place
|
|
||||||
const auto& dev_itr = device_maps.find(segment_root);
|
const auto& dev_itr = device_maps.find(segment_root);
|
||||||
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
|
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
|
||||||
VLOG(1) << "No device assigned to segment " << segments->size();
|
VLOG(1) << "No device assigned to segment " << segments->size();
|
||||||
segments->emplace_back(std::make_pair(segment_nodes, string()));
|
|
||||||
} else if (dev_itr->second.size() > 1) {
|
} else if (dev_itr->second.size() > 1) {
|
||||||
string s("Segment ");
|
string s = StrCat("Segment ", segments->size(),
|
||||||
StrAppend(&s, segments->size(), " has multiple devices attached: ");
|
" has multiple devices attached: ");
|
||||||
for (const auto& dev : dev_itr->second) {
|
for (const auto& dev : dev_itr->second) {
|
||||||
StrAppend(&s, dev, ", ");
|
StrAppend(&s, dev, ", ");
|
||||||
}
|
}
|
||||||
LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
|
LOG(WARNING) << s;
|
||||||
segments->emplace_back(
|
|
||||||
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
|
|
||||||
} else {
|
|
||||||
segments->emplace_back(
|
|
||||||
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
segments->emplace_back(segment_nodes);
|
||||||
}
|
}
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
for (const auto& d : device_maps) {
|
for (const auto& d : device_maps) {
|
||||||
|
@ -31,10 +31,8 @@ namespace tensorflow {
|
|||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
namespace segment {
|
namespace segment {
|
||||||
|
|
||||||
// Vector of segments, each entry contains a set of node pointers and a device
|
// Vector of segments, each entry contains a set of node pointers.
|
||||||
// name in the segment.
|
using SegmentNodesVector = std::vector<std::set<const Node*>>;
|
||||||
using SegmentNodesVector =
|
|
||||||
std::vector<std::pair<std::set<const Node*>, string>>;
|
|
||||||
|
|
||||||
struct SegmentOptions {
|
struct SegmentOptions {
|
||||||
// Segment must contain at least this many nodes.
|
// Segment must contain at least this many nodes.
|
||||||
|
@ -77,7 +77,7 @@ class SegmentTest : public ::testing::Test {
|
|||||||
EXPECT_EQ(expected_segments.size(), segments.size());
|
EXPECT_EQ(expected_segments.size(), segments.size());
|
||||||
for (int i = 0; i < segments.size(); ++i) {
|
for (int i = 0; i < segments.size(); ++i) {
|
||||||
std::set<string> segment_node_names;
|
std::set<string> segment_node_names;
|
||||||
for (const Node* node : segments[i].first) {
|
for (const Node* node : segments[i]) {
|
||||||
segment_node_names.insert(node->name());
|
segment_node_names.insert(node->name());
|
||||||
}
|
}
|
||||||
const auto& expected = expected_segments[i];
|
const auto& expected = expected_segments[i];
|
||||||
@ -262,6 +262,23 @@ TEST_F(SegmentTest, BigIfElse) {
|
|||||||
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
|
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SegmentTest, IdentityOps) {
|
||||||
|
Scope s = Scope::NewRootScope();
|
||||||
|
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
|
||||||
|
auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
|
||||||
|
auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
|
||||||
|
auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||||
|
auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
TF_EXPECT_OK(s.ToGraph(&g));
|
||||||
|
|
||||||
|
const std::set<string> all_identities = {"identity0", "identity1",
|
||||||
|
"identity2", "identity3"};
|
||||||
|
// Identity ops are not counted as effective ops in the segment, so no segment
|
||||||
|
// will be formed in this case.
|
||||||
|
RunTest(&g, all_identities, all_identities, all_identities, {});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace segment
|
} // namespace segment
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [":internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = [
|
||||||
@ -23,15 +26,12 @@ package_group(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = [":internal"],
|
|
||||||
)
|
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||||
"if_cuda_is_configured",
|
"if_cuda_is_configured",
|
||||||
)
|
)
|
||||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf2xla_supported_ops_lib",
|
name = "tf2xla_supported_ops_lib",
|
||||||
@ -67,6 +67,19 @@ xla_proto_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A proto library that is minimal in size and dependencies for platforms like Android.
|
||||||
|
tf_portable_proto_library(
|
||||||
|
name = "portable_tf2xla_proto",
|
||||||
|
config_string = "allow_all:true",
|
||||||
|
header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"],
|
||||||
|
portable_deps = ["//tensorflow/core:android_proto_lib"],
|
||||||
|
proto_deps = [
|
||||||
|
":tf2xla_proto",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
xla_py_proto_library(
|
xla_py_proto_library(
|
||||||
name = "tf2xla_py",
|
name = "tf2xla_py",
|
||||||
has_services = False,
|
has_services = False,
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
||||||
|
|
||||||
tf_gen_op_wrapper_cc(
|
tf_gen_op_wrapper_cc(
|
||||||
|
@ -918,10 +918,16 @@ string Conditional::name() const {
|
|||||||
|
|
||||||
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
|
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
|
||||||
int port) {
|
int port) {
|
||||||
|
NodeBuilder id_builder(replacee->name(), "Identity");
|
||||||
|
id_builder.Input(if_node, port);
|
||||||
|
string outside_compilation;
|
||||||
|
if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
|
||||||
|
&outside_compilation)
|
||||||
|
.ok()) {
|
||||||
|
id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
|
||||||
|
}
|
||||||
Node* id;
|
Node* id;
|
||||||
TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
|
TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
|
||||||
.Input(if_node, port)
|
|
||||||
.Finalize(graph_, &id));
|
|
||||||
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
|
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
|
||||||
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
|
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -247,8 +247,8 @@ Status FunctionalizeControlFlowPass::Run(
|
|||||||
// multiple times, and we want to avoid functionalize it again.
|
// multiple times, and we want to avoid functionalize it again.
|
||||||
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
|
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
|
||||||
new std::map<string, string>{
|
new std::map<string, string>{
|
||||||
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
|
// _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
|
||||||
{"TPUReplicate", "computation"},
|
{"_TPUReplicate", "computation"},
|
||||||
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
|
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
|
||||||
{"XlaLaunch", "function"},
|
{"XlaLaunch", "function"},
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
@ -195,6 +194,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core/kernels:training_ops",
|
"//tensorflow/core/kernels:training_ops",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ class AssertOp : public XlaOpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("Assert"), AssertOp);
|
REGISTER_XLA_OP(Name("Assert").CompilationOnly(), AssertOp);
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -39,7 +39,10 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void CompileImpl(XlaOpKernelContext* ctx) {
|
||||||
xla::PrimitiveType input_type;
|
xla::PrimitiveType input_type;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
|
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
|
||||||
@ -116,8 +119,29 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
bool is_on_gpu_;
|
bool is_on_gpu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class FusedBatchNormOpV3 : public FusedBatchNormOp {
|
||||||
|
public:
|
||||||
|
explicit FusedBatchNormOpV3(OpKernelConstruction* ctx)
|
||||||
|
: FusedBatchNormOp(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
FusedBatchNormOp::CompileImpl(ctx);
|
||||||
|
if (!ctx->status().ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx->SetConstantOutput(5, Tensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float epsilon_;
|
||||||
|
TensorFormat data_format_;
|
||||||
|
bool is_training_;
|
||||||
|
bool is_on_gpu_;
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
||||||
|
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
|
||||||
|
|
||||||
class FusedBatchNormGradOp : public XlaOpKernel {
|
class FusedBatchNormGradOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -233,6 +257,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
|
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
|
REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
|
||||||
|
REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
|||||||
xla::XlaOp conv =
|
xla::XlaOp conv =
|
||||||
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
||||||
lhs_dilation, rhs_dilation, dims, depth);
|
lhs_dilation, rhs_dilation, dims, depth);
|
||||||
|
// Feature group convolution, will end up with the kernel_size change more
|
||||||
|
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
|
||||||
|
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
|
||||||
|
conv_dims.back() = depth;
|
||||||
|
conv_dims.push_back(kernel_size);
|
||||||
|
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
|
||||||
|
conv_dims.pop_back();
|
||||||
|
conv_dims.back() *= kernel_size;
|
||||||
|
conv = xla::Reshape(conv, conv_dims);
|
||||||
ctx->SetOutput(0, conv);
|
ctx->SetOutput(0, conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
@ -20,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
|||||||
|
|
||||||
class GatherOp : public XlaOpKernel {
|
class GatherOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||||
|
// Set batch_dims_ to 0 if the attribute does not exist.
|
||||||
|
if (context->HasAttr("batch_dims")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
|
||||||
|
} else {
|
||||||
|
batch_dims_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* context) override {
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
xla::XlaBuilder* builder = context->builder();
|
|
||||||
auto input = context->Input(0);
|
auto input = context->Input(0);
|
||||||
auto input_shape = context->InputShape(0);
|
auto input_shape = context->InputShape(0);
|
||||||
auto indices = context->Input(1);
|
auto indices = context->Input(1);
|
||||||
auto indices_shape = context->InputShape(1);
|
auto indices_shape = context->InputShape(1);
|
||||||
int64 axis = 0;
|
|
||||||
|
absl::optional<int64> axis;
|
||||||
if (context->num_inputs() == 3) {
|
if (context->num_inputs() == 3) {
|
||||||
const TensorShape axis_shape = context->InputShape(2);
|
const TensorShape axis_shape = context->InputShape(2);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
||||||
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
||||||
errors::InvalidArgument("axis must be int32 or int64"));
|
errors::InvalidArgument("axis must be int32 or int64"));
|
||||||
|
|
||||||
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
|
int64 axis_input;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->ConstantInputAsIntScalar(2, &axis_input));
|
||||||
|
|
||||||
const auto params_dims = input_shape.dims();
|
const auto params_dims = input_shape.dims();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context,
|
||||||
context, -params_dims <= axis && axis < params_dims,
|
-params_dims <= axis_input && axis_input < params_dims,
|
||||||
errors::InvalidArgument("Expected axis in the range [", -params_dims,
|
errors::InvalidArgument("Expected axis in the range [",
|
||||||
", ", params_dims, "), but got ", axis));
|
-params_dims, ", ", params_dims,
|
||||||
if (axis < 0) {
|
"), but got ", axis_input));
|
||||||
axis += params_dims;
|
if (axis_input < 0) {
|
||||||
|
axis_input += params_dims;
|
||||||
}
|
}
|
||||||
|
axis = axis_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (batch_dims_ != 0) {
|
||||||
|
if (batch_dims_ < 0) {
|
||||||
|
batch_dims_ = indices_shape.dims() + batch_dims_;
|
||||||
|
}
|
||||||
|
|
||||||
|
axis = axis.value_or(batch_dims_);
|
||||||
|
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
batch_dims_ >= -indices_shape.dims() &&
|
||||||
|
batch_dims_ < indices_shape.dims(),
|
||||||
|
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||||
|
-indices_shape.dims(), ", ",
|
||||||
|
indices_shape.dims(), "), but got ",
|
||||||
|
batch_dims_));
|
||||||
|
|
||||||
|
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
|
||||||
|
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||||
|
") must be less than rank(input) (",
|
||||||
|
input_shape.dims(), ")."));
|
||||||
|
|
||||||
|
OP_REQUIRES(context, *axis >= batch_dims_,
|
||||||
|
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||||
|
") must be less than or equal to ",
|
||||||
|
"axis (", *axis, ")."));
|
||||||
|
}
|
||||||
|
|
||||||
|
axis = axis.value_or(0);
|
||||||
DataType index_type = input_type(1);
|
DataType index_type = input_type(1);
|
||||||
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
||||||
errors::InvalidArgument("indices must be int32 or int64"));
|
errors::InvalidArgument("indices must be int32 or int64"));
|
||||||
|
|
||||||
xla::XlaOp gather;
|
xla::XlaOp gather;
|
||||||
|
if (batch_dims_ > 0) {
|
||||||
|
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
|
||||||
|
} else {
|
||||||
|
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||||
|
// error conditions and caught above if batch_dims is not 0.
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, XlaGather(input, input_shape, indices, indices_shape, axis,
|
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||||
/*indices_are_nd=*/false, input_type(0), index_type,
|
/*indices_are_nd=*/false, input_type(0),
|
||||||
builder, &gather));
|
index_type, context->builder(), &gather));
|
||||||
|
}
|
||||||
context->SetOutput(0, gather);
|
context->SetOutput(0, gather);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
|
||||||
|
|
||||||
|
// The number of batch dimensions, as passed in the batch_dims attribute.
|
||||||
|
// It must be less than rank(indices).
|
||||||
|
int32 batch_dims_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("Gather"), GatherOp);
|
REGISTER_XLA_OP(Name("Gather"), GatherOp);
|
||||||
|
@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
|
|||||||
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
|
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
|
||||||
|
|
||||||
// Calculate in each row of `predictions`, how many values are larger than
|
// Calculate in each row of `predictions`, how many values are larger than
|
||||||
// the value of target class. Then return the result whether the count <= k,
|
// the value of target class. Then return the result whether the count < k,
|
||||||
// which indicates the target is in topk.
|
// which indicates the target is in topk.
|
||||||
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0});
|
xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
|
||||||
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
|
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
|
||||||
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
|
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
|
||||||
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
|
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
|
||||||
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
|
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
|
||||||
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2);
|
xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
|
||||||
xla::XlaOp num_ge_r1 = xla::Reduce(
|
xla::XlaOp num_gt_r1 = xla::Reduce(
|
||||||
one_hot_r2, zero_r0,
|
one_hot_r2, zero_r0,
|
||||||
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
|
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
|
||||||
|
|
||||||
xla::XlaOp result =
|
xla::XlaOp result =
|
||||||
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k));
|
xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
|
||||||
|
xla::IsFinite(targets_values_r1));
|
||||||
|
|
||||||
context->SetOutput(0, result);
|
context->SetOutput(0, result);
|
||||||
}
|
}
|
||||||
|
@ -67,9 +67,9 @@ class MatMulOp : public XlaOpKernel {
|
|||||||
|
|
||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
|
a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
|
||||||
errors::InvalidArgument("Matrix size-compatible: In[0]: ",
|
errors::InvalidArgument(
|
||||||
a_shape.DebugString(), ", In[1]: ",
|
"Matrix size-incompatible: In[0]: ", a_shape.DebugString(),
|
||||||
b_shape.DebugString()));
|
", In[1]: ", b_shape.DebugString()));
|
||||||
|
|
||||||
xla::XlaOp a = ctx->Input(0);
|
xla::XlaOp a = ctx->Input(0);
|
||||||
xla::XlaOp b = ctx->Input(1);
|
xla::XlaOp b = ctx->Input(1);
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
@ -22,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/util/bcast.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -77,5 +79,58 @@ class SelectOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("Select"), SelectOp);
|
REGISTER_XLA_OP(Name("Select"), SelectOp);
|
||||||
|
|
||||||
|
class SelectOpV2 : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit SelectOpV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape cond_shape = ctx->InputShape(0);
|
||||||
|
const TensorShape then_shape = ctx->InputShape(1);
|
||||||
|
const TensorShape else_shape = ctx->InputShape(2);
|
||||||
|
|
||||||
|
// Compute the output shape from the broadcast of the two data inputs, with
|
||||||
|
// the broadcast of the conditional.
|
||||||
|
// Then Broadcast all three inputs to the output shape and emit a select.
|
||||||
|
|
||||||
|
BCast bcast_then_else(BCast::FromShape(then_shape),
|
||||||
|
BCast::FromShape(else_shape),
|
||||||
|
/*fewer_dims_optimization=*/false);
|
||||||
|
if (!bcast_then_else.IsValid()) {
|
||||||
|
ctx->SetStatus(errors::InvalidArgument(
|
||||||
|
"Incompatible shapes: ", then_shape.DebugString(), " vs. ",
|
||||||
|
else_shape.DebugString()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape),
|
||||||
|
/*fewer_dims_optimization=*/false);
|
||||||
|
if (!bcast.IsValid()) {
|
||||||
|
ctx->SetStatus(errors::InvalidArgument(
|
||||||
|
"Incompatible shapes: ",
|
||||||
|
BCast::ToShape(bcast_then_else.output_shape()).DebugString(), " vs. ",
|
||||||
|
cond_shape.DebugString()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape());
|
||||||
|
OP_REQUIRES_OK(ctx, bcasted_cond.status());
|
||||||
|
auto cond_handle = bcasted_cond.ValueOrDie();
|
||||||
|
|
||||||
|
auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape());
|
||||||
|
OP_REQUIRES_OK(ctx, bcasted_then.status());
|
||||||
|
auto then_handle = bcasted_then.ValueOrDie();
|
||||||
|
|
||||||
|
auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape());
|
||||||
|
OP_REQUIRES_OK(ctx, bcasted_else.status());
|
||||||
|
auto else_handle = bcasted_else.ValueOrDie();
|
||||||
|
|
||||||
|
ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(SelectOpV2);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("SelectV2"), SelectOpV2);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
// XLA-specific Ops for softmax.
|
// XLA-specific Ops for softmax.
|
||||||
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
@ -145,23 +146,36 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
|
|||||||
: XlaOpKernel(ctx) {}
|
: XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
const TensorShape logits_shape = ctx->InputShape(0);
|
|
||||||
const TensorShape labels_shape = ctx->InputShape(1);
|
|
||||||
OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"logits and labels must be same size: logits_size=",
|
|
||||||
logits_shape.DebugString(),
|
|
||||||
" labels_size=", labels_shape.DebugString()));
|
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
|
|
||||||
errors::InvalidArgument("logits must be 2-dimensional"));
|
|
||||||
// As we already tested that both inputs have the same shape no need to
|
|
||||||
// check that "labels" is a matrix too.
|
|
||||||
|
|
||||||
const DataType type = input_type(0);
|
const DataType type = input_type(0);
|
||||||
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
|
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
|
||||||
auto logits = ctx->Input(0);
|
auto logits = ctx->Input(0);
|
||||||
auto labels = ctx->Input(1);
|
auto labels = ctx->Input(1);
|
||||||
|
|
||||||
|
const TensorShape logits_shape = ctx->InputShape(0);
|
||||||
|
const TensorShape labels_shape = ctx->InputShape(1);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
|
||||||
|
errors::InvalidArgument("logits must be 2-dimensional"));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_shape),
|
||||||
|
errors::InvalidArgument("labels must be 2-dimensional"));
|
||||||
|
|
||||||
|
// Confirm that any necessary broadcasting to make the shapes the same will
|
||||||
|
// succeed.
|
||||||
|
for (int dim = 0; dim < 2; dim++) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
labels_shape.dim_size(dim) == 1 ||
|
||||||
|
logits_shape.dim_size(dim) == labels_shape.dim_size(dim),
|
||||||
|
errors::InvalidArgument("logits and labels must be same size after "
|
||||||
|
"broadcasting of labels: logits_size=",
|
||||||
|
logits_shape.DebugString(),
|
||||||
|
" labels_size=", labels_shape.DebugString()));
|
||||||
|
}
|
||||||
|
if (!logits_shape.IsSameSize(labels_shape)) {
|
||||||
|
auto labels_or = BroadcastTo(labels, logits_shape.dim_sizes());
|
||||||
|
OP_REQUIRES_OK(ctx, labels_or.status());
|
||||||
|
labels = labels_or.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
xla::XlaOp loss, backprop;
|
xla::XlaOp loss, backprop;
|
||||||
std::tie(loss, backprop) =
|
std::tie(loss, backprop) =
|
||||||
CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
|
CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
# Utilities for building XLA computations.
|
# Utilities for building XLA computations.
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filegroup used to collect source files for dependency checking.
|
# Filegroup used to collect source files for dependency checking.
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"tf_custom_op_library",
|
"tf_custom_op_library",
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
"//visibility:public",
|
"//visibility:public",
|
||||||
],
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
load(
|
load(
|
||||||
|
@ -550,6 +550,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
|||||||
};
|
};
|
||||||
GraphOptimizer::Options graph_optimizer_options;
|
GraphOptimizer::Options graph_optimizer_options;
|
||||||
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
|
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
|
||||||
|
graph_optimizer_options.inline_multi_device_functions = true;
|
||||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||||
|
|
||||||
|
@ -116,9 +116,12 @@ class XlaOpRegistry {
|
|||||||
// If we should cluster operations returning DT_VARIANT.
|
// If we should cluster operations returning DT_VARIANT.
|
||||||
bool cluster_variant_ops = false;
|
bool cluster_variant_ops = false;
|
||||||
|
|
||||||
// Whether ops known to be slow or to have correctness issues should be
|
// Whether ops known to be slow should be auto-clustered.
|
||||||
|
bool cluster_slow_ops = false;
|
||||||
|
|
||||||
|
// Whether ops known to have numerical accuracy issues should be
|
||||||
// auto-clustered.
|
// auto-clustered.
|
||||||
bool cluster_slow_and_inaccurate_ops = false;
|
bool cluster_inaccurate_ops = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Registers an XLA backend. `compilation_device_name` is the name of the
|
// Registers an XLA backend. `compilation_device_name` is the name of the
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
package(default_visibility = ["//tensorflow:internal"])
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "friends",
|
name = "friends",
|
||||||
@ -575,6 +576,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":types",
|
":types",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -881,6 +883,26 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "refcounting_hash_map",
|
||||||
|
hdrs = ["refcounting_hash_map.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "refcounting_hash_map_test",
|
||||||
|
srcs = ["refcounting_hash_map_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":refcounting_hash_map",
|
||||||
|
":test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user