Merge branch 'tensorflow-master'

This commit is contained in:
minds 2019-05-27 18:47:30 +09:00
commit 33dffe53fb
1650 changed files with 82482 additions and 58113 deletions

View File

@ -39,32 +39,46 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true build:download_clang --define=using_clang=true
build:download_clang --action_env TF_DOWNLOAD_CLANG=1
# Instruct clang to use LLD for linking. # Instruct clang to use LLD for linking.
# This only works with GPU builds currently, since Bazel sets -B/usr/bin in # This only works with GPU builds currently, since Bazel sets -B/usr/bin in
# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over # auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
# the downloaded one. # the downloaded one.
build:download_clang_use_lld --linkopt='-fuse-ld=lld' build:download_clang_use_lld --linkopt='-fuse-ld=lld'
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain # This config refers to building with CUDA available. It does not necessarily
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true # mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
# This config refers to building CUDA op kernels with clang.
build:cuda_clang --config=using_cuda
build:cuda_clang --define=using_cuda_clang=true
build:cuda_clang --define=using_clang=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true --define=using_trisycl=false build:sycl --define=using_sycl=true
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_nodouble --config=sycl
build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_nodouble --config=sycl
build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script # Options extracted from configure script
build:gdr --define=with_gdr_support=true build:gdr --define=with_gdr_support=true
@ -87,6 +101,9 @@ build --spawn_strategy=standalone
build --strategy=Genrule=standalone build --strategy=Genrule=standalone
build -c opt build -c opt
# Make Bazel print out all options from rc files.
build --announce_rc
# Other build flags. # Other build flags.
build --define=grpc_no_ares=true build --define=grpc_no_ares=true
@ -97,8 +114,7 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
# Build TF with C++ 17 features. # Build TF with C++ 17 features.
build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++ build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --cxxopt=-std=c++1z build:c++1z --config=c++17
build:c++1z --cxxopt=-stdlib=libc++
# Default paths for TF_SYSTEM_LIBS # Default paths for TF_SYSTEM_LIBS
build --define=PREFIX=/usr build --define=PREFIX=/usr

View File

@ -38,7 +38,13 @@ working on getting your pull request submitted to our internal repository. After
the change has been submitted internally, your pull request will be merged the change has been submitted internally, your pull request will be merged
automatically on GitHub. automatically on GitHub.
If you want to contribute but you're not sure where to start, take a look at the If you want to contribute, start working through the TensorFlow codebase,
navigate to the
[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start
looking through interesting issues. If you are not sure of where to start, then
start by trying one of the smaller/easier issues here i.e.
[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue)
and then take a look at the
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you contributions, often because we probably won't get to them right now. If you

View File

@ -403,7 +403,8 @@ def set_action_env_var(environ_cp,
enabled_by_default, enabled_by_default,
question=None, question=None,
yes_reply=None, yes_reply=None,
no_reply=None): no_reply=None,
bazel_config_name=None):
"""Set boolean action_env variable. """Set boolean action_env variable.
Ask user if query_item will be enabled. Default is used if no input is given. Ask user if query_item will be enabled. Default is used if no input is given.
@ -418,12 +419,16 @@ def set_action_env_var(environ_cp,
question: optional string for how to ask for user input. question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled. yes_reply: optional string for reply when feature is enabled.
no_reply: optional string for reply when feature is disabled. no_reply: optional string for reply when feature is disabled.
bazel_config_name: adding config to .bazelrc instead of action_env.
""" """
var = int( var = int(
get_var(environ_cp, var_name, query_item, enabled_by_default, question, get_var(environ_cp, var_name, query_item, enabled_by_default, question,
yes_reply, no_reply)) yes_reply, no_reply))
if not bazel_config_name:
write_action_env_to_bazelrc(var_name, var) write_action_env_to_bazelrc(var_name, var)
elif var:
write_to_bazelrc('build --config=%s' % bazel_config_name)
environ_cp[var_name] = str(var) environ_cp[var_name] = str(var)
@ -543,7 +548,8 @@ def set_tf_cuda_clang(environ_cp):
False, False,
question=question, question=question,
yes_reply=yes_reply, yes_reply=yes_reply,
no_reply=no_reply) no_reply=no_reply,
bazel_config_name='cuda_clang')
def set_tf_download_clang(environ_cp): def set_tf_download_clang(environ_cp):
@ -558,7 +564,8 @@ def set_tf_download_clang(environ_cp):
False, False,
question=question, question=question,
yes_reply=yes_reply, yes_reply=yes_reply,
no_reply=no_reply) no_reply=no_reply,
bazel_config_name='download_clang')
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
@ -782,8 +789,8 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path):
print('WARNING: The NDK version in %s is %s, which is not ' print('WARNING: The NDK version in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use ' 'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing ' 'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_version, 'errors.\n' %
_SUPPORTED_ANDROID_NDK_VERSIONS)) (android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS))
# Now grab the NDK API level to use. Note that this is different from the # Now grab the NDK API level to use. Note that this is different from the
# SDK API level, as the NDK API level is effectively the *min* target SDK # SDK API level, as the NDK API level is effectively the *min* target SDK
@ -952,6 +959,7 @@ def set_tf_nccl_version(environ_cp):
ask_nccl_version, '') ask_nccl_version, '')
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
def get_native_cuda_compute_capabilities(environ_cp): def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities. """Get native cuda compute capabilities.
@ -1293,9 +1301,6 @@ def configure_ios():
""" """
if not is_macos(): if not is_macos():
return return
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
print(
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
for filepath in APPLE_BAZEL_FILES: for filepath in APPLE_BAZEL_FILES:
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
@ -1386,7 +1391,7 @@ def main():
# environment variables. # environment variables.
environ_cp = dict(os.environ) environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.24.1', '0.25.2') current_bazel_version = check_bazel_version('0.24.1', '0.25.3')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc() reset_tf_configure_bazelrc()
@ -1422,7 +1427,12 @@ def main():
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla') xla_enabled_by_default, 'xla')
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',
'OpenCL SYCL',
False,
bazel_config_name='sycl')
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp) set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp) set_host_c_compiler(environ_cp)
@ -1432,30 +1442,44 @@ def main():
else: else:
set_trisycl_include_dir(environ_cp) set_trisycl_include_dir(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False) set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
if (environ_cp.get('TF_NEED_ROCM') == '1' and if (environ_cp.get('TF_NEED_ROCM') == '1' and
'LD_LIBRARY_PATH' in environ_cp and 'LD_LIBRARY_PATH' in environ_cp and
environ_cp.get('LD_LIBRARY_PATH') != '1'): environ_cp.get('LD_LIBRARY_PATH') != '1'):
write_action_env_to_bazelrc('LD_LIBRARY_PATH', write_action_env_to_bazelrc('LD_LIBRARY_PATH',
environ_cp.get('LD_LIBRARY_PATH')) environ_cp.get('LD_LIBRARY_PATH'))
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) environ_cp['TF_NEED_CUDA'] = str(
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
if (environ_cp.get('TF_NEED_CUDA') == '1' and if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp): 'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False) set_action_env_var(
environ_cp,
'TF_NEED_TENSORRT',
'TensorRT',
False,
bazel_config_name='tensorrt')
environ_save = dict(environ_cp) environ_save = dict(environ_cp)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
if validate_cuda_config(environ_cp): if validate_cuda_config(environ_cp):
cuda_env_names = [ cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION', 'TF_CUDA_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS', 'TF_CUBLAS_VERSION',
'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION',
'TF_NCCL_VERSION',
'TF_CUDA_PATHS',
# Items below are for backwards compatibility when not using # Items below are for backwards compatibility when not using
# TF_CUDA_PATHS. # TF_CUDA_PATHS.
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH', 'CUDA_TOOLKIT_PATH',
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH' 'CUDNN_INSTALL_PATH',
'NCCL_INSTALL_PATH',
'NCCL_HDR_PATH',
'TENSORRT_INSTALL_PATH'
] ]
# Note: set_action_env_var above already writes to bazelrc. # Note: set_action_env_var above already writes to bazelrc.
for name in cuda_env_names: for name in cuda_env_names:
@ -1506,8 +1530,6 @@ def main():
# CUDA not required. Ask whether we should download the clang toolchain and # CUDA not required. Ask whether we should download the clang toolchain and
# use it for the CPU build. # use it for the CPU build.
set_tf_download_clang(environ_cp) set_tf_download_clang(environ_cp)
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
write_to_bazelrc('build --config=download_clang')
# SYCL / ROCm / CUDA are mutually exclusive. # SYCL / ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured. # At most 1 GPU platform can be configured.

View File

@ -59,7 +59,7 @@ except ImportError:
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """ _CONTRIB_WARNING = """
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see: For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons * https://github.com/tensorflow/addons

View File

@ -21,6 +21,9 @@ filegroup(
srcs = [ srcs = [
"c_api.h", "c_api.h",
"c_api_experimental.h", "c_api_experimental.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_status.h",
], ],
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
@ -51,6 +54,8 @@ tf_cuda_library(
hdrs = [ hdrs = [
"c_api.h", "c_api.h",
"c_api_internal.h", "c_api_internal.h",
"tf_datatype.h",
"tf_status.h",
], ],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
@ -61,6 +66,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":tf_attrtype",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -71,16 +77,26 @@ tf_cuda_library(
}), }),
) )
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
visibility = ["//visibility:public"],
)
tf_cuda_library( tf_cuda_library(
name = "c_api", name = "c_api",
hdrs = [ hdrs = [
"c_api.h", "c_api.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_status.h",
], ],
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":c_api_no_xla", ":c_api_no_xla",
":c_api_internal", ":c_api_internal",
":tf_attrtype",
] + select({ ] + select({
"//tensorflow:with_xla_support": [ "//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
@ -96,14 +112,21 @@ tf_cuda_library(
"c_api.cc", "c_api.cc",
"c_api_function.cc", "c_api_function.cc",
], ],
hdrs = ["c_api.h"], hdrs = [
"c_api.h",
],
copts = tf_copts(), copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"], visibility = ["//tensorflow/c:__subpackages__"],
deps = [":c_api_internal"] + select({ deps = [
":c_api_internal",
":tf_attrtype",
":tf_datatype",
] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":tf_status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients", "//tensorflow/cc:gradients",
@ -124,6 +147,37 @@ tf_cuda_library(
}), }),
) )
cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:lib",
],
}),
)
cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
hdrs = ["tf_datatype.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}),
)
tf_cuda_library( tf_cuda_library(
name = "c_api_experimental", name = "c_api_experimental",
srcs = [ srcs = [
@ -137,6 +191,7 @@ tf_cuda_library(
deps = [ deps = [
":c_api", ":c_api",
":c_api_internal", ":c_api_internal",
":checkpoint_reader",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
@ -151,15 +206,6 @@ tf_cuda_library(
], ],
) )
cc_library(
name = "c_api_headers",
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
)
exports_files( exports_files(
[ [
"version_script.lds", "version_script.lds",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/strings/match.h"
// Required for IS_MOBILE_PLATFORM // Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/platform/platform.h" // NOLINT
@ -97,7 +98,6 @@ using tensorflow::TensorId;
using tensorflow::TensorShape; using tensorflow::TensorShape;
using tensorflow::TensorShapeProto; using tensorflow::TensorShapeProto;
using tensorflow::VersionDef; using tensorflow::VersionDef;
using tensorflow::error::Code;
using tensorflow::errors::FailedPrecondition; using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument; using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice; using tensorflow::gtl::ArraySlice;
@ -108,34 +108,6 @@ extern "C" {
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
const char* TF_Version() { return TF_VERSION_STRING; } const char* TF_Version() { return TF_VERSION_STRING; }
// --------------------------------------------------------------------------
size_t TF_DataTypeSize(TF_DataType dt) {
return static_cast<size_t>(
tensorflow::DataTypeSize(static_cast<DataType>(dt)));
}
// --------------------------------------------------------------------------
TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}
TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}
const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
namespace { namespace {
@ -1697,7 +1669,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
if (metadata.list_size == 0) { if (metadata.list_size == 0) {
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
const auto& a = oper->node.op_def().attr(i); const auto& a = oper->node.op_def().attr(i);
if (a.name().compare(attr_name) != 0) continue; if (a.name() != attr_name) continue;
const string& typestr = a.type(); const string& typestr = a.type();
if (typestr == "list(string)") { if (typestr == "list(string)") {
metadata.type = TF_ATTR_STRING; metadata.type = TF_ATTR_STRING;
@ -2517,8 +2489,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
// used in this graph // used in this graph
for (const auto& pair : g->name_map) { for (const auto& pair : g->name_map) {
const string& name = pair.first; const string& name = pair.first;
if (name.compare(prefix) == 0 || if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
status->status = InvalidArgument( status->status = InvalidArgument(
"prefix [", prefix, "prefix [", prefix,
"] conflicts with existing node in the graph named [", name, "]"); "] conflicts with existing node in the graph named [", name, "]");
@ -2548,8 +2519,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
// Adding the gradients to the graph can alter the prefix to prevent // Adding the gradients to the graph can alter the prefix to prevent
// name collisions only if this prefix has not been provided explicitly // name collisions only if this prefix has not been provided explicitly
// by the user. If it was provided, assert that it remained intact. // by the user. If it was provided, assert that it remained intact.
if (prefix != nullptr && if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
!tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
status->status = tensorflow::errors::Internal( status->status = tensorflow::errors::Internal(
"BUG: The gradients prefix have been unexpectedly altered when " "BUG: The gradients prefix have been unexpectedly altered when "
"adding the nodes to the graph. This is a bug. Please file an " "adding the nodes to the graph. This is a bug. Please file an "

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "tensorflow/c/tf_attrtype.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// C API for TensorFlow. // C API for TensorFlow.
// //
@ -69,7 +73,7 @@ limitations under the License.
// .dylib, .dll). // .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in // This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent // tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.$a // of any other includes.
#ifdef SWIG #ifdef SWIG
#define TF_CAPI_EXPORT #define TF_CAPI_EXPORT
#else #else
@ -93,89 +97,6 @@ extern "C" {
// TensorFlow library. TensorFlow using semantic versioning. // TensorFlow library. TensorFlow using semantic versioning.
TF_CAPI_EXPORT extern const char* TF_Version(void); TF_CAPI_EXPORT extern const char* TF_Version(void);
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
// The enum values here are identical to corresponding values in types.proto.
typedef enum TF_DataType {
TF_FLOAT = 1,
TF_DOUBLE = 2,
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
TF_UINT8 = 4,
TF_INT16 = 5,
TF_INT8 = 6,
TF_STRING = 7,
TF_COMPLEX64 = 8, // Single-precision complex
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
TF_INT64 = 9,
TF_BOOL = 10,
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
// to the given TF_DataType enum value. Returns 0 for variable length types
// (eg. TF_STRING) or on failure.
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
typedef enum TF_Code {
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
} TF_Code;
// --------------------------------------------------------------------------
// TF_Status holds error information. It either has an OK code, or
// else an error code with an associated error message.
typedef struct TF_Status TF_Status;
// Return a new status object.
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
// Record <code, msg> in *s. Any previous information is lost.
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
const char* msg);
// Return the code record in *s.
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
// Return a pointer to the (null-terminated) error message in *s. The
// return value points to memory that is only usable until the next
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
// TF_OK.
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// TF_Buffer holds a pointer to a block of data and its associated length. // TF_Buffer holds a pointer to a block of data and its associated length.
// Typically, the data consists of a serialized protocol buffer, but other data // Typically, the data consists of a serialized protocol buffer, but other data
@ -686,19 +607,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs(
TF_Operation* oper, TF_Operation** control_outputs, TF_Operation* oper, TF_Operation** control_outputs,
int max_control_outputs); int max_control_outputs);
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
// TF_AttrMetadata describes the value of an attribute on an operation. // TF_AttrMetadata describes the value of an attribute on an operation.
typedef struct TF_AttrMetadata { typedef struct TF_AttrMetadata {
// A boolean: 1 if the attribute value is a list, 0 otherwise. // A boolean: 1 if the attribute value is a list, 0 otherwise.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
@ -37,6 +38,7 @@ using tensorflow::FunctionDef;
using tensorflow::Node; using tensorflow::Node;
using tensorflow::NodeBuilder; using tensorflow::NodeBuilder;
using tensorflow::Status; using tensorflow::Status;
using tensorflow::errors::InvalidArgument;
namespace { namespace {
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
@ -149,7 +151,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
} }
char* TF_FunctionDebugString(TF_Function* func, size_t* len) { char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
const auto& debug_str = func->fdef.DebugString(); const auto& debug_str = DebugString(func->fdef);
*len = debug_str.size(); *len = debug_str.size();
char* ret = static_cast<char*>(malloc(*len + 1)); char* ret = static_cast<char*>(malloc(*len + 1));
memcpy(ret, debug_str.c_str(), *len + 1); memcpy(ret, debug_str.c_str(), *len + 1);
@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg); status->status = tensorflow::errors::Internal(errMsg);
} }
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
std::vector<std::string> variable_list;
};
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
TF_Status* status) {
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
if (!status->status.ok()) return nullptr;
const auto& m = reader->GetVariableToDataTypeMap();
for (auto it = m.begin(); it != m.end(); ++it)
reader->variable_list.push_back(it->first);
std::sort(reader->variable_list.begin(), reader->variable_list.end());
return reader;
}
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
const char* name) {
return reader->HasTensor(name);
}
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
int index) {
return reader->variable_list[index].c_str();
}
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
return reader->variable_list.size();
}
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToDataTypeMap();
return static_cast<TF_DataType>(m.at(name));
}
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
const char* name, TF_Status* status) {
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor.get(), status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
const char* name, int64_t* dims,
int num_dims, TF_Status* status) {
const auto& shape = reader->GetVariableToShapeMap().at(name);
int rank = shape.dims();
if (num_dims != rank) {
status->status = InvalidArgument("Expected rank is ", num_dims,
" but actual rank is ", rank);
return;
}
for (int i = 0; i < num_dims; i++) {
dims[i] = shape.dim_size(i);
}
}
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToShapeMap();
return m.at(name).dims();
}
// This builder is used in the eager API to build a NodeDef. // This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder { struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder; using tensorflow::AttrBuilder::AttrBuilder;

View File

@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg); const char* errMsg);
// TF_NewCheckpointReader() return the CheckpointReader that can be use to
// investigate or load the variable from the checkpoint file
typedef struct TF_CheckpointReader TF_CheckpointReader;
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
const char* filename, TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
TF_CheckpointReader* reader);
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
TF_CheckpointReader* reader, const char* name);
// Get the variable name at the given index
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
TF_CheckpointReader* reader, int index);
// Get the number of variable in the checkpoint
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
// Get the DataType of a variable
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
TF_CheckpointReader* reader, const char* name);
// Read the shape of a variable and write to `dims`
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
TF_Status* status);
// Get the number of dimension of a variable
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
TF_CheckpointReader* reader, const char* name);
// Load the weight of a variable
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
TF_CheckpointReader* reader, const char* name, TF_Status* status);
// TF_NewAttrBuilder() returns an object that you can set attributes on as // TF_NewAttrBuilder() returns an object that you can set attributes on as
// though it were an op. This allows querying properties of that op for // though it were an op. This allows querying properties of that op for
// type-checking purposes like if the op will run on a particular device type. // type-checking purposes like if the op will run on a particular device type.

View File

@ -62,8 +62,8 @@ protocol: "grpc"
TF_Buffer* null_result = TF_Buffer* null_result =
TFE_GetServerDef(malformed_text_proto.c_str(), status); TFE_GetServerDef(malformed_text_proto.c_str(), status);
EXPECT_NE(TF_GetCode(status), TF_OK); EXPECT_NE(TF_GetCode(status), TF_OK);
EXPECT_TRUE(tensorflow::str_util::StrContains( EXPECT_TRUE(absl::StrContains(TF_Message(status),
TF_Message(status), "Invalid text proto for ServerDef")); "Invalid text proto for ServerDef"));
EXPECT_EQ(null_result, nullptr); EXPECT_EQ(null_result, nullptr);
// Cleanup // Cleanup

View File

@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test {
const std::unordered_set<string>& nodes) { const std::unordered_set<string>& nodes) {
ASSERT_EQ(nodes.size(), fdef.node_def_size()) ASSERT_EQ(nodes.size(), fdef.node_def_size())
<< "Got unexpected number of nodes. Expected: [" << "Got unexpected number of nodes. Expected: ["
<< str_util::Join(nodes, ", ") << absl::StrJoin(nodes, ", ")
<< "] Actual nodes in fdef: " << fdef.DebugString(); << "] Actual nodes in fdef: " << fdef.DebugString();
for (const NodeDef& node_def : fdef.node_def()) { for (const NodeDef& node_def : fdef.node_def()) {
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end()) ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())

View File

@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace { namespace {
static void ExpectHasSubstr(StringPiece s, StringPiece expected) { static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected)) EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'"; << "'" << s << "' does not contain '" << expected << "'";
} }

View File

@ -1,6 +1,8 @@
# Experimental extensions to the C API for eager execution of kernels. # Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0 package(
licenses = ["notice"], # Apache 2.0
)
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",

View File

@ -30,7 +30,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA #endif // TENSORFLOW_EAGER_USE_XLA
@ -135,11 +137,12 @@ tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id, const std::vector<string>& remote_workers, int64 rendezvous_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def, int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) { for (int i = 0; i < remote_workers.size(); i++) {
const string& remote_worker = remote_workers[i]; const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response; tensorflow::eager::CreateContextResponse response;
request.set_rendezvous_id(rendezvous_id); request.set_rendezvous_id(rendezvous_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -221,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
remote_workers, grpc_server->master_env()->worker_cache, remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr)); &remote_device_mgr));
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache = std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
grpc_server->channel_cache(); grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers( std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
@ -230,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def, remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts)); remote_eager_workers.get(), ctx->context->Async(), base_request,
&remote_contexts));
tensorflow::RemoteRendezvous* r = tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, true)); session_name, server_def, base_request.cluster_device_attributes(),
true));
std::shared_ptr<tensorflow::WorkerSession> worker_session; std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -250,9 +272,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr; auto* device_mgr = grpc_server->worker_env()->device_mgr;
return ctx->context->InitializeRemote( return ctx->context->InitializeRemote(
std::move(server), std::move(remote_eager_workers), std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_device_mgr), remote_contexts, r, device_mgr, std::move(remote_eager_workers), std::move(remote_device_mgr),
keep_alive_secs); remote_contexts, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get());
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
@ -970,6 +993,23 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
return t; return t;
} }
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status) {
// TensorHandles created by PyFuncOp lack context and therefore could
// not be copied.
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
tensorflow::TensorHandle* handle;
status->status = tensorflow::EagerCopyToDevice(
h->handle, h->handle->Context(), "CPU:0", &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
} else {
return nullptr;
}
}
return h;
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
TFE_ContextAsyncWait(ctx, status); TFE_ContextAsyncWait(ctx, status);

View File

@ -462,6 +462,9 @@ class Tensor;
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status); TFE_TensorHandle* h, TF_Status* status);
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t); TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
#endif #endif

View File

@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
} else { } else {
VLOG(3) << "Fully padded shape of [" VLOG(3) << "Fully padded shape of ["
<< tensorflow::str_util::Join(shape_to_log, ", ") << "] is " << absl::StrJoin(shape_to_log, ", ") << "] is "
<< padded_shape.DebugString(); << padded_shape.DebugString();
} }
} }

View File

@ -33,7 +33,7 @@ namespace tensorflow {
namespace { namespace {
static bool HasSubstr(absl::string_view base, absl::string_view substr) { static bool HasSubstr(absl::string_view base, absl::string_view substr) {
bool ok = str_util::StrContains(base, substr); bool ok = absl::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr; EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok; return ok;
} }

View File

@ -1408,6 +1408,10 @@ void FunctionDefAndExecute(bool async) {
status); status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (bool clear_cache : {true, false, true}) {
if (clear_cache) {
TFE_ContextClearCaches(ctx);
}
TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* retval[1] = {nullptr}; TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1; int num_retvals = 1;
@ -1431,6 +1435,7 @@ void FunctionDefAndExecute(bool async) {
EXPECT_EQ(10, product[1]); EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]); EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]); EXPECT_EQ(22, product[3]);
}
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);

View File

@ -1,7 +1,9 @@
# Description: # Description:
# Experimental C APIs for TensorFlow. # Experimental C APIs for TensorFlow.
licenses(["notice"]) # Apache 2.0 package(
licenses = ["notice"], # Apache 2.0
)
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",

View File

@ -6,10 +6,9 @@ load(
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
tf_kernel_library( tf_kernel_library(
name = "bitcast_op", name = "bitcast_op",
prefix = "bitcast_op", prefix = "bitcast_op",

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
#define TENSORFLOW_C_TF_ATTRTYPE_H_
#ifdef __cplusplus
extern "C" {
#endif
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_ATTRTYPE_H_

View File

@ -0,0 +1,23 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/types.h"
size_t TF_DataTypeSize(TF_DataType dt) {
return static_cast<size_t>(
tensorflow::DataTypeSize(static_cast<tensorflow::DataType>(dt)));
}

View File

@ -0,0 +1,83 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_DATATYPE_H_
#define TENSORFLOW_C_TF_DATATYPE_H_
#include <stddef.h>
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
// The enum values here are identical to corresponding values in types.proto.
typedef enum TF_DataType {
TF_FLOAT = 1,
TF_DOUBLE = 2,
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
TF_UINT8 = 4,
TF_INT16 = 5,
TF_INT8 = 6,
TF_STRING = 7,
TF_COMPLEX64 = 8, // Single-precision complex
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
TF_INT64 = 9,
TF_BOOL = 10,
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
// to the given TF_DataType enum value. Returns 0 for variable length types
// (eg. TF_STRING) or on failure.
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_DATATYPE_H_

42
tensorflow/c/tf_status.cc Normal file
View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/lib/core/status.h"
using ::tensorflow::Status;
using ::tensorflow::error::Code;
TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}
TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}
const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}

88
tensorflow/c/tf_status.h Normal file
View File

@ -0,0 +1,88 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_STATUS_H_
#define TENSORFLOW_C_TF_STATUS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TF_Status TF_Status;
// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
typedef enum TF_Code {
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
} TF_Code;
// --------------------------------------------------------------------------
// Return a new status object.
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
// Record <code, msg> in *s. Any previous information is lost.
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
const char* msg);
// Return the code record in *s.
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
// Return a pointer to the (null-terminated) error message in *s. The
// return value points to memory that is only usable until the next
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
// TF_OK.
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_STATUS_H_

View File

@ -4,10 +4,9 @@
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
filegroup( filegroup(
name = "srcs", name = "srcs",
srcs = [ srcs = [
@ -638,6 +637,7 @@ cc_library(
"//tensorflow/core:op_gen_lib", "//tensorflow/core:op_gen_lib",
"//tensorflow/core:proto_text", "//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -657,6 +657,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/cc/framework/cc_op_gen.h"
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h" #include "absl/strings/escaping.h"
#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) {
} }
string PrintString(const string& str) { string PrintString(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\""); return strings::StrCat("\"", absl::CEscape(str), "\"");
} }
string PrintTensorShape(const TensorShapeProto& shape_proto) { string PrintTensorShape(const TensorShapeProto& shape_proto) {
@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) {
string ret; string ret;
for (int64 i = 0; i < num_elts; ++i) { for (int64 i = 0; i < num_elts; ++i) {
if (i > 0) strings::StrAppend(&ret, " "); if (i > 0) strings::StrAppend(&ret, " ");
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i))); strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
} }
return ret; return ret;
} }

View File

@ -62,12 +62,12 @@ op {
)"; )";
void ExpectHasSubstr(StringPiece s, StringPiece expected) { void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected)) EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'"; << "'" << s << "' does not contain '" << expected << "'";
} }
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
EXPECT_FALSE(str_util::StrContains(s, expected)) EXPECT_FALSE(absl::StrContains(s, expected))
<< "'" << s << "' contains '" << expected << "'"; << "'" << s << "' contains '" << expected << "'";
} }

View File

@ -275,7 +275,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
for (const string& entry : node_constraints) { for (const string& entry : node_constraints) {
StringPiece s(entry); StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
current_constraints.emplace(s); current_constraints.emplace(s);
} }
} }

View File

@ -3,10 +3,9 @@
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) exports_files(["LICENSE"])
load( load(

View File

@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
const Status status = LoadSavedModelInternal(session_options, run_options, const Status status = LoadSavedModelInternal(session_options, run_options,
export_dir, tags, bundle); export_dir, tags, bundle);
auto log_and_count = [&](const string& status_str) { auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
<< " }; Status: " << status_str << ". Took " << " }; Status: " << status_str << ". Took "
<< GetLatencyMicroseconds(start_microseconds) << " microseconds."; << GetLatencyMicroseconds(start_microseconds) << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);

View File

@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) {
Status st = LoadSavedModel(session_options, run_options, export_dir, Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle); {"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok()); EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(absl::StrContains(
st.error_message(), st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }")) "Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message(); << st.error_message();
@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
Status st = LoadSavedModel(session_options, run_options, export_dir, Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe, "missing-tag"}, &bundle); {kSavedModelTagServe, "missing-tag"}, &bundle);
EXPECT_FALSE(st.ok()); EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(absl::StrContains(
st.error_message(), st.error_message(),
"Could not find meta graph def matching supplied tags: ")) "Could not find meta graph def matching supplied tags: "))
<< st.error_message(); << st.error_message();
@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
Status st = LoadSavedModel(session_options, run_options, export_dir, Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle); {kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok()); EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
<< st.error_message(); << st.error_message();
} }

View File

@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
Status FindMetaGraphDef(const SavedModel& saved_model_proto, Status FindMetaGraphDef(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags, const std::unordered_set<string>& tags,
MetaGraphDef* meta_graph_def) { MetaGraphDef* meta_graph_def) {
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ") LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
<< " }"; << " }";
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
// Get tags from the graph_def. // Get tags from the graph_def.
@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
error::Code::NOT_FOUND, error::Code::NOT_FOUND,
strings::StrCat( strings::StrCat(
"Could not find meta graph def matching supplied tags: { ", "Could not find meta graph def matching supplied tags: { ",
str_util::Join(tags, " "), absl::StrJoin(tags, " "),
" }. To inspect available tag-sets in the SavedModel, please " " }. To inspect available tag-sets in the SavedModel, please "
"use the SavedModel CLI: `saved_model_cli`")); "use the SavedModel CLI: `saved_model_cli`"));
} }

View File

@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) {
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def); &meta_graph_def);
EXPECT_FALSE(st.ok()); EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(absl::StrContains(
st.error_message(), st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }")) "Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message(); << st.error_message();
@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
Status st = ReadMetaGraphDefFromSavedModel( Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok()); EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(absl::StrContains(
st.error_message(), st.error_message(),
"Could not find meta graph def matching supplied tags: ")) "Could not find meta graph def matching supplied tags: "))
<< st.error_message(); << st.error_message();

View File

@ -3,10 +3,9 @@
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) exports_files(["LICENSE"])
load( load(

View File

@ -167,8 +167,7 @@ namespace {
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
int32* dst) { int32* dst) {
if (tensorflow::str_util::ConsumePrefix(&arg, flag) && if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
tensorflow::str_util::ConsumePrefix(&arg, "=")) {
char extra; char extra;
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
} }
@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool* dst) { bool* dst) {
if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { if (absl::ConsumePrefix(&arg, flag)) {
if (arg.empty()) { if (arg.empty()) {
*dst = true; *dst = true;
return true; return true;

View File

@ -1,7 +1,6 @@
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = ["//visibility:private"], default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
) )
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")

View File

@ -1,4 +1,12 @@
licenses(["notice"]) # Apache 2.0 package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
licenses = ["notice"], # Apache 2.0
)
package_group( package_group(
name = "internal", name = "internal",
@ -14,15 +22,6 @@ package_group(
], ],
) )
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@ -200,6 +199,7 @@ cc_library(
"//tensorflow/core/kernels:host_constant_op", "//tensorflow/core/kernels:host_constant_op",
"//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:logging_ops",
"//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:resource_variable_ops",
@ -257,10 +257,8 @@ cc_library(
name = "xla_launch_util", name = "xla_launch_util",
srcs = ["xla_launch_util.cc"], srcs = ["xla_launch_util.cc"],
hdrs = ["xla_launch_util.h"], hdrs = ["xla_launch_util.h"],
# TODO(skyewm): remove this once XlaAllocator is factored out.
visibility = [ visibility = [
":internal", ":internal",
"//tensorflow/compiler/xla/python:__pkg__",
], ],
deps = [ deps = [
":common", ":common",

View File

@ -244,11 +244,11 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
"resource variable op in called function"); "resource variable op in called function");
} }
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) { if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
return LogNotCompilableAndReturn(node, "operation with correctness issues"); return LogNotCompilableAndReturn(node, "operation with correctness issues");
} }
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) { if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
return LogNotCompilableAndReturn(node, "slow operation"); return LogNotCompilableAndReturn(node, "slow operation");
} }
@ -268,8 +268,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
registration.elide_assert_and_checknumerics; registration.elide_assert_and_checknumerics;
op_filter.allow_ops_producing_or_consuming_variant = op_filter.allow_ops_producing_or_consuming_variant =
registration.cluster_variant_ops; registration.cluster_variant_ops;
op_filter.allow_slow_and_inaccurate_ops = op_filter.allow_slow_ops = registration.cluster_slow_ops;
registration.cluster_slow_and_inaccurate_ops; op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
return op_filter; return op_filter;
} }

View File

@ -97,9 +97,12 @@ class RecursiveCompilabilityChecker {
// live-out DT_VARIANT values. // live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant; bool allow_ops_producing_or_consuming_variant;
// Whether ops known to be slow or to have correctness issues should be // Whether ops known to be slow on XLA-GPU should be considered compilable..
// auto-clustered. bool allow_slow_ops;
bool allow_slow_and_inaccurate_ops;
// Whether ops known to have numerical accuracy issues should be considered
// compilable..
bool allow_inaccurate_ops;
}; };
RecursiveCompilabilityChecker(const OperationFilter* op_filter, RecursiveCompilabilityChecker(const OperationFilter* op_filter,

View File

@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -43,12 +45,12 @@ limitations under the License.
// ------------------------------------------ // ------------------------------------------
// //
// If we ignore cycles for a moment, computing predicates is fairly // If we ignore cycles for a moment, computing predicates is fairly
// straightforward. We traverse the graph in RPO, mapping each node to a // straightforward. We traverse the graph in a topological order, mapping each
// predicate based on the predicates its inputs are mapped to. For instance a // node to a predicate based on the predicates its inputs are mapped to. For
// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)). // instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
// Roughtly speaking, we abstract interpret each node on the "liveness" domain, // PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
// where values in the domain represent if a tensor carries a dead signal or // the "liveness" domain, where values in the domain represent if a tensor
// not. // carries a dead signal or not.
// //
// //
// DEALING WITH CYCLES // DEALING WITH CYCLES
@ -85,22 +87,28 @@ limitations under the License.
// true on iteration 0, 1, 2 respectively. This is made more precise in the // true on iteration 0, 1, 2 respectively. This is made more precise in the
// comment on the AndRecurrence class. // comment on the AndRecurrence class.
// //
// The general algorithm that deals with cycles does two RPO (reverse post // The general algorithm that deals with cycles does two topological-order
// order) passes over the graph. On the first pass it assigns a symbolic // iterations over the graph. On the first iteration it assigns a symbolic
// predicate to merge nodes with backedges. On the second pass it tries to // predicate to merge nodes with backedges. On the second iteration it tries
// pattern matche the predicates for the backedges of these merges and infer an // to pattern match the predicates for the backedges of these merges and infer
// AndRecurrence for the merge. // an AndRecurrence for the merge. In other words, we do a data flow analysis
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
// sufficient to converge.
// //
// In other words, we do a pessimistic data flow analysis where the data-flow // We first do an optimisitc analysis and, if it does not converge, we then fall
// lattice has two elements, Symbolic and NonSymbolic with Symbolic > // back to a pessimistic analysis. The optimistic analysis assigns the same
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to // symbolic predicate to all the merge nodes whose preceding enter nodes have
// converge. We don't do an optimistic data flow analysis to make pattern // the same frame name on the first iteration. On the second iteration, if all
// matching easier: if we assigned the predicate of the initial value to the // the merge nodes are pattern matched into the same AndRecurrence predicate
// merge during the first pass, on the second pass the backedge may see a // instance, the optimistic assignment of the same symbolic predicate is correct
// simplified value that would be difficult to pattern match. // and the analyzed result is taken.
// //
// We still use symbolic predicates for merges for which we can't pattern match // Otherwise, if the optimistic analysis fails to converge, we then obtain the
// on the backedge predicate. This is conservatively correct. // result by falling back to the pessimistic analysis which assigns a unique
// symbolic predicate to each merge on the first iteration. We still use
// symbolic predicates for merges for which we can't pattern match on the
// backedge predicate. This is conservatively correct.
namespace tensorflow { namespace tensorflow {
@ -636,6 +644,35 @@ Predicate* PredicateFactory::MakeAndOrImpl(
negated_ops.insert(negated_op); negated_ops.insert(negated_op);
} }
// Simplify {S,&,X} & ~X & ... => S & ...
if (is_and) {
absl::flat_hash_set<Predicate*> to_remove;
std::vector<Predicate*> to_add;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kAndRecurrence) {
auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
if (negated_ops.contains(and_rec->step())) {
// Remove and_rec and ~X and insert S. Note that checking the
// existence of ~X through negated_ops is sufficient since it makes
// sure the predicate is in the input operands. It does not need to
// be in simplified_ops if it was already cancelled out.
to_remove.insert(and_rec);
to_remove.insert(MakeNotPredicate(and_rec->step()));
to_add.push_back(and_rec->start());
}
}
}
auto it = simplified_ops.begin();
while (it != simplified_ops.end()) {
if (to_remove.contains(*it)) {
it = simplified_ops.erase(it);
} else {
++it;
}
}
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
}
// If all ops contain the same subop, then factor it out thanks to the // If all ops contain the same subop, then factor it out thanks to the
// distributive property. Such as: // distributive property. Such as:
// - (A & B) | (A & C) | (A & D) => A & (B | C | D) // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
@ -699,8 +736,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
explicit DeadnessAnalysisImpl(const Graph* graph) explicit DeadnessAnalysisImpl(const Graph* graph)
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {} : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate(); Status Populate(bool enable_optimistic);
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo); Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
bool* success);
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor( StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override; Node* n, int oidx) const override;
void Print() const override; void Print() const override;
@ -742,16 +780,29 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
} }
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit); Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
Status HandleMerge(Node* n, std::vector<bool>* should_revisit); Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode);
Status HandleRecv(Node* n, std::vector<bool>* should_revisit); Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit); Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
Status HandleNode(Node* n, std::vector<bool>* should_revisit); Status HandleNode(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode = false);
Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
bool IsRootEnter(const Node* n) const {
return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
}
bool IsRootExit(const Node* n) const {
return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
}
const Graph& graph_; const Graph& graph_;
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_; absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_; PredicateFactory predicate_factory_;
std::vector<ControlFlowInfo> control_flow_info_; std::vector<ControlFlowInfo> control_flow_info_;
bool vlog_; bool vlog_;
absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
}; };
TensorId InputEdgeToTensorId(const Edge* e) { TensorId InputEdgeToTensorId(const Edge* e) {
@ -914,10 +965,32 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
return Status::OK(); return Status::OK();
} }
// If the node is inside some frames, get the name of the outermost non-empty
// frame. Otherwise, get an empty frame name.
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
absl::string_view* frame) {
int depth = 0;
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
while (!cfi_iter->parent_frame->IsSource()) {
n = cfi_iter->parent_frame;
cfi_iter = &cfi_infos[n->id()];
if (depth++ > 5000) {
return errors::Internal(
"Frame of depth > 5000: Probably malformed graph or a bug in "
"BuildControlFlowInfo");
}
}
*frame = cfi_iter->frame_name;
return Status::OK();
}
} // namespace } // namespace
Status DeadnessAnalysisImpl::HandleMerge(Node* n, Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<bool>* should_revisit) { std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
// Merge ignores deadness of its control inputs. A merge that isn't the // Merge ignores deadness of its control inputs. A merge that isn't the
// target of a backedge has is alive iff any of its data inputs are. The // target of a backedge has is alive iff any of its data inputs are. The
// liveness of a merge that is the target of a backedge can sometimes be // liveness of a merge that is the target of a backedge can sometimes be
@ -937,8 +1010,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// We're visiting this merge for the first time and it has an unvisited // We're visiting this merge for the first time and it has an unvisited
// backedge. // backedge.
Predicate* input_data_pred; Predicate* input_data_pred;
if (use_optimistic_mode) {
// In the optimistic mode, we use the first-seen Merge node per
// frame as the representative Merge node. It is just convenient and
// does not affect the result after pattern-matching into the
// AndRecurrence form.
absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
auto insert_result = frame_to_merge_node_.insert({frame_name, n});
Node* representative = insert_result.first->second;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
representative, /*output_idx=*/0, /*must_be_true=*/false,
&input_data_pred));
} else {
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
}
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit); should_revisit);
@ -948,7 +1034,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<Predicate*> input_preds; std::vector<Predicate*> input_preds;
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
// We're visiting this merge for the first time and it is a acyclic merge. // We're visiting this merge for the first time and it is an acyclic merge.
Predicate* input_data_pred = Predicate* input_data_pred =
predicate_factory_.MakeOrPredicate(input_preds); predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
@ -1022,11 +1108,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
} }
Status DeadnessAnalysisImpl::HandleNode(Node* n, Status DeadnessAnalysisImpl::HandleNode(Node* n,
std::vector<bool>* should_revisit) { std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
if (n->IsSwitch()) { if (n->IsSwitch()) {
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
} else if (n->IsMerge()) { } else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
} else if (n->IsControlTrigger()) { } else if (n->IsControlTrigger()) {
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
nullptr); nullptr);
@ -1040,17 +1127,129 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
return Status::OK(); return Status::OK();
} }
Status DeadnessAnalysisImpl::Populate() { // Compute a special topological order for the Graph, where nodes having the
std::vector<Node*> rpo; // same root frame are placed adjacent to each other. The traversal uses a
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), // variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
/*edge_filter=*/[](const Edge& edge) { // many inputs of each node are ready; a node is ready to be scheduled if all
return !edge.src()->IsNextIteration(); // of its inputs are ready.
}); // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
return PopulateWithReversePostOrder(rpo); Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
std::vector<Node*>* order) {
absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
Node* src_node = graph_.source_node();
for (const auto* node : graph_.op_nodes()) {
const ControlFlowInfo& cf = control_flow_info_[node->id()];
if (IsRootEnter(node)) {
// Since we care only the root-level frame, full frame names are the same
// as frame names.
++num_enters_for_frame[cf.frame_name];
} else if (IsRootExit(node)) {
++num_exits_for_frame[cf.frame_name];
}
// Edge NextIteration->Merge is counted before starting the traveral to
// break the backedges.
if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) {
++num_ready_inputs[node->id()];
}
}
}
}
// dequeue is used to ensure that the nodes are first-in-first-out. This
// order guarantees that the exits in the ready queue are visited before
// nodes that will become ready in the future.
std::deque<Node*> ready;
ready.push_back(src_node);
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
// the ready enters/exits before they are moved to the `ready` queue for
// controlling the start and end of a processing frame.
absl::flat_hash_map<absl::string_view, std::vector<Node*>>
ready_enters_per_frame;
// Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough.
std::vector<Node*> ready_exits;
while (!ready.empty()) {
Node* curr_node = ready.front();
ready.pop_front();
VLOG(4) << "Visiting " << curr_node->name();
order->push_back(curr_node);
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
if (IsNextIteration(curr_node) && IsMerge(out)) {
// Edge NextIteration->Merge has been counted.
continue;
}
++num_ready_inputs[out->id()];
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
absl::string_view frame_name = control_flow_info_[out_id].frame_name;
if (IsRootEnter(out)) {
ready_enters_per_frame[frame_name].push_back(out);
} else if (IsRootExit(out)) {
ready_exits.push_back(out);
} else {
ready.push_back(out);
}
}
if (ready.empty()) {
// Try moving nodes from ready_enters_per_frame and ready_exits to
// `ready`.
if (!ready_exits.empty()) {
// If there are nodes in ready_exits we must process them before
// processing ready_enters_per_frame to make sure all nodes in the
// currently processing frame are visited before starting processing
// other frames.
absl::string_view frame_name =
control_flow_info_[ready_exits.front()->id()].frame_name;
CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
ready_exits.clear();
} else {
// Otherwise, try moving nodes from ready_enters to `ready`.
for (auto iter = ready_enters_per_frame.begin();
iter != ready_enters_per_frame.end(); ++iter) {
absl::string_view frame_name = iter->first;
const std::vector<Node*>& ready_enters = iter->second;
if (ready_enters.size() == num_enters_for_frame[frame_name]) {
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
ready_enters_per_frame.erase(iter);
break;
}
}
}
}
}
if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
return errors::InvalidArgument(
"Some enters/exits have never been visited in the traversal."
" Most probably the input graph is malformed.");
}
return Status::OK();
} }
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( // We populate the nodes along a special topological order where nodes having
absl::Span<Node* const> rpo) { // the same root frame are placed adjacent to each other. This grouping enables
// processing the graph per root frame at a time and guarantees that when a root
// frame is being processed, nodes in the downstream frames have not yet been
// processed. This property is important because we need to process an entire
// frame to know whether the optimistic mode converges or not. In other words,
// nodes in the downstream frames shall not be populated until all of its
// upstream frames are populated. In effect, this order enables processing each
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
// (root) frame. Note that we don't separate while loops belonging to the same
// nested while, as there is no clean cut for separating them in the topological
// order.
Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
std::vector<string> unreachable_nodes; std::vector<string> unreachable_nodes;
// Compute the loop structure of the graph. // Compute the loop structure of the graph.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -1069,14 +1268,63 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
absl::StrJoin(unreachable_nodes, ", ")); absl::StrJoin(unreachable_nodes, ", "));
} }
std::vector<Node*> topo;
TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
size_t frame_start = 0;
while (frame_start < topo.size()) {
// Batching nodes who have the same root frame.
absl::string_view cur_frame_name;
TF_RETURN_IF_ERROR(
GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
size_t frame_end = frame_start;
for (size_t i = frame_start + 1; i < topo.size(); ++i) {
absl::string_view i_frame_name;
TF_RETURN_IF_ERROR(
GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
if (i_frame_name == cur_frame_name) {
frame_end = i;
} else {
break;
}
}
absl::Span<Node*> sub_topo(topo.data() + frame_start,
/*length=*/frame_end - frame_start + 1);
frame_start = frame_end + 1;
// First, try the optimistic mode.
bool success = false;
if (enable_optimistic && !cur_frame_name.empty()) {
TF_RETURN_IF_ERROR(
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
}
if (!success) {
// The optimistic mode does not converge. Let's fall back to the
// pessimistic mode.
TF_RETURN_IF_ERROR(
PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
}
VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
<< (success ? "optimistic" : "pessimistic") << " mode.";
}
return Status::OK();
}
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
bool use_optimistic_mode,
bool* success) {
CHECK(use_optimistic_mode && success != nullptr ||
!use_optimistic_mode && success == nullptr);
// This an abstract interpretation over the deadness propagation semantics of // This an abstract interpretation over the deadness propagation semantics of
// the graph executor. // the graph executor.
// //
// We iterate over the graph twice, each time in RPO. On the first iteration // We iterate over the graph twice, each time in a topological order. On the
// merge nodes with backedges are mapped to symbolic predicates. On the // first iteration merge nodes with backedges are mapped to symbolic
// second iteration we use the predicates assigned to the backedges in the // predicates. On the second iteration we use the predicates assigned to the
// previous iteration to infer a more precise predicate for the backedge merge // backedges in the previous iteration to infer a more precise predicate for
// nodes and all the nodes that transitively use it. // the backedge merge nodes and all the nodes that transitively use it.
// //
// We don't track the output indices for should_revisit. Instead, putting a // We don't track the output indices for should_revisit. Instead, putting a
// node in `should_revisit` denotes that the deadness flowing out from any // node in `should_revisit` denotes that the deadness flowing out from any
@ -1086,9 +1334,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
// delta should not change in the second iteration. // delta should not change in the second iteration.
std::vector<bool> should_revisit; std::vector<bool> should_revisit;
should_revisit.resize(graph_.num_node_ids()); should_revisit.resize(graph_.num_node_ids());
for (Node* n : rpo) { for (Node* n : topo) {
VLOG(4) << "Visiting " << n->name(); VLOG(4) << "Visiting " << n->name();
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); TF_RETURN_IF_ERROR(
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
if (n->IsNextIteration()) { if (n->IsNextIteration()) {
// If this is a backedge for a merge node then remember to reprocess the // If this is a backedge for a merge node then remember to reprocess the
// merge the next time we run. // merge the next time we run.
@ -1100,11 +1349,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
} }
} }
for (Node* n : rpo) { for (Node* n : topo) {
// The nodes added to should_revisit in the previous loop need to be // The nodes added to should_revisit in the previous loop need to be
// revisited now. Reprocesing these initial nodes may add *their* consumers // revisited now. Reprocesing these initial nodes may add *their* consumers
// to should_revisit, and these newly added nodes will also be processed by // to should_revisit, and these newly added nodes will also be processed by
// this very same loop. Since we're traversing the graph in reverse post // this very same loop. Since we're traversing the graph in topological
// order (producers before consumers) and HandleNode(n) can only ever add // order (producers before consumers) and HandleNode(n) can only ever add
// n's consumers to should_revisit, we won't "miss" an addition to // n's consumers to should_revisit, we won't "miss" an addition to
// should_revisit. // should_revisit.
@ -1114,6 +1363,71 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
} }
} }
// Check if the optimistic analysis converges. Specifically, check whether
// all the predicates of the merge nodes in the same frame are the same. If
// yes, report success. If not, report failure and clear the assigned
// predicates.
if (use_optimistic_mode) {
bool is_converged = true;
absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
for (Node* n : topo) {
if (!n->IsMerge()) {
continue;
}
const Edge* e;
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
if (e == nullptr) {
// Skip acyclic merge nodes.
continue;
}
Node* merge = n;
// Note that here uses frame names instead of root frame names. In the
// case of a nested while loop, each level of while loops can have merges
// with different predicate instances, while the merge nodes on the same
// level must have the same predicate instances.
absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
auto it = predicate_map_.find(TensorId(merge->name(), 0));
Predicate* merge_pred = it->second;
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge because node " << merge->name()
<< " cannot be mapped into the AndRecurrence form.";
break;
}
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
if (!insert_result.second) {
// If we have already seen this frame name, verify the predicate is the
// same as the previously seen one's.
Predicate* curr_andrec = merge_pred;
Predicate* prev_andrec = insert_result.first->second;
if (curr_andrec != prev_andrec) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge. Seeing different Merge predicates: \n"
<< curr_andrec->ToString() << " and \n"
<< prev_andrec->ToString();
break;
}
}
}
// Clear the assigned predicates if the optimistic mode does not converge.
if (!is_converged) {
for (Node* n : topo) {
for (int oid = 0; oid < n->num_outputs(); ++oid) {
predicate_map_.erase(TensorId(n->name(), oid));
}
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
}
}
if (success != nullptr) {
*success = is_converged;
}
}
return Status::OK(); return Status::OK();
} }
@ -1149,7 +1463,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) { const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
std::unique_ptr<DeadnessAnalysisImpl> analysis( std::unique_ptr<DeadnessAnalysisImpl> analysis(
new DeadnessAnalysisImpl(&graph)); new DeadnessAnalysisImpl(&graph));
TF_RETURN_IF_ERROR(analysis->Populate()); TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
analysis->Print(); analysis->Print();
@ -1170,22 +1484,18 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
} }
namespace deadness_analysis_internal { namespace deadness_analysis_internal {
Status ComputePredicates(const Graph& graph, Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
PredicateMapTy* out_predicate_map) { bool enable_optimistic) {
DeadnessAnalysisImpl impl(&graph); DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.Populate()); TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
*out_predicate_map = impl.PredicateMapAsString(); *out_predicate_map = impl.PredicateMapAsString();
return Status::OK(); return Status::OK();
} }
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
} // namespace deadness_analysis_internal } // namespace deadness_analysis_internal
string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
return static_cast<Predicate*>(predicate.pred_)->ToString();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -82,6 +82,8 @@ class DeadnessAnalysis {
virtual void Print() const = 0; virtual void Print() const = 0;
virtual ~DeadnessAnalysis(); virtual ~DeadnessAnalysis();
string DebugString(DeadnessPredicate predicate) const;
// Run the deadness analysis over `graph` and returns an error or a populated // Run the deadness analysis over `graph` and returns an error or a populated
// instance of DeadnessAnalysis in `result`. // instance of DeadnessAnalysis in `result`.
static Status Run(const Graph& graph, static Status Run(const Graph& graph,

View File

@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For // Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only. // testing purposes only.
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>; using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool enable_optimistic = true);
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only. Makes deadness analysis visit the graph in the order
// specified in `reverse_post_order` which must be a valid RPO for the graph
// minus NextIteration->Merge edges.
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal } // namespace deadness_analysis_internal
} // namespace tensorflow } // namespace tensorflow

View File

@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
} }
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph()); FixupSourceAndSinkEdges(root.graph());
// To make deadness analysis think that dependent_iv is a loop we need an RPO
// that visits the merge before the backedge. This is a legal RPO for
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
// Right now dependent_iv has an edge from Merge to NextIteration so do the
// RPO with this edge in place. Then remove this edge to get our test case.
std::vector<Node*> rpo;
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
TF_ASSERT_OK(root.graph()->UpdateEdge( TF_ASSERT_OK(root.graph()->UpdateEdge(
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"{#true,&,*iv0/cond:0}<frame>");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"div0/iv:0"); "div0/iv:0");
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
} }
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
"{(*iv_outer/cond:0 & "
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
"cond:0}<inner_loop;outer_loop>");
// enable_optimistic = true or not should produce the same results because
// of fallback. However, note that the order of iv_inner/cond:0 and
// iv_inner/iv:0 is different because the optimistic approach does not
// create predicates for all merges and it can change the predicate id and
// hence the symbol order.
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
"iv_inner/iv:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>"); "{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
"{{#true,&,(iv_outer/iv:0 & " "{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & " "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>"); "*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
"{{#true,&,(iv_outer/iv:0 & " predicate_map[ControlOutputFor(dependent_inner_iv0)]);
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(add0)], EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
"{{#true,&,(iv_outer/iv:0 & " predicate_map[ControlOutputFor(dependent_inner_iv0)]);
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
} }
} }
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
} }
} }
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_outer =
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
Output enter_constant_outer_loop = ops::internal::Enter(
root.WithOpName("constant_enter_outer_loop"),
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
ops::Switch inner_value(root.WithOpName("outer_is_live"),
enter_constant_outer_loop, iv_outer.loop_cond);
InductionVarInfo iv_inner = CreateInductionVariable(
root, "iv_inner", "inner_loop", inner_value.output_true);
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
div0_outer.induction_var);
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
div1_outer.induction_var);
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
"tensor_a", "sender", 0, "receiver");
Output capture_enter_outer = ops::internal::Enter(
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output capture_enter_inner = ops::internal::Enter(
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
capture_enter_inner);
TF_ASSERT_OK(root.graph()->UpdateEdge(
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
div1_inner.induction_var);
VLogGraphIfAsked(*root.graph());
{
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
}
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
DependentInductionVar div0 =
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
DependentInductionVar div1 =
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph());
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
div0.latch.output_true.node(), 0));
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
div1.latch.output_true.node(), 0));
VLogGraphIfAsked(*root.graph());
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
// This tests the rule {S,&,X} & ~X => S.
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
div1.latch.output_false.index()};
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
}
}
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
Scope root = Scope::NewRootScope().ExitOnError(); Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);

File diff suppressed because it is too large Load Diff

View File

@ -52,16 +52,6 @@ typedef std::function<Status(
// 'group_attribute' must be a string valued-attribute that names the new // 'group_attribute' must be a string valued-attribute that names the new
// functions to introduce. // functions to introduce.
// //
// 'outside_compilation_attribute' must be a string-valued attribute that is
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
// cluster within the subgraph. A cluster is formed from the set of nodes with
// the same value of outside_compilation_subgraph and group_attribute. The nodes
// in an outside_compilation cluster are left in the original graph. Edges
// crossing from the subgraph to an outside_compilation cluster nested in the
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
// crossing from an outside_compilation cluster into its enclosing subgraph are
// lifted into a SendFromHost/RecvFromHost pair of nodes.
//
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion. // function conversion.
// //
@ -74,10 +64,9 @@ typedef std::function<Status(
// dep from B. Originally D must run after C, post-transformation this // dep from B. Originally D must run after C, post-transformation this
// dependency is lost. // dependency is lost.
Status EncapsulateSubgraphsInFunctions( Status EncapsulateSubgraphsInFunctions(
string group_attribute, string outside_compilation_attribute, string group_attribute, const Graph& graph_in,
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out, std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate // The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via XlaLaunch operators. // subgraphs pass and that should in turn be compiled via XlaLaunch operators.

View File

@ -514,10 +514,10 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
std::unique_ptr<Graph> graph_out; std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions( s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
"_encapsulate", /*outside_compilation_attribute=*/"", *graph,
/*rewrite_subgraph_fn=*/{}, /*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph_out, lib_def.get()); /*reuse_existing_functions=*/false,
&graph_out, lib_def.get());
if (!s.ok()) return s; if (!s.ok()) return s;
std::unordered_map<string, XlaClusterInfo> clusters; std::unordered_map<string, XlaClusterInfo> clusters;
@ -746,7 +746,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph; std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", "", graph_before_encapsulation, "_cluster", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{}, /*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph, &library)); /*reuse_existing_functions=*/false, &graph, &library));
@ -798,7 +798,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0; int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "", graph_before, "_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/ /*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors, [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr, std::unique_ptr<Graph>* graph_ptr,
@ -843,7 +843,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0; int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "", graph_before, "_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/ /*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors, [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr, std::unique_ptr<Graph>* graph_ptr,
@ -1109,7 +1109,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
absl::Span<const string>( absl::Span<const string>(
{"_xla_token_arg_node", {"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}, "outside_compilation_O1_host_compute"})}},
{"F"}}, {"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"}, {{"outside_compilation_O1_host_compute"},
"XlaHostCompute", "XlaHostCompute",
{"C:o:0", "D:o:0"}, {"C:o:0", "D:o:0"},
@ -1990,7 +1990,8 @@ TEST(EncapsulateSubgraphsTest,
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>( absl::Span<const string>(
{"_xla_token_arg_node", {"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}}, "outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
}, },
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}}); {"h_0_retval_retval", "H:o:0"}});
@ -2117,7 +2118,8 @@ TEST(EncapsulateSubgraphsTest,
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>( absl::Span<const string>(
{"_xla_token_arg_node", {"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}}, "outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"}, {{"outside_compilation_O1_host_compute"},
"XlaHostCompute", "XlaHostCompute",
{"D:o:0"}, {"D:o:0"},
@ -2267,7 +2269,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"_xla_token_input_nodes", {"_xla_token_input_nodes",
absl::Span<const string>( absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}}, {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{}}, {"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"}, {{"outside_compilation_O3_host_compute"},
"XlaHostCompute", "XlaHostCompute",
{"D:o:0"}, {"D:o:0"},
@ -2282,7 +2284,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
absl::Span<const string>({"_xla_token_arg_node", absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute", "outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}}, "outside_compilation_O2_host_compute"})}},
{}}}, {"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}}); {"h_0_retval_retval", "H:o:0"}});

View File

@ -231,9 +231,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
auto output = absl::make_unique<Graph>((*graph)->op_registry()); auto output = absl::make_unique<Graph>((*graph)->op_registry());
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
EncapsulateSubgraphsInFunctions( EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
kXlaClusterAttr, "", **graph, RewriteSubgraph, /*reuse_existing_functions=*/true,
/*reuse_existing_functions=*/true, &output, flib_def), &output, flib_def),
"EncapsulateXlaComputationsPass failed"); "EncapsulateXlaComputationsPass failed");
graph->swap(output); graph->swap(output);
return Status::OK(); return Status::OK();

View File

@ -393,7 +393,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
// Replace outside compilation function call node with XlaHostCompute node. // Replace outside compilation function call node with XlaHostCompute node.
// If the function call node has no input/output edges, we will just remove it // If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node. // and not create a XlaHostCompute node.
Status ReplaceOrRemoveOutsideCompilationCallNode( xla::StatusOr<Node*> ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core, Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) { const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
// If the function call node has no input/output edges, just remove it. // If the function call node has no input/output edges, just remove it.
@ -413,7 +413,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
if (!has_edge) { if (!has_edge) {
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString(); VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
g->RemoveNode(call_node); g->RemoveNode(call_node);
return Status::OK(); return nullptr;
} }
// Build XlaHostCompute NodeDef. // Build XlaHostCompute NodeDef.
@ -424,7 +424,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
ReplaceNode(g, call_node, node_def)); ReplaceNode(g, call_node, node_def));
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
return Status::OK(); return host_compute_node;
} }
// Resets "device_ordinal" attr to placeholder value for related nodes // Resets "device_ordinal" attr to placeholder value for related nodes
@ -1634,7 +1634,7 @@ Status ExtractOutsideCompilationForFunction(
RewriteOutsideCompilationSubgraphFn rewrite_fn( RewriteOutsideCompilationSubgraphFn rewrite_fn(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name); xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
outside_compilation_attr_name, "", *fbody->graph, rewrite_fn, outside_compilation_attr_name, *fbody->graph, rewrite_fn,
/*reuse_existing_functions=*/true, &graph_out, fld)); /*reuse_existing_functions=*/true, &graph_out, fld));
// Replace outside_compilation function nodes with HostCompute ops. // Replace outside_compilation function nodes with HostCompute ops.
@ -1670,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction(
} }
} }
} }
std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) { for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps)); graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status());
Node* host_compute_node = host_compute_node_or.ValueOrDie();
if (host_compute_node) {
host_compute_nodes[host_compute_node->name()] = host_compute_node;
}
}
// For XlaHostCompute nodes with dependencies, add control edges between them
// so XlaCompiler can handle them in correct order.
for (auto iter : host_compute_nodes) {
Node* host_compute_node = iter.second;
std::vector<string> token_input_node_names;
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
kXlaTokenInputNodesAttrName,
&token_input_node_names));
for (const string& node_name : token_input_node_names) {
if (node_name == kXlaTokenArgNodeName) {
continue;
}
auto iter = host_compute_nodes.find(node_name);
if (iter != host_compute_nodes.end()) {
graph_out->AddControlEdge(iter->second, host_compute_node);
}
}
} }
// Handle nodes with associated functions. // Handle nodes with associated functions.

View File

@ -990,6 +990,16 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes)); "_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
} }
TEST_F(ExtractOutsideCompilationForFunctionTest, TEST_F(ExtractOutsideCompilationForFunctionTest,
@ -1062,5 +1072,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes)); "_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//tensorflow/compiler/tf2xla:internal", "//tensorflow/compiler/tf2xla:internal",
], ],
licenses = ["notice"], # Apache 2.0
) )
load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test")

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//tensorflow/compiler/tf2xla:internal", "//tensorflow/compiler/tf2xla:internal",
], ],
licenses = ["notice"], # Apache 2.0
) )
cc_library( cc_library(
@ -29,6 +28,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib", "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:tf_allocator_adapter",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],

View File

@ -61,7 +61,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type(); DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr; se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr;
std::unique_ptr<XlaAllocator> xla_allocator; std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
se::DeviceMemoryAllocator* device_allocator = nullptr; se::DeviceMemoryAllocator* device_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) { if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
@ -93,7 +93,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
se::MultiPlatformManager::PlatformWithId(platform_id); se::MultiPlatformManager::PlatformWithId(platform_id);
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
xla_allocator = absl::make_unique<XlaAllocator>( xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
} }

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/stream_executor_util.h" #include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow { namespace tensorflow {
@ -36,10 +37,10 @@ class XlaPlatformInfo {
public: public:
XlaPlatformInfo() : device_type_("") {} XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default; XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type, explicit XlaPlatformInfo(
se::Platform::Id platform_id, const DeviceType device_type, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata, const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<XlaAllocator> xla_allocator, std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
se::DeviceMemoryAllocator* device_allocator) se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type), : device_type_(device_type),
platform_id_(platform_id), platform_id_(platform_id),
@ -84,8 +85,8 @@ class XlaPlatformInfo {
// then device_allocator_ is the xla::Backend's memory allocator and // then device_allocator_ is the xla::Backend's memory allocator and
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
// then device_allocator_ is null and xla_allocator_ points to an appropriate // then device_allocator_ is null and xla_allocator_ points to an appropriate
// XlaAllocator instance. // se::TfAllocatorAdapter instance.
std::unique_ptr<XlaAllocator> xla_allocator_; std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
se::DeviceMemoryAllocator* device_allocator_; se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);

View File

@ -229,32 +229,18 @@ class MarkForCompilationPassImpl {
// Initialize some internal data structures. // Initialize some internal data structures.
Status Initialize(); Status Initialize();
// Runs through all the nodes in `cycles_graph_` and tries to create clusters. // Runs through the entire cluster graph in post-order and calls `fn(from,
// Returns true if any new clusters were created. // to)` on each edge. `fn(from, to)` is expected to return true if it was
StatusOr<bool> RunEdgeContractionLoopInPostOrderOnce(); // able to contract `from`->`to`.
//
// Returns true if `fn` returned true for any edge.
template <typename FnTy>
StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
// Runs through all the nodes in `cycles_graph_` and tries to contract high // If from->to is a "preferred" edge (i.e. if we have a choice, we want to
// priority edges for clusters. Returns true if any new clusters were created. // prioritize contracting from->to over contracting other edges) then
// // contracts it and returns true. Else returns false.
// There are potentially many maximal clustering results, but they will not StatusOr<bool> ContractEdgeIfPreferred(Cluster* from, Cluster* to);
// all be equally performant. Some clustering decision are likely to improve
// performance much more than others, and we cannot order contractions on this
// cost function, nor can we look at global information while deciding on
// individual edges to contract. Instead, we will make decisions on these
// important edges then make decisions on all other edges, causing the highest
// chance of all most important edges to be contracted.
//
// An example of where this might occur is with a digraph:
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
// should be clustered with A because it will prevent a potentially large
// tensor from A being computed and copied.
//
// This pass will ensure that contraction happens, which cannot be enforced in
// a single pass with the current algorithm.
// graph and prevent B->C from being clusterd in anticipation of a later A->B
// cluster.
StatusOr<bool> ContractPreferredEdges();
// Contracts as many edges as possible to create XLA clusters. After this // Contracts as many edges as possible to create XLA clusters. After this
// finishes the clustering decisions made are implicitly stored in // finishes the clustering decisions made are implicitly stored in
@ -276,10 +262,6 @@ class MarkForCompilationPassImpl {
// true if successful. // true if successful.
StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to); StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
// Tries to contract each edge from `cluster_from`. Returns true if any edges
// were contracted, false otherwise.
StatusOr<bool> TryToContractEdgesFrom(Cluster* cluster_from);
// Nodes that XLA can compile are put in `compilation_candidates_`. // Nodes that XLA can compile are put in `compilation_candidates_`.
Status FindCompilationCandidates(); Status FindCompilationCandidates();
@ -401,6 +383,13 @@ class MarkForCompilationPassImpl {
return true; return true;
} }
string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
absl::string_view reason) {
return absl::StrCat("Could not contract ", from->DebugString(*graph_),
" -> ", to->DebugString(*graph_), " because ", reason,
".");
}
DebugOptions debug_options_; DebugOptions debug_options_;
Graph* graph_; Graph* graph_;
FunctionLibraryDefinition* flib_def_; FunctionLibraryDefinition* flib_def_;
@ -611,7 +600,8 @@ Status MarkForCompilationPassImpl::Initialize() {
return BuildInitialClusterSet(); return BuildInitialClusterSet();
} }
StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() { template <typename FnTy>
StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
bool changed = false; bool changed = false;
for (int32 node : cycles_graph_.AllNodesInPostOrder()) { for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node); Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
@ -632,8 +622,18 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
continue; continue;
} }
if (cluster_to->cluster_size() == 1) { TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode()); changed |= contracted_edge;
}
}
return changed;
}
StatusOr<bool> MarkForCompilationPassImpl::ContractEdgeIfPreferred(
Cluster* from, Cluster* to) {
if (to->cluster_size() == 1) {
Node* n = graph_->FindNodeId(to->GetIdOfOnlyNode());
// Shape consuming operations are desirable to cluster with their // Shape consuming operations are desirable to cluster with their
// operands because they return a small set of scalar values after // operands because they return a small set of scalar values after
@ -644,43 +644,11 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
// tensor that must be computed and possible transposed/copied before // tensor that must be computed and possible transposed/copied before
// the second cluster executes. // the second cluster executes.
if (IsShapeConsumerOp(*n)) { if (IsShapeConsumerOp(*n)) {
TF_ASSIGN_OR_RETURN(bool contracted_edge, return TryToContractEdge(from, to);
TryToContractEdge(cluster_from, cluster_to));
changed |= contracted_edge;
}
}
} }
} }
return changed; return false;
}
StatusOr<bool>
MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() {
bool changed = false;
// Iterating over the graph once in post-order is sufficient to produce a
// maximal clustering:
//
// A. We visit a cluster only after maximally clustering all its children.
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
// its children that could have been absorbed into `node` have been
// absorbed.
// C. We have an invariant that making a cluster larger does not make edges
// leaving it more contractable. That is, if we have
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
// to contract Y->Z if Y->Z was not contractible originally.
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
if (!cluster_from) {
continue;
}
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
TryToContractEdgesFrom(cluster_from));
changed |= contracted_one_edge;
}
return changed;
} }
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
@ -694,25 +662,68 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
// without restrictions. This helps to minimize data output from clusters (and // without restrictions. This helps to minimize data output from clusters (and
// possible transpose operations before outputs) that might occur if a // possible transpose operations before outputs) that might occur if a
// ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations. // ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations.
TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges()); //
// There are potentially many maximal clustering results, but they will not
// all be equally performant. Some clustering decision are likely to improve
// performance much more than others, and we cannot order contractions on this
// cost function, nor can we look at global information while deciding on
// individual edges to contract. Instead, we will make decisions on these
// important edges then make decisions on all other edges, causing the highest
// chance of all most important edges to be contracted.
//
// An example of where this might occur is with a digraph:
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
// should be clustered with A because it will prevent a potentially large
// tensor from A being computed and copied.
//
// This pass will ensure that contraction happens, which cannot be enforced in
// a single pass with the current algorithm.
// graph and prevent B->C from being clusterd in anticipation of a later A->B
// cluster.
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); TF_ASSIGN_OR_RETURN(bool changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return ContractEdgeIfPreferred(from, to);
}));
// Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the // Iterating over the whole graph once in post-order is sufficient to produce
// linear time post-order scheme has been battle tested we can move this to // a maximal clustering:
// happen only in debug builds. //
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); // A. We visit a cluster only after maximally clustering all its children.
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
// its children that could have been absorbed into `node` have been
// absorbed.
// C. We have an invariant that making a cluster larger does not make edges
// leaving it more contractable. That is, if we have
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
// to contract Y->Z if Y->Z was not contractible originally.
TF_ASSIGN_OR_RETURN(changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return TryToContractEdge(from, to);
}));
// Check that the conclusion made above (that iterating over the graph once in
// post order gives a maximal clustering) holds. Once the linear time
// post-order scheme has been battle tested we can move this to happen only in
// debug builds.
TF_ASSIGN_OR_RETURN(changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return TryToContractEdge(from, to);
}));
TF_RET_CHECK(!changed); TF_RET_CHECK(!changed);
return Status::OK(); return Status::OK();
} }
std::atomic<int64> cluster_sequence_num;
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
Status MarkForCompilationPassImpl::CreateClusters() { Status MarkForCompilationPassImpl::CreateClusters() {
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
clusters_created_ = true; clusters_created_ = true;
static std::atomic<int64> cluster_sequence_num;
// Names for each cluster. // Names for each cluster.
std::unordered_map<int, string> cluster_names; std::unordered_map<int, string> cluster_names;
@ -745,7 +756,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
string& name = cluster_names[cluster->cycles_graph_node_id()]; string& name = cluster_names[cluster->cycles_graph_node_id()];
if (name.empty()) { if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++); name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
} }
n->AddAttr(kXlaClusterAttr, name); n->AddAttr(kXlaClusterAttr, name);
@ -1065,8 +1076,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
Cluster* from, Cluster* to, absl::string_view reason) { Cluster* from, Cluster* to, absl::string_view reason) {
VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> " VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
<< to->DebugString(*graph_) << " because " << reason << ".";
return false; return false;
} }
@ -1075,8 +1085,14 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
DCHECK(from->deadness_predicate().has_value() == DCHECK(from->deadness_predicate().has_value() ==
to->deadness_predicate().has_value()); to->deadness_predicate().has_value());
if (from->deadness_predicate() != to->deadness_predicate()) { if (from->deadness_predicate() != to->deadness_predicate()) {
return LogNotContractableAndReturnFalse( VLOG(3) << EdgeContractionFailureMsg(
from, to, "the two nodes have mismatching deadness"); from, to,
absl::StrCat(
"the two nodes have mismatching deadness: ",
deadness_analysis_->DebugString(*from->deadness_predicate()),
" and ",
deadness_analysis_->DebugString(*to->deadness_predicate())));
return false;
} }
TF_ASSIGN_OR_RETURN(bool devices_compatible, TF_ASSIGN_OR_RETURN(bool devices_compatible,
@ -1133,32 +1149,6 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
return MergeClusters(from, to); return MergeClusters(from, to);
} }
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
Cluster* cluster_from) {
bool changed = false;
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32> successors_copy =
cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
iteration_count_++;
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
if (!cluster_to) {
continue;
}
TF_ASSIGN_OR_RETURN(bool contracted_edge,
TryToContractEdge(cluster_from, cluster_to));
changed |= contracted_edge;
}
return changed;
}
Status MarkForCompilationPassImpl::Run() { Status MarkForCompilationPassImpl::Run() {
// Make sure that kernels have been registered on the JIT device. // Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels(); XlaOpRegistry::RegisterCompilationKernels();
@ -1485,7 +1475,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
op_filter.allow_control_trigger = true; op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true; op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true; op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_and_inaccurate_ops = true; op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type} return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
.IsCompilableCall(ndef, flr); .IsCompilableCall(ndef, flr);
@ -1522,4 +1513,8 @@ Status MarkForCompilationPass::RunForTest(
return MarkForCompilation(options, debug_options); return MarkForCompilation(options, debug_options);
} }
namespace testing {
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
} // namespace testing
} // namespace tensorflow } // namespace tensorflow

View File

@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
// function is compilable iff every operator in the function body is // function is compilable iff every operator in the function body is
// compilable. // compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
namespace testing {
// DO NOT USE IN PRODUCTION.
//
// Resets some internal state to let us write reliable unit tests.
void ResetClusterSequenceNumber();
} // namespace testing
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_

View File

@ -1,7 +1,6 @@
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"], default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
) )
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")

View File

@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch(
missing.push_back(entry.first); missing.push_back(entry.first);
} }
return errors::InvalidArgument("Missing shapes for nodes: ", return errors::InvalidArgument("Missing shapes for nodes: ",
str_util::Join(missing, ",")); absl::StrJoin(missing, ","));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -60,7 +60,8 @@ Status XlaCpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true; registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations = static XlaDeviceOpRegistrations* registrations =

View File

@ -71,7 +71,7 @@ class XlaDeviceContext : public DeviceContext {
StatusCallback done) const override; StatusCallback done) const override;
xla::LocalClient* client() const { return client_; } xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); } se::Stream* stream() const override { return stream_.get(); }
se::Stream* host_to_device_stream() const { se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get(); return host_to_device_stream_.get();
} }

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/kernels/host_constant_op.h" #include "tensorflow/core/kernels/host_constant_op.h"
#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/logging_ops.h"
#include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/resource_variable_ops.h"
@ -81,6 +82,11 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("Assert") \
.Device(DEVICE) \
.HostMemory("condition") \
.HostMemory("data"), \
AssertOp); \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \

View File

@ -95,7 +95,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true; registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
static XlaDeviceOpRegistrations* registrations = static XlaDeviceOpRegistrations* registrations =

View File

@ -63,7 +63,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true; registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true; registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true; registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true; registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration); registration);

View File

@ -167,32 +167,6 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
return Status::OK(); return Status::OK();
} }
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {}
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
void* data = nullptr;
if (size != 0) {
data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
if (data == nullptr) {
return errors::ResourceExhausted(
"Out of memory while trying to allocate ", size, " bytes.");
}
}
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
}
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
wrapped_->DeallocateRaw(mem.opaque());
return Status::OK();
}
XlaComputationLaunchContext::XlaComputationLaunchContext( XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams) bool allocate_xla_tensors, bool use_multiple_streams)

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/device_memory_allocator.h"
namespace tensorflow { namespace tensorflow {
class XlaAllocator;
// Struct that represents a possibly-absent Tensor. // Struct that represents a possibly-absent Tensor.
struct OptionalTensor { struct OptionalTensor {
@ -104,74 +103,6 @@ class VariableInfo {
Status LockVariables(absl::Span<VariableInfo> variables) Status LockVariables(absl::Span<VariableInfo> variables)
EXCLUSIVE_LOCK_FUNCTION(); EXCLUSIVE_LOCK_FUNCTION();
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public se::DeviceMemoryAllocator {
public:
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override;
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
Allocator* wrapped_;
};
// Adapter class that wraps per-device TF allocators as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
// see comment on `AllowsAsynchronousDeallocation()`.
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
public:
MultiDeviceAdapter(
const se::Platform* platform,
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
: DeviceMemoryAllocator(platform),
tf_allocators_(std::move(tf_allocators)) {
for (const auto& tf_allocator : tf_allocators_) {
per_device_allocators_.emplace_back(platform, tf_allocator.get());
}
}
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
retry_on_failure);
}
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
mem);
}
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
// not take ownership of its underlying Allocator).
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
};
// Helper class to perform the marshalling of TensorFlow inputs and outputs to // Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation. // ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext { class XlaComputationLaunchContext {

View File

@ -28,10 +28,9 @@
** Please don't remove this file - it is supporting some 3rd party plugins ** ** Please don't remove this file - it is supporting some 3rd party plugins **
""" """
licenses(["notice"])
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
licenses = ["notice"],
) )
cc_library( cc_library(

View File

@ -954,7 +954,7 @@ tf_xla_py_test(
tf_xla_py_test( tf_xla_py_test(
name = "ternary_ops_test", name = "ternary_ops_test",
size = "small", size = "medium",
srcs = ["ternary_ops_test.py"], srcs = ["ternary_ops_test.py"],
deps = [ deps = [
":xla_test", ":xla_test",

View File

@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import session
from tensorflow.python.compiler.xla import xla from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
def f(): def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond( output = control_flow_ops.cond(
constant_op.constant( constant_op.constant(True),
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack() return output.stack()
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit() xla_context.Exit()
def testCondAndTensorArrayInDefun_constFolding(self):
g = ops.Graph()
with session.Session(graph=g), g.as_default(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@function.defun
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(False),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testCondAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit()
def testCondConstPropagation(self): def testCondConstPropagation(self):
with self.session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit() xla_context.Exit()
def testSwitchCaseAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseConstPropagation(self): def testSwitchCaseConstPropagation(self):
self.skipTest("b/127846988") self.skipTest("b/127846988")
with self.session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():

View File

@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
padding="VALID", padding="VALID",
patches=patches) patches=patches)
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 2]
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
# [1, 1, 1, 8]
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="VALID",
patches=patches)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -72,21 +72,21 @@ class TernaryOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types: for dtype in self.numeric_types:
self._testTernary( self._testTernary(
array_ops.where, array_ops.where,
np.array(0, dtype=np.bool), np.array(False),
np.array(2, dtype=dtype), np.array(2, dtype=dtype),
np.array(7, dtype=dtype), np.array(7, dtype=dtype),
expected=np.array(7, dtype=dtype)) expected=np.array(7, dtype=dtype))
self._testTernary( self._testTernary(
array_ops.where, array_ops.where,
np.array(1, dtype=np.bool), np.array(True),
np.array([1, 2, 3, 4], dtype=dtype), np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([1, 2, 3, 4], dtype=dtype)) expected=np.array([1, 2, 3, 4], dtype=dtype))
self._testTernary( self._testTernary(
array_ops.where, array_ops.where,
np.array(0, dtype=np.bool), np.array(False),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
@ -105,6 +105,74 @@ class TernaryOpsTest(xla_test.XLATestCase):
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype)) expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
def testSelectV2(self):
for dtype in self.numeric_types:
self._testTernary(
array_ops.where_v2,
np.array(False),
np.array(2, dtype=dtype),
np.array(7, dtype=dtype),
expected=np.array(7, dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array(True),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([1, 2, 3, 4], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array(False),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array([0, 1, 1, 0], dtype=np.bool),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([5, 2, 3, 8], dtype=dtype))
# Broadcast the condition
self._testTernary(
array_ops.where_v2,
np.array([0, 1], dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype))
# Broadcast the then branch to the else
self._testTernary(
array_ops.where_v2,
np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool),
np.array([[1, 2]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
# Broadcast the else branch to the then
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
np.array([[1, 2]], dtype=dtype),
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
# Broadcast the then/else branches to the condition
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool),
np.array(7, dtype=dtype),
np.array(8, dtype=dtype),
expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
np.array(7, dtype=dtype),
np.array([8, 9], dtype=dtype),
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
def testSlice(self): def testSlice(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
self._testTernary( self._testTernary(

View File

@ -104,6 +104,24 @@ struct EdgePtrCompare {
} }
}; };
// TODO(laigd): instead of deciding the device here, the converter should accept
// a device name as one of the conversion parameter so users can control on
// which device they want to run the conversion.
std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
PlatformGpuId platform_gpu_id;
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
<< platform_gpu_id.value();
return std::make_pair(tf_gpu_id, platform_gpu_id);
}
}
LOG(ERROR) << "Could not find any TF GPUs";
return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
}
// Function to get subsegment information structure. // Function to get subsegment information structure.
Status GetEngineInfo(const Graph* g, Status GetEngineInfo(const Graph* g,
const grappler::GraphProperties& graph_properties, const grappler::GraphProperties& graph_properties,
@ -128,20 +146,37 @@ Status GetEngineInfo(const Graph* g,
if (segment_nodes.count(node) == 0) continue; if (segment_nodes.count(node) == 0) continue;
auto node_device = node->requested_device(); auto node_device = node->requested_device();
if (!node_device.empty()) { if (!node_device.empty()) {
// If device is CPU, treat as if no device was assigned. Don't add CPU to // If device is set, it means device placement may have been done before,
// segment_device because that would cause a segfault in // so we need to assign a device for the TRTEngineOp to maintain the
// GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes // invariance.
// any already set device is a GPU. // If the device is CPU in this case, it tries to find the first available
// GPU and use it as the device.
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
const bool parse_succeeded =
DeviceNameUtils::ParseFullName(node_device, &parsed_name); DeviceNameUtils::ParseFullName(node_device, &parsed_name);
if (parsed_name.type == "CPU") { if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) {
VLOG(1) << "Node " << node->name() << " was assigned to the CPU. " string msg;
<< "Attempting to place on GPU."; if (!parse_succeeded) {
msg = StrCat("Failed to parse assigned device of node ", node->name(),
". ");
} else {
msg = StrCat("Node ", node->name(), " was assigned to the CPU. ");
}
VLOG(1) << msg << "Attempting to place on GPU.";
TfGpuId tf_gpu_id;
PlatformGpuId platform_gpu_id;
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
if (tf_gpu_id.value() >= 0) {
parsed_name.type = "GPU";
parsed_name.id = tf_gpu_id.value();
segment_devices.insert(DeviceNameUtils::FullName(
parsed_name.job, parsed_name.replica, parsed_name.task,
parsed_name.type, parsed_name.id));
}
} else { } else {
segment_devices.insert(node_device); segment_devices.insert(node_device);
} }
} else { } else if (node->has_assigned_device_name()) {
if (node->has_assigned_device_name()) {
// It appears that nodes will not have assigned devices at this point in // It appears that nodes will not have assigned devices at this point in
// execution. // execution.
segment_devices.insert(node->assigned_device_name()); segment_devices.insert(node->assigned_device_name());
@ -149,7 +184,6 @@ Status GetEngineInfo(const Graph* g,
VLOG(2) << "Node " << node->name() VLOG(2) << "Node " << node->name()
<< " neither have requested device nor assigned device"; << " neither have requested device nor assigned device";
} }
}
subgraph_nodes.push_back(node); subgraph_nodes.push_back(node);
const int node_id = node->id(); const int node_id = node->id();
@ -251,13 +285,11 @@ Status GetEngineInfo(const Graph* g,
info->engine_name = StrCat(scope_name, info->engine_name); info->engine_name = StrCat(scope_name, info->engine_name);
VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
<< "' to a GraphDef"; << "' to a GraphDef";
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) { if (segment_devices.size() == 1) {
info->device = *segment_devices.begin(); info->device = *segment_devices.begin();
} else if (segment_devices.size() > 1) { } else if (segment_devices.size() > 1) {
LOG(WARNING) << "Detected multiple(" << segment_devices.size() LOG(WARNING) << "Detected multiple (" << segment_devices.size()
<< ") devices for the segment. Picking first one to continue " << ") devices for the segment. Picking first one to continue.";
<< "but this shouldn't have happened";
info->device = *segment_devices.begin(); info->device = *segment_devices.begin();
} else { } else {
VLOG(1) << "No device is assigned to the segment. " VLOG(1) << "No device is assigned to the segment. "
@ -543,10 +575,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
std::map<string, Node*> io_nodes; std::map<string, Node*> io_nodes;
int num_inputs = 0; int num_inputs = 0;
for (auto n : sgraph.op_nodes()) { for (auto n : sgraph.op_nodes()) {
if (str_util::StartsWith(n->name(), kInputPHName)) { if (absl::StartsWith(n->name(), kInputPHName)) {
num_inputs++; num_inputs++;
io_nodes.insert({n->name(), n}); io_nodes.insert({n->name(), n});
} else if (str_util::StartsWith(n->name(), kOutputPHName)) { } else if (absl::StartsWith(n->name(), kOutputPHName)) {
io_nodes.insert({n->name(), n}); io_nodes.insert({n->name(), n});
} }
} }
@ -640,24 +672,17 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr || if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
engine.device.empty()) { engine.device.empty()) {
// If device is not set, use the first found GPU device for the conversion. // If device is not set, use the first found GPU device for the conversion.
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) { TfGpuId tf_gpu_id;
TfGpuId tf_gpu_id(tf_gpu_id_value);
PlatformGpuId platform_gpu_id; PlatformGpuId platform_gpu_id;
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
<< platform_gpu_id.value();
cuda_device_id = platform_gpu_id.value(); cuda_device_id = platform_gpu_id.value();
if (cuda_device_id >= 0) {
GPUOptions gpu_options; GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding // If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the // allocator must have been initialized already, so the
// GetGPUAllocator() call won't create a new allocator. // GetGPUAllocator() call won't create a new allocator.
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator( dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
gpu_options, tf_gpu_id, 1); gpu_options, tf_gpu_id, 1);
break;
}
LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
<< s;
} }
return std::make_pair(cuda_device_id, dev_allocator); return std::make_pair(cuda_device_id, dev_allocator);
} }
@ -750,8 +775,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
EngineInfo curr_engine; EngineInfo curr_engine;
curr_engine.engine_name = StrCat("TRTEngineOp_", t); curr_engine.engine_name = StrCat("TRTEngineOp_", t);
Status status = Status status =
GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
node_map, reverse_topo_order, &curr_engine); reverse_topo_order, &curr_engine);
if (!status.ok()) { if (!status.ok()) {
LOG(WARNING) << "Failed to get engine info for segment " << t << ": " LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
<< status; << status;
@ -776,7 +801,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
total_engine_bytes_size += engine_bytes_size.back(); total_engine_bytes_size += engine_bytes_size.back();
total_num_nodes_in_segments += curr_segment.first.size(); total_num_nodes_in_segments += curr_segment.size();
engine_segments.push_back(std::move(curr_engine)); engine_segments.push_back(std::move(curr_engine));
converted_segments.push_back(std::move(curr_segment)); converted_segments.push_back(std::move(curr_segment));
@ -806,7 +831,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
engine.max_workspace_size_bytes = engine.max_workspace_size_bytes =
params.max_workspace_size_bytes * params.max_workspace_size_bytes *
(engine_bytes_size.at(i) / total_engine_bytes_size + (engine_bytes_size.at(i) / total_engine_bytes_size +
converted_segments.at(i).first.size() / total_num_nodes_in_segments) / converted_segments.at(i).size() / total_num_nodes_in_segments) /
2.0; 2.0;
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
<< engine.engine_name; << engine.engine_name;
@ -828,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) {
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph, CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
alloc.get(), &engine_nodes); alloc.get(), &engine_nodes);
string msg = StrCat("TensorRT node ", engine.engine_name, string msg =
" added for segment ", i, " consisting of ", StrCat("TensorRT node ", engine.engine_name, " added for segment ", i,
converted_segments.at(i).first.size(), " nodes"); " consisting of ", converted_segments.at(i).size(), " nodes");
if (status.ok()) { if (status.ok()) {
LOG(INFO) << msg << " succeeded."; LOG(INFO) << msg << " succeeded.";
} else { } else {
@ -839,7 +864,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
} }
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
msg = "Segment consists of nodes: "; msg = "Segment consists of nodes: ";
for (const Node* node : converted_segments.at(i).first) { for (const Node* node : converted_segments.at(i)) {
StrAppend(&msg, node->name(), ", "); StrAppend(&msg, node->name(), ", ");
} }
VLOG(1) << msg; VLOG(1) << msg;
@ -848,7 +873,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
// If status is ok, we successfully added the node to the graph and can // If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified. // remove segment ops. Otherwise graph is not modified.
if (status.ok()) { if (status.ok()) {
for (const Node* node : converted_segments.at(i).first) { for (const Node* node : converted_segments.at(i)) {
graph.RemoveNode(const_cast<Node*>(node)); graph.RemoveNode(const_cast<Node*>(node));
} }
} }

View File

@ -239,7 +239,7 @@ class ConvertAfterShapesTest : public ::testing::Test {
params.output_names = &output_names; params.output_names = &output_names;
params.max_workspace_size_bytes = 8 << 20; params.max_workspace_size_bytes = 8 << 20;
params.output_graph_def = output_graph_def; params.output_graph_def = output_graph_def;
params.minimum_segment_size = 2; params.minimum_segment_size = 1;
params.graph_properties = &graph_properties; params.graph_properties = &graph_properties;
params.use_calibration = false; params.use_calibration = false;

View File

@ -385,11 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) {
", dims=", DebugString(tensor.getDimensions()), ")"); ", dims=", DebugString(tensor.getDimensions()), ")");
} }
Status Converter::GetTrtBroadcastShape( Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims) const { nvinfer1::Dims* operand_r_new_dims) {
// ***************************************************************************
// TensorRT Elementwise op supports broadcast but requires both tensor to be // TensorRT Elementwise op supports broadcast but requires both tensor to be
// of Identical rank // of Identical rank
// //
@ -473,14 +472,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
nvinfer1::Weights trt_weights = weights.GetTrtWeights(); nvinfer1::Weights trt_weights = weights.GetTrtWeights();
nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
if (!layer) return nullptr; if (!layer) return nullptr;
const nvinfer1::DataType trt_dtype = trt_weights.type;
nvinfer1::ITensor* trt_tensor = layer->getOutput(0); nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
#if !IS_TRT_VERSION_GE(5, 1, 3, 0) #if !IS_TRT_VERSION_GE(5, 1, 3, 0)
// TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
// the data type below, it will always be kFLOAT regardless what the data type // the data type below, it will always be kFLOAT regardless what the data type
// of the weights is. Once NVIDIA fixes this bug, we should remove the data // of the weights is. Once NVIDIA fixes this bug, we should remove the data
// type setting logic below and test should still pass. // type setting logic below and test should still pass.
trt_tensor->setType(trt_dtype); trt_tensor->setType(trt_weights.type);
#endif #endif
return trt_tensor; return trt_tensor;
} }
@ -1677,190 +1675,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
return Status::OK(); return Status::OK();
} }
// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
// right operand. If swapped_inputs is true, those two are swapped.
//
// TODO(jie): broadcast is needed yet not implemented.
// Only implemented channel wise for the time being.
Status BinaryTensorOpWeight(OpConverterParams* params,
nvinfer1::ITensor* tensor,
TRT_ShapedWeights weights, bool swapped_inputs) {
static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
"Div", "RealDiv"};
const auto& node_def = params->node_def;
if (!supported_ops.count(node_def.op())) {
return errors::Unimplemented(node_def.op(), " is not supported, at ",
node_def.name());
}
// Check scale mode.
auto dims_w = weights.shape_;
const auto dims_t = tensor->getDimensions();
// TODO(jie): addScale checks for input tensor dimension
if (dims_t.nbDims != 3) {
return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
node_def.name());
}
// Default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
// TODO(jie): maybe use a permutation instead to support more cases;
bool need_to_permute = false;
if (weights.count() == 1) {
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
} else {
VLOG(2) << "weights dims: " << DebugString(dims_w)
<< "; tensor dims: " << DebugString(dims_t);
// Make sure no broadcasting on batch dimension.
if (dims_w.nbDims == dims_t.nbDims + 1) {
if (dims_w.d[0] == 1) {
for (int i = 1; i < dims_w.nbDims; i++) {
dims_w.d[i - 1] = dims_w.d[i];
}
dims_w.nbDims--;
} else {
return errors::InvalidArgument("Binary op cannot operate on batch, at ",
node_def.name());
}
}
if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
// Default is element-wise
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != dims_t.d[i]) {
// If dimension does not match, switch back to per-channel
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
// If the mode is per-channel, since channel dimension is assumed to be
// the third to last dimension, we need to make sure all other dimensions
// have size 1.
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != 1)
return errors::InvalidArgument(
"Weight dims not compatible for channel-wise broadcast at ",
node_def.name());
}
}
} else if (dims_w.nbDims == 1 &&
dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
// Channel wise and broadcast required. We compare the last dimension of
// the tensor shape because of tensorflow default broadcasting rules.
need_to_permute = true;
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else {
return errors::InvalidArgument("Weight dims not compatible at ",
node_def.name());
}
}
// TODO(laigd): we should add validation_only support in TransposeTensor() and
// PrepareTensorForShape().
if (params->validation_only) return Status::OK();
// Transpose last dimension.
std::vector<int> permutation(dims_t.nbDims + 1);
if (need_to_permute) {
// We swap the last dimension into channel for trt, because of tensorflow
// default broadcasting rules.
for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
permutation[i] = i;
}
permutation[1] = dims_t.nbDims;
permutation[dims_t.nbDims] = 1;
TF_RETURN_IF_ERROR(
params->converter->TransposeTensor(tensor, permutation, &tensor));
}
// Prepare weights
TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.TrtDType());
TRT_ShapedWeights power_weights(weights.TrtDType());
if (node_def.op() == "Sub") {
if (swapped_inputs) {
shift_weights = weights;
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
*tensor, nvinfer1::UnaryOperation::kNEG);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
// Since quantization ranges are symmetric, the same range as the input
// will work for the negation of the input.
params->converter->MarkQuantizationRangesAsInferrable(
tensor, layer->getOutput(0));
tensor = layer->getOutput(0);
} else {
TRT_ShapedWeights neg_weights =
params->weight_store->GetTempWeights(weights);
LambdaFactory unary_op;
unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
shift_weights = neg_weights;
}
} else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
if (swapped_inputs) {
// We need to infer the quantization range for this intermediate tensor.
//
// x -> [Recip] -> 1/x -> [Scale] -> s/x
// ^
// need range for this
//
// We have the quantization scales for x and s/x - can we divide the scale
// for s/x by s? Only if it is a scalar.
//
// Because of this issue, fall back to BinaryTensorOpTensor if we are
// doing INT8 with no calibration. There is most likely no performance
// penalty by falling back here.
if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
!params->converter->use_calibration()) {
return errors::Unimplemented(
"Intermediate quantization range cannot be determined without"
" calibration. Falling back to BinaryTensorOpTensor for ",
node_def.op(), ", at ", node_def.name());
}
scale_weights = weights;
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
*tensor, nvinfer1::UnaryOperation::kRECIP);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
tensor = layer->getOutput(0);
} else {
TRT_ShapedWeights recip_weights =
params->weight_store->GetTempWeights(weights);
LambdaFactory unary_op;
unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
scale_weights = recip_weights;
}
} else if (node_def.op() == "Mul") {
scale_weights = weights;
} else if (node_def.op() == "Add") {
shift_weights = weights;
} else {
// This should not happen.
return errors::Unimplemented("Binary op not supported at ", node_def.op());
}
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
*tensor, scale_mode, shift_weights.GetTrtWeights(),
scale_weights.GetTrtWeights(), power_weights.GetTrtWeights());
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// Transpose back dimension
if (need_to_permute) {
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
output_tensor, permutation, &output_tensor));
}
// Pass the output
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
Status ConvertConv2DHelper(OpConverterParams* params, int group, Status ConvertConv2DHelper(OpConverterParams* params, int group,
bool is_conv2d_backprop_input) { bool is_conv2d_backprop_input) {
const auto& inputs = params->inputs; const auto& inputs = params->inputs;
@ -1951,7 +1765,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
kernel_size.h() = weights.shape_.d[2]; kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3]; kernel_size.w() = weights.shape_.d[3];
// Add padding. // Before TRT 5.1.3, we have to calculate padding ourselves.
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
std::vector<std::pair<int, int>> padding; std::vector<std::pair<int, int>> padding;
if (attrs.get<string>("padding") == "SAME") { if (attrs.get<string>("padding") == "SAME") {
nvinfer1::DimsHW effective_kernel_size = kernel_size; nvinfer1::DimsHW effective_kernel_size = kernel_size;
@ -1978,12 +1793,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
padding = {{0, 0}, {0, 0}}; padding = {{0, 0}, {0, 0}};
} }
// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we // Handle asymmetric padding. TensorRT 5.1 added support for asymmetric
// can only use asymmetric padding in convolutions with 5.1.3+. // padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can
#if !IS_TRT_VERSION_GE(5, 1, 3, 0) // only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we
// will always use setPaddingMode for simplicity.
if (padding[0].first != padding[0].second || if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) { padding[1].first != padding[1].second) {
// Handle asymmetric padding.
auto pad_layer = params->converter->network()->addPadding( auto pad_layer = params->converter->network()->addPadding(
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second)); nvinfer1::DimsHW(padding[0].second, padding[1].second));
@ -2006,20 +1821,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
layer->setStride(stride); layer->setStride(stride);
// TensorRT 5.1.3 added support for padding modes. // TensorRT 5.1.3 added support for padding modes.
#if IS_TRT_VERSION_GE(5, 1, 3, 0) #if IS_TRT_VERSION_GE(5, 1, 3, 0)
// VALID padding is the default TRT behavior.
if (attrs.get<string>("padding") == "SAME") { if (attrs.get<string>("padding") == "SAME") {
VLOG(2) << "Using SAME padding";
// SAME_UPPER means that post padding is preferred. // SAME_UPPER means that post padding is preferred.
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
} }
// For VALID padding, we need to manually set the padding.
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(
nvinfer1::DimsHW{padding[0].second, padding[1].second});
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
<< " and post-padding to: " << DebugString(layer->getPostPadding());
#else #else
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
#endif #endif
layer->setName(node_def.name().c_str()); layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups); layer->setNbGroups(num_groups);
@ -2033,17 +1841,10 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
layer->setStride(stride); layer->setStride(stride);
#if IS_TRT_VERSION_GE(5, 1, 3, 0) #if IS_TRT_VERSION_GE(5, 1, 3, 0)
if (attrs.get<string>("padding") == "SAME") { if (attrs.get<string>("padding") == "SAME") {
VLOG(2) << "Using SAME padding";
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
} }
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(
nvinfer1::DimsHW{padding[0].second, padding[1].second});
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
<< " and post-padding to: " << DebugString(layer->getPostPadding());
#else #else
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
#endif #endif
layer->setName(node_def.name().c_str()); layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups); layer->setNbGroups(num_groups);
@ -2061,74 +1862,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
return Status::OK(); return Status::OK();
} }
Status BinaryTensorOpTensor(OpConverterParams* params,
const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r) {
const auto& node_def = params->node_def;
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
{"Div", nvinfer1::ElementWiseOperation::kDIV},
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
};
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end()) {
return errors::Unimplemented("Binary op ", node_def.op(),
" not supported at: ", node_def.name());
}
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
Status status = params->converter->GetTrtBroadcastShape(
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
if (!status.ok()) {
return errors::InvalidArgument(
"Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
status.error_message());
}
TFAttrs attrs(node_def);
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
if (dtype == nvinfer1::DataType::kINT32) {
return errors::Unimplemented("Binary op ", node_def.op(),
" does not support INT32, at ",
node_def.name());
}
if (params->validation_only) return Status::OK();
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
status = params->converter->PrepareTensorForShape(
operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l);
if (status.ok()) {
status = params->converter->PrepareTensorForShape(
operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r);
}
if (!status.ok()) {
return errors::Internal("Failed to convert binary op ", node_def.name(),
": ", status.error_message());
}
// Check type consistency.
TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
<< DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
<< DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
// Add ElementWise layer.
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
op_pair->second);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// Pass the output
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
Status ConvertPlugin(OpConverterParams* params) { Status ConvertPlugin(OpConverterParams* params) {
const auto& inputs = params->inputs; const auto& inputs = params->inputs;
const auto& node_def = params->node_def; const auto& node_def = params->node_def;
@ -2777,6 +2510,8 @@ Status ConvertPool(OpConverterParams* params) {
const auto tf_kernel = attrs.get<std::vector<int64>>("ksize"); const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
// Before TRT 5.1.3, we have to calculate padding ourselves.
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
auto tensor_dim = tensor->getDimensions(); auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding; std::vector<std::pair<int, int>> padding;
if (padding_type == "SAME") { if (padding_type == "SAME") {
@ -2789,13 +2524,13 @@ Status ConvertPool(OpConverterParams* params) {
} else if (padding_type == "VALID") { } else if (padding_type == "VALID") {
padding = {{0, 0}, {0, 0}}; padding = {{0, 0}, {0, 0}};
} }
#endif
// TensorRT 5.1 added support for asymmetric padding. // TensorRT 5.1 added support for asymmetric padding. Before that, we need an
// extra padding layer.
#if !IS_TRT_VERSION_GE(5, 1, 0, 0) #if !IS_TRT_VERSION_GE(5, 1, 0, 0)
// Asymmetric padding case.
if (padding[0].first != padding[0].second || if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) { padding[1].first != padding[1].second) {
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
auto pad_layer = params->converter->network()->addPadding( auto pad_layer = params->converter->network()->addPadding(
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second)); nvinfer1::DimsHW(padding[0].second, padding[1].second));
@ -2817,16 +2552,13 @@ Status ConvertPool(OpConverterParams* params) {
layer->getOutput(0)); layer->getOutput(0));
layer->setStride(stride); layer->setStride(stride);
// TensorRT 5.1.3 added support for padding modes.
#if IS_TRT_VERSION_GE(5, 1, 3, 0) #if IS_TRT_VERSION_GE(5, 1, 3, 0)
// VALID padding is the default TRT behavior.
if (attrs.get<string>("padding") == "SAME") { if (attrs.get<string>("padding") == "SAME") {
// SAME_UPPER means that post padding is preferred. // SAME_UPPER means that post padding is preferred.
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
} }
#endif #elif IS_TRT_VERSION_GE(5, 1, 0, 0)
// TensorRT 5.1 has support for asymmetric padding.
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
// If padding mode is not SAME, then these values will be used instead.
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second}); layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
#else #else
@ -3350,9 +3082,6 @@ Status ConvertIdentity(OpConverterParams* params) {
Status ConvertBinary(OpConverterParams* params) { Status ConvertBinary(OpConverterParams* params) {
const auto& inputs = params->inputs; const auto& inputs = params->inputs;
const auto& node_def = params->node_def; const auto& node_def = params->node_def;
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
// false}}));
if (inputs.size() != 2) { if (inputs.size() != 2) {
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
" inputs but expected 2, at ", " inputs but expected 2, at ",
@ -3368,33 +3097,45 @@ Status ConvertBinary(OpConverterParams* params) {
"both input as constant at: ", "both input as constant at: ",
node_def.name()); node_def.name());
} }
const TRT_TensorOrWeights& operand_l = inputs.at(0);
const TRT_TensorOrWeights& operand_r = inputs.at(1);
// TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
// IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For {"Add", nvinfer1::ElementWiseOperation::kSUM},
// now, the performance will be slightly better with IScaleLayer because it {"Mul", nvinfer1::ElementWiseOperation::kPROD},
// can be fused in more situations. However, most of the benefits of {"Sub", nvinfer1::ElementWiseOperation::kSUB},
// IScaleLayer are when the layer performs both a shift and a scale, which we {"Div", nvinfer1::ElementWiseOperation::kDIV},
// don't do except for convolutions. {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
// {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
// Try to convert into Scale layer first (for better performance). {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
// Since scale layer supports restricted broadcast policy and op types, we {"Pow", nvinfer1::ElementWiseOperation::kPOW},
// allow failure and try to handle it through Elementwise op };
// (BinaryTensorOpTensor). auto op_pair = ops.find(node_def.op());
Status status = Status::OK(); if (op_pair == ops.end()) {
if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { return errors::Unimplemented("Binary op ", node_def.op(),
status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), " not supported at: ", node_def.name());
inputs.at(1).weights(), false);
} else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
inputs.at(0).weights(), true);
} }
// If both input are tensors, or one of them is weights but the conversion
// above failed, try the conversion using BinaryTensorOpTensor. nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
if (!status.ok()) VLOG(2) << status; operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r));
status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
} nvinfer1::ITensor* tensor_l = nullptr;
return status; nvinfer1::ITensor* tensor_r = nullptr;
// This will also convert constants to tensors, and set quantization ranges.
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
operand_l, broadcasted_dims_l, params->validation_only, &tensor_l));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
operand_r, broadcasted_dims_r, params->validation_only, &tensor_r));
if (params->validation_only) return Status::OK();
// Add ElementWise layer.
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
op_pair->second);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
return Status::OK();
} }
Status ConvertRsqrt(OpConverterParams* params) { Status ConvertRsqrt(OpConverterParams* params) {
@ -4547,7 +4288,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) {
const auto& node_def = params->node_def; const auto& node_def = params->node_def;
// Broadcast inputs. // Broadcast inputs.
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape( TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r)); inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr; nvinfer1::ITensor* tensor_r = nullptr;
@ -4692,8 +4433,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) {
TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name()); TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
// Create plugin // Create plugin
nvinfer1::IPluginV2* plugin = TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
creator->createPlugin(node_def.name().c_str(), &fc); creator->createPlugin(node_def.name().c_str(), &fc));
TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name()); TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
// Set plugin inputs // Set plugin inputs
@ -4875,7 +4616,8 @@ static void RegisterValidatableOpConverters(
for (auto pool_op_type : {"AvgPool", "MaxPool"}) { for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
(*registration)[pool_op_type] = ConvertPool; (*registration)[pool_op_type] = ConvertPool;
} }
for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { for (auto normalization_op_type :
{"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) {
(*registration)[normalization_op_type] = ConvertFusedBatchNorm; (*registration)[normalization_op_type] = ConvertFusedBatchNorm;
} }
for (auto unary_op_pair : *UnaryOperationMap()) { for (auto unary_op_pair : *UnaryOperationMap()) {

View File

@ -512,13 +512,6 @@ class Converter {
const bool validation_only, const bool validation_only,
nvinfer1::ITensor** tensor); nvinfer1::ITensor** tensor);
// Return OK if the broadcast scheme is supported and compute the shapes after
// broadcasting.
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims) const;
// Creates an IConstantLayer using 'weights' whose dimensions are specified by // Creates an IConstantLayer using 'weights' whose dimensions are specified by
// 'dims', and returns the output ITensor. // 'dims', and returns the output ITensor.
nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
@ -592,6 +585,13 @@ class Converter {
friend class OpConverterTest; friend class OpConverterTest;
}; };
// Return OK if the broadcast scheme is supported and compute the shapes after
// broadcasting.
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims);
// Map of all supported UnaryOperations // Map of all supported UnaryOperations
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap(); const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
// Map of all supported ActivationTypes // Map of all supported ActivationTypes

View File

@ -988,18 +988,16 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
operand_2_shape, operand_2_is_tensor, operand_2_batch_size); operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
// operand_1 broadcast operand_2 // operand_1 broadcast operand_2
ExpectStatus( ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims,
this->converter_->GetTrtBroadcastShape( &operand_2_new_dims),
operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
expected_code, expected_error_msg_substr); expected_code, expected_error_msg_substr);
if (expected_code == error::OK) { if (expected_code == error::OK) {
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
} }
// operand_2 broadcast operand_1 // operand_2 broadcast operand_1
ExpectStatus( ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims,
this->converter_->GetTrtBroadcastShape( &operand_1_new_dims),
operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
expected_code, expected_error_msg_substr); expected_code, expected_error_msg_substr);
if (expected_code == error::OK) { if (expected_code == error::OK) {
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
@ -1033,18 +1031,29 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
error::INVALID_ARGUMENT, error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported " "Broadcasting beyond batch dimension is not supported "
"(tensor #dims 4 vs broadcast #dims 5)"); "(tensor #dims 4 vs broadcast #dims 5)");
symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 2 vs broadcast #dims 3)",
/*operand_1_batch_size=*/2);
// Both inputs are tensors. // Both inputs are tensors.
symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT, error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported " "Broadcasting beyond batch dimension is not supported "
"(tensor #dims 3 vs broadcast #dims 4)"); "(tensor #dims 3 vs broadcast #dims 4)");
symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 2 vs broadcast #dims 3)");
symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
{2, 1, 4}); {2, 1, 4});
symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT, error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported " "Broadcasting beyond batch dimension is not supported "
"(tensor #dims 4 vs broadcast #dims 5)"); "(tensor #dims 4 vs broadcast #dims 5)");
symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
} }
TEST_F(ConverterTest, CreateConstantLayer) { TEST_F(ConverterTest, CreateConstantLayer) {
@ -1070,7 +1079,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
int batch_size = -1; int batch_size = -1;
for (const NodeDef& node : gdef.node()) { for (const NodeDef& node : gdef.node()) {
absl::string_view node_name(node.name()); absl::string_view node_name(node.name());
if (str_util::ConsumePrefix(&node_name, kInputPHName)) { if (absl::ConsumePrefix(&node_name, kInputPHName)) {
int port = -1; int port = -1;
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
@ -1351,10 +1360,6 @@ class OpConverterTest : public ::testing::Test {
} }
} }
void TestMatMulHelper(
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
const std::string& op_name);
// Expose quantization_ranges_ for tests // Expose quantization_ranges_ for tests
std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() { std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
return converter_->quantization_ranges_; return converter_->quantization_ranges_;
@ -1682,59 +1687,60 @@ TEST_F(OpConverterTest, ConvertReshape) {
// Helper function for testing MatMul and BatchMatMul // Helper function for testing MatMul and BatchMatMul
// get_matmul corresponds to the function used to generate the node. It should // get_matmul corresponds to the function used to generate the node. It should
// accept (DataType, transpose_a, transpose_b) as parameters. // accept (DataType, transpose_a, transpose_b) as parameters.
void OpConverterTest::TestMatMulHelper( void TestMatMulHelper(
OpConverterTest* test,
const std::function<NodeDef(DataType, bool, bool)>& get_matmul, const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
const std::string& op_name) { const std::string& op_name) {
// HACK: This needs to be done in a better way. // HACK: This needs to be done in a better way.
const bool is_batch_matmul = op_name == "BatchMatMul"; const bool is_batch_matmul = op_name == "BatchMatMul";
{ {
// Unsupported data type. // Unsupported data type.
Reset(); test->Reset();
NodeDef node_def = get_matmul(DT_INT32, false, false); NodeDef node_def = get_matmul(DT_INT32, false, false);
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); test->AddTestTensor("input", {2}, /*batch_size=*/1,
AddTestWeights<int32>("weights", {2, 1}, {3, 5}); nvinfer1::DataType::kINT32);
RunValidationAndConversion( test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
test->RunValidationAndConversion(
node_def, error::UNIMPLEMENTED, node_def, error::UNIMPLEMENTED,
("Data type int32 is not supported for " + op_name + StrCat("Data type int32 is not supported for ", op_name,
", " ", must be one of [float, half], at my_matmul")
"must be one of [float, half], at my_matmul")
.c_str()); .c_str());
} }
// OK. // OK.
for (bool transpose_a : {false, true}) { for (bool transpose_a : {false, true}) {
for (bool transpose_b : {false, true}) { for (bool transpose_b : {false, true}) {
Reset(); test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b); NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
AddTestTensor("input", {2}, /*batch_size=*/1); test->AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3}); test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) { if (is_batch_matmul) {
if (transpose_a || transpose_b) { if (transpose_a || transpose_b) {
RunValidationAndConversion( test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension for " "Input weight attempts to broadcast across batch dimension for "
"BatchMatMul, at my_matmul"); "BatchMatMul, at my_matmul");
} else { } else {
RunValidationAndConversion( test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension"); "Input weight attempts to broadcast across batch dimension");
} }
continue; continue;
} else if (transpose_a) { } else if (transpose_a) {
RunValidationAndConversion( test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 " "Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions"); "non-batch dimensions");
continue; continue;
} }
RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output; TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor()); ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}}; const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}}; DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
BuildAndRun(input_data, &output_data); test->BuildAndRun(input_data, &output_data);
if (transpose_b) { if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else { } else {
@ -1744,31 +1750,31 @@ void OpConverterTest::TestMatMulHelper(
} }
// OK, 3D inputs // OK, 3D inputs
for (bool transpose_b : {false, true}) { for (bool transpose_b : {false, true}) {
Reset(); test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b); NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
AddTestTensor("input", {2}, /*batch_size=*/1); test->AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3}); test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) { if (is_batch_matmul) {
if (transpose_b) { if (transpose_b) {
RunValidationAndConversion( test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension for " "Input weight attempts to broadcast across batch dimension for "
"BatchMatMul, at my_matmul"); "BatchMatMul, at my_matmul");
} else { } else {
RunValidationAndConversion( test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension"); "Input weight attempts to broadcast across batch dimension");
} }
continue; continue;
} }
RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output; TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor()); ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}}; const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}}; DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
BuildAndRun(input_data, &output_data); test->BuildAndRun(input_data, &output_data);
if (transpose_b) { if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else { } else {
@ -1832,7 +1838,7 @@ TEST_F(OpConverterTest, ConvertMatMul) {
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
"Cannot currently transpose constant input if it is not 2 dimensional"); "Cannot currently transpose constant input if it is not 2 dimensional");
} }
TestMatMulHelper(get_matmul_nodedef, "MatMul"); TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
} }
TEST_F(OpConverterTest, ConvertBatchMatMul) { TEST_F(OpConverterTest, ConvertBatchMatMul) {
@ -1889,7 +1895,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
} }
} }
TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul"); TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
} }
template <DataType dtype> template <DataType dtype>
@ -2010,250 +2016,82 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
} }
template <typename OpType, DataType dtype> template <typename OpType, DataType dtype>
void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor,
typedef typename EnumToDataType<dtype>::Type CType; bool operand_2_is_tensor) {
for (auto swap_inputs : {false, true}) {
test->Reset();
NodeDef node_def;
if (swap_inputs) {
node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
} else {
node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
}
const std::vector<CType> operand1{CType(3), CType(7.5)};
const std::vector<CType> operand2{CType(2), CType(3)};
// It requires the dims to be at least of rank 3 to apply an IScaleLayer.
test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
/*values=*/swap_inputs ? operand1 : operand2);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(10.5)));
} else if (node_def.op() == "Sub") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1), CType(4.5)));
} else if (node_def.op() == "Mul") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(6), CType(22.5)));
} else if (node_def.op() == "Div") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(2.5)));
} else if (node_def.op() == "RealDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(2.5)));
} else {
ASSERT_TRUE(false);
}
}
}
template <DataType dtype>
void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const NodeDef node_def =
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
const std::vector<CType> weights{CType(10), CType(20)};
// There are two types of valid dim pairs which requires channel-wise
// broadcasting:
// - input dims (X Y Z) vs weights dims (X 1 1)
// - input dims (X Y Z) vs weights dims (Z)
// Here X=Z=2 and Y=1.
for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
test->Reset();
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", weights_dims, weights);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
test->BuildAndRun(input_data, &output_data);
if (weights_dims.size() == 1) {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(22), CType(13), CType(24)));
} else {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(23), CType(24)));
}
}
}
template <DataType dtype>
void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const NodeDef node_def =
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
const std::vector<CType> weights{CType(10)};
test->Reset();
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
test->BuildAndRun(input_data, &output_data);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(13), CType(14)));
}
template <typename OpType>
void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
const std::vector<int32>& input_dims,
const std::vector<int>& weights_dims,
error::Code code = error::OK,
const char* error_msg_substr = nullptr,
const int input_batch_size = 1) {
const DataType dtype = DT_FLOAT;
typedef typename EnumToDataType<dtype>::Type CType;
const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims));
const size_t num_weights =
TrtWeightDimsNumElements(GetTestDims(weights_dims));
test->Reset();
const NodeDef node_def =
GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>(
"weights", /*dims=*/weights_dims,
/*values=*/std::vector<CType>(num_weights, CType(1)));
test->RunValidationAndConversion(node_def, code, error_msg_substr);
if (code != error::OK) return;
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
CheckAddedLayers(test, /*expect_scale_layer=*/false);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
// Check the dims of the output ITensor.
std::vector<int> expected_output_dims = input_dims;
for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
i >= 0 && j >= 0; --i, --j) {
if (expected_output_dims[i] == 1) {
expected_output_dims[i] = weights_dims[j];
}
}
ExpectTrtDimsEqualsArray(expected_output_dims,
output.tensor()->getDimensions());
// Check the result of running the engine.
const int expected_num_outputs =
TrtTensorDimsNumElements(GetTestDims(expected_output_dims));
const DataVec input_data{
{"input", ConstructTensor<CType>(num_inputs, CType(2))}};
DataVec output_data{
{"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
test->BuildAndRun(input_data, &output_data);
if (node_def.op() == "Add") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
} else if (node_def.op() == "Minimum") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
} else {
ASSERT_TRUE(false);
}
}
template <typename OpType, DataType dtype>
void TestBinaryTensorOpTensor(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType; typedef typename EnumToDataType<dtype>::Type CType;
test->Reset(); test->Reset();
const NodeDef node_def = const NodeDef node_def =
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype); GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, if (operand_1_is_tensor) {
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2,
TfDataTypeToTrt(dtype)); TfDataTypeToTrt(dtype));
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, } else {
test->AddTestWeights("input1", /*dims=*/{1, 2},
/*values=*/std::vector<CType>{CType(3), CType(6)});
}
if (operand_2_is_tensor) {
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2,
TfDataTypeToTrt(dtype)); TfDataTypeToTrt(dtype));
} else {
test->AddTestWeights("input2", /*dims=*/{2, 1},
/*values=*/std::vector<CType>{CType(2), CType(3)});
}
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. DataVec input_data;
CheckAddedLayers(test, /*expect_scale_layer=*/false); if (operand_1_is_tensor) {
input_data.push_back(
{"input1",
test::AsTensor<CType>({CType(3), CType(6), CType(3), CType(6)})});
}
if (operand_2_is_tensor) {
input_data.push_back(
{"input2",
test::AsTensor<CType>({CType(2), CType(3), CType(2), CType(3)})});
}
DataVec output_data{{"my_binary", ConstructTensor<CType>(8)}};
// Check output dims. // Check output dims.
TRT_TensorOrWeights output; TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor()); ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
const DataVec input_data{
{"input1", test::AsTensor<CType>({CType(3), CType(6)})},
{"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
// After broadcasting first input becomes {3, 6, 3, 6} and second input // After broadcasting first input becomes {3, 6, 3, 6} and second input
// becomes {2, 3, 2, 3}. // becomes {2, 3, 2, 3}.
test->BuildAndRun( test->BuildAndRun(
input_data, &output_data, input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32,
/*batch_size=*/2);
if (node_def.op() == "Add") { if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(
ElementsAre(CType(5), CType(8), CType(6), CType(9))); GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({5, 8, 6, 9, 5, 8, 6, 9})));
} else if (node_def.op() == "Sub") { } else if (node_def.op() == "Sub") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(
ElementsAre(CType(1), CType(4), CType(0), CType(3))); GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({1, 4, 0, 3, 1, 4, 0, 3})));
} else if (node_def.op() == "Mul") { } else if (node_def.op() == "Mul") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(6), CType(12), CType(9), CType(18))); ElementsAreArray(
CastTestVector<int, CType>({6, 12, 9, 18, 6, 12, 9, 18})));
} else if (node_def.op() == "Div") { } else if (node_def.op() == "Div") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "RealDiv") { } else if (node_def.op() == "RealDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "Minimum") { } else if (node_def.op() == "Minimum") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(
ElementsAre(CType(2), CType(2), CType(3), CType(3))); GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({2, 2, 3, 3, 2, 2, 3, 3})));
} else if (node_def.op() == "Maximum") { } else if (node_def.op() == "Maximum") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(
ElementsAre(CType(3), CType(6), CType(3), CType(6))); GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({3, 6, 3, 6, 3, 6, 3, 6})));
} else if (node_def.op() == "Pow") { } else if (node_def.op() == "Pow") {
ExpectArrayNear( ExpectArrayNear(
std::vector<CType>{CType(9), CType(36), CType(27), CType(216)}, CastTestVector<int, CType>({9, 36, 27, 216, 9, 36, 27, 216}),
GetSpanForData<CType>(output_data[0])); GetSpanForData<CType>(output_data[0]));
} else { } else {
ASSERT_TRUE(false); ASSERT_TRUE(false);
@ -2287,58 +2125,48 @@ TEST_F(OpConverterTest, ConvertBinary) {
"both input as constant at: my_add"); "both input as constant at: my_add");
} }
// Test BinaryTensorOpWeight() without broadcasting. // Test combinations of tensor vs weight inputs (except when both inputs are
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this); // weights).
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this); for (const bool operand_1_is_tensor : {true, false}) {
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this); for (const bool operand_2_is_tensor : {true, false}) {
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this); if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this); // FP32 tests
TestBinaryOp<ops::Add, DT_FLOAT>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this); operand_2_is_tensor);
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this); TestBinaryOp<ops::Sub, DT_FLOAT>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this); operand_2_is_tensor);
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this); TestBinaryOp<ops::Mul, DT_FLOAT>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this); operand_2_is_tensor);
TestBinaryOp<ops::Div, DT_FLOAT>(this, operand_1_is_tensor,
// Test BinaryTensorOpWeight() with channel-wise broadcasting. operand_2_is_tensor);
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this); TestBinaryOp<ops::RealDiv, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
// Test BinaryTensorOpWeight() with uniformly broadcasting. TestBinaryOp<ops::Minimum, DT_FLOAT>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this); operand_2_is_tensor);
TestBinaryOp<ops::Maximum, DT_FLOAT>(this, operand_1_is_tensor,
// Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). operand_2_is_tensor);
// Unsupported op. TestBinaryOp<ops::Pow, DT_FLOAT>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1}); operand_2_is_tensor);
// Rank of input tensor dimension <3. // FP16 tests
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1}); // TODO(tmorris): Use templates to avoid duplication.
// Broadcast on batch dimension, should fail. TestBinaryOp<ops::Add, DT_HALF>(this, operand_1_is_tensor,
TestBinaryTensorOpWeightFallback<ops::Add>( operand_2_is_tensor);
this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, TestBinaryOp<ops::Sub, DT_HALF>(this, operand_1_is_tensor,
"Unsupported binary op broadcast scheme for op my_binary", operand_2_is_tensor);
/*input_batch_size=*/2); TestBinaryOp<ops::Mul, DT_HALF>(this, operand_1_is_tensor,
// Incompatible dims with per-channel mode. operand_2_is_tensor);
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1}); TestBinaryOp<ops::Div, DT_HALF>(this, operand_1_is_tensor,
// Incompatible dims. operand_2_is_tensor);
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2}); TestBinaryOp<ops::RealDiv, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
// Test BinaryTensorOpTensor() with broadcasting. TestBinaryOp<ops::Minimum, DT_HALF>(this, operand_1_is_tensor,
TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this); operand_2_is_tensor);
TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this); TestBinaryOp<ops::Maximum, DT_HALF>(this, operand_1_is_tensor,
TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this); operand_2_is_tensor);
TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this); TestBinaryOp<ops::Pow, DT_HALF>(this, operand_1_is_tensor,
TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this); operand_2_is_tensor);
TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this); }
TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this); }
TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
} }
TEST_F(OpConverterTest, ConvertQuantize) { TEST_F(OpConverterTest, ConvertQuantize) {
@ -2583,7 +2411,6 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
// implementation that, the extra output classes that are outside of the // implementation that, the extra output classes that are outside of the
// range specified by valid_detections[i] are not zeros but -1s. // range specified by valid_detections[i] are not zeros but -1s.
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}}; TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
const int batch_size = 1;
for (int i = 0; i < kCombinedNMSOKCases; ++i) { for (int i = 0; i < kCombinedNMSOKCases; ++i) {
Reset(); Reset();

View File

@ -14,6 +14,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
@ -32,9 +34,9 @@ namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace convert { namespace convert {
// TODO(sami): Remove VLOG messages once the code matures // TODO(sami): Remove VLOG messages once the code matures
using absl::AsciiStrToUpper;
using absl::StrAppend; using absl::StrAppend;
using absl::StrCat; using absl::StrCat;
using str_util::Uppercase;
Status TRTOptimizationPass::Init( Status TRTOptimizationPass::Init(
const RewriterConfig_CustomGraphOptimizer* config) { const RewriterConfig_CustomGraphOptimizer* config) {
@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init(
} }
if (params.count("precision_mode")) { if (params.count("precision_mode")) {
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName( TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
Uppercase(params.at("precision_mode").s()), &precision_mode_)); AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
} }
if (params.count("use_calibration")) { if (params.count("use_calibration")) {
use_calibration_ = params.at("use_calibration").b(); use_calibration_ = params.at("use_calibration").b();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <dirent.h> #include <dirent.h>
#include <string.h> #include <string.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
TF_ASSERT_OK(RunOpKernel()); TF_ASSERT_OK(RunOpKernel());
// Verify the result. // Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work. // string type output will remain on CPU, so we're not using GetOutput() here.
Tensor* output = context_->mutable_output(0); EXPECT_EQ("my_serialized_str",
EXPECT_EQ("my_serialized_str", output->scalar<string>()()); context_->mutable_output(0)->scalar<string>()());
} }
} // namespace tensorrt } // namespace tensorrt

View File

@ -87,15 +87,10 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
TF_ASSERT_OK(OpsTestBase::RunOpKernel()); TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result. // Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work. Tensor* output = OpsTestBase::GetOutput(0);
Tensor* output = OpsTestBase::context_->mutable_output(0); EXPECT_THAT(
const auto& tensor_map = output->flat<TypeParam>(); absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
std::vector<TypeParam> output_data(tensor_map.size()); output->NumElements()),
ASSERT_EQ(0, cudaDeviceSynchronize());
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
} }

View File

@ -681,31 +681,33 @@ Status SegmentGraph(const Graph* tf_graph,
<< " with parent=" << segment_root << ":" << s; << " with parent=" << segment_root << ":" << s;
} }
// Don't use small segments. const int num_effective_nodes = std::count_if(
if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) { segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
static auto noops =
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
return noops->count(node->type_string()) == 0;
});
// Don't use segments whose number of effective nodes is small.
if (num_effective_nodes < options.minimum_segment_size) {
VLOG(1) << "Segment " << segments->size() << " has only " VLOG(1) << "Segment " << segments->size() << " has only "
<< segment_nodes.size() << " nodes, dropping"; << num_effective_nodes << " effective nodes, dropping";
continue; continue;
} }
// TODO(sami): Make segmenter placement aware once trtscopes are in place
const auto& dev_itr = device_maps.find(segment_root); const auto& dev_itr = device_maps.find(segment_root);
if (dev_itr == device_maps.end() || dev_itr->second.empty()) { if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
VLOG(1) << "No device assigned to segment " << segments->size(); VLOG(1) << "No device assigned to segment " << segments->size();
segments->emplace_back(std::make_pair(segment_nodes, string()));
} else if (dev_itr->second.size() > 1) { } else if (dev_itr->second.size() > 1) {
string s("Segment "); string s = StrCat("Segment ", segments->size(),
StrAppend(&s, segments->size(), " has multiple devices attached: "); " has multiple devices attached: ");
for (const auto& dev : dev_itr->second) { for (const auto& dev : dev_itr->second) {
StrAppend(&s, dev, ", "); StrAppend(&s, dev, ", ");
} }
LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); LOG(WARNING) << s;
segments->emplace_back(
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
} else {
segments->emplace_back(
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
} }
segments->emplace_back(segment_nodes);
} }
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
for (const auto& d : device_maps) { for (const auto& d : device_maps) {

View File

@ -31,10 +31,8 @@ namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace segment { namespace segment {
// Vector of segments, each entry contains a set of node pointers and a device // Vector of segments, each entry contains a set of node pointers.
// name in the segment. using SegmentNodesVector = std::vector<std::set<const Node*>>;
using SegmentNodesVector =
std::vector<std::pair<std::set<const Node*>, string>>;
struct SegmentOptions { struct SegmentOptions {
// Segment must contain at least this many nodes. // Segment must contain at least this many nodes.

View File

@ -77,7 +77,7 @@ class SegmentTest : public ::testing::Test {
EXPECT_EQ(expected_segments.size(), segments.size()); EXPECT_EQ(expected_segments.size(), segments.size());
for (int i = 0; i < segments.size(); ++i) { for (int i = 0; i < segments.size(); ++i) {
std::set<string> segment_node_names; std::set<string> segment_node_names;
for (const Node* node : segments[i].first) { for (const Node* node : segments[i]) {
segment_node_names.insert(node->name()); segment_node_names.insert(node->name());
} }
const auto& expected = expected_segments[i]; const auto& expected = expected_segments[i];
@ -262,6 +262,23 @@ TEST_F(SegmentTest, BigIfElse) {
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
} }
TEST_F(SegmentTest, IdentityOps) {
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
Graph g(OpRegistry::Global());
TF_EXPECT_OK(s.ToGraph(&g));
const std::set<string> all_identities = {"identity0", "identity1",
"identity2", "identity3"};
// Identity ops are not counted as effective ops in the segment, so no segment
// will be formed in this case.
RunTest(&g, all_identities, all_identities, all_identities, {});
}
} // namespace test } // namespace test
} // namespace segment } // namespace segment
} // namespace tensorrt } // namespace tensorrt

View File

@ -1,7 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
package(
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
)
package_group( package_group(
name = "internal", name = "internal",
packages = [ packages = [
@ -23,15 +26,12 @@ package_group(
], ],
) )
package(
default_visibility = [":internal"],
)
load( load(
"//tensorflow/core:platform/default/cuda_build_defs.bzl", "//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured", "if_cuda_is_configured",
) )
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
cc_library( cc_library(
name = "tf2xla_supported_ops_lib", name = "tf2xla_supported_ops_lib",
@ -67,6 +67,19 @@ xla_proto_library(
], ],
) )
# A proto library that is minimal in size and dependencies for platforms like Android.
tf_portable_proto_library(
name = "portable_tf2xla_proto",
config_string = "allow_all:true",
header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"],
portable_deps = ["//tensorflow/core:android_proto_lib"],
proto_deps = [
":tf2xla_proto",
"//tensorflow/core:protos_all_cc",
],
visibility = ["//visibility:public"],
)
xla_py_proto_library( xla_py_proto_library(
name = "tf2xla_py", name = "tf2xla_py",
has_services = False, has_services = False,

View File

@ -1,9 +1,8 @@
package( package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"], default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
tf_gen_op_wrapper_cc( tf_gen_op_wrapper_cc(

View File

@ -918,10 +918,16 @@ string Conditional::name() const {
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
int port) { int port) {
NodeBuilder id_builder(replacee->name(), "Identity");
id_builder.Input(if_node, port);
string outside_compilation;
if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {
id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
}
Node* id; Node* id;
TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
.Input(if_node, port)
.Finalize(graph_, &id));
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
return Status::OK(); return Status::OK();

View File

@ -247,8 +247,8 @@ Status FunctionalizeControlFlowPass::Run(
// multiple times, and we want to avoid functionalize it again. // multiple times, and we want to avoid functionalize it again.
static std::map<string, string>* kNodeTypeToFunctionAttrMapping = static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
new std::map<string, string>{ new std::map<string, string>{
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass. // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"TPUReplicate", "computation"}, {"_TPUReplicate", "computation"},
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass. // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
{"XlaLaunch", "function"}, {"XlaLaunch", "function"},
}; };

View File

@ -1,9 +1,8 @@
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"], default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
) )
tf_kernel_library( tf_kernel_library(
@ -195,6 +194,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:training_ops",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )

View File

@ -43,7 +43,7 @@ class AssertOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(AssertOp); TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
}; };
REGISTER_XLA_OP(Name("Assert"), AssertOp); REGISTER_XLA_OP(Name("Assert").CompilationOnly(), AssertOp);
} // anonymous namespace } // anonymous namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,7 +39,10 @@ class FusedBatchNormOp : public XlaOpKernel {
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); }
protected:
virtual void CompileImpl(XlaOpKernelContext* ctx) {
xla::PrimitiveType input_type; xla::PrimitiveType input_type;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
@ -116,8 +119,29 @@ class FusedBatchNormOp : public XlaOpKernel {
bool is_on_gpu_; bool is_on_gpu_;
}; };
class FusedBatchNormOpV3 : public FusedBatchNormOp {
public:
explicit FusedBatchNormOpV3(OpKernelConstruction* ctx)
: FusedBatchNormOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
FusedBatchNormOp::CompileImpl(ctx);
if (!ctx->status().ok()) {
return;
}
ctx->SetConstantOutput(5, Tensor());
}
private:
float epsilon_;
TensorFormat data_format_;
bool is_training_;
bool is_on_gpu_;
};
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
class FusedBatchNormGradOp : public XlaOpKernel { class FusedBatchNormGradOp : public XlaOpKernel {
public: public:
@ -233,6 +257,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp);
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
xla::XlaOp conv = xla::XlaOp conv =
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims, depth); lhs_dilation, rhs_dilation, dims, depth);
// Feature group convolution, will end up with the kernel_size change more
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
conv_dims.back() = depth;
conv_dims.push_back(kernel_size);
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
conv_dims.pop_back();
conv_dims.back() *= kernel_size;
conv = xla::Reshape(conv, conv_dims);
ctx->SetOutput(0, conv); ctx->SetOutput(0, conv);
} }

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <algorithm>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
@ -20,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
class GatherOp : public XlaOpKernel { class GatherOp : public XlaOpKernel {
public: public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
// Set batch_dims_ to 0 if the attribute does not exist.
if (context->HasAttr("batch_dims")) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
} else {
batch_dims_ = 0;
}
}
void Compile(XlaOpKernelContext* context) override { void Compile(XlaOpKernelContext* context) override {
xla::XlaBuilder* builder = context->builder();
auto input = context->Input(0); auto input = context->Input(0);
auto input_shape = context->InputShape(0); auto input_shape = context->InputShape(0);
auto indices = context->Input(1); auto indices = context->Input(1);
auto indices_shape = context->InputShape(1); auto indices_shape = context->InputShape(1);
int64 axis = 0;
absl::optional<int64> axis;
if (context->num_inputs() == 3) { if (context->num_inputs() == 3) {
const TensorShape axis_shape = context->InputShape(2); const TensorShape axis_shape = context->InputShape(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
errors::InvalidArgument("axis must be int32 or int64")); errors::InvalidArgument("axis must be int32 or int64"));
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); int64 axis_input;
OP_REQUIRES_OK(context,
context->ConstantInputAsIntScalar(2, &axis_input));
const auto params_dims = input_shape.dims(); const auto params_dims = input_shape.dims();
OP_REQUIRES( OP_REQUIRES(context,
context, -params_dims <= axis && axis < params_dims, -params_dims <= axis_input && axis_input < params_dims,
errors::InvalidArgument("Expected axis in the range [", -params_dims, errors::InvalidArgument("Expected axis in the range [",
", ", params_dims, "), but got ", axis)); -params_dims, ", ", params_dims,
if (axis < 0) { "), but got ", axis_input));
axis += params_dims; if (axis_input < 0) {
axis_input += params_dims;
} }
axis = axis_input;
} }
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices_shape.dims() + batch_dims_;
}
axis = axis.value_or(batch_dims_);
OP_REQUIRES(context,
batch_dims_ >= -indices_shape.dims() &&
batch_dims_ < indices_shape.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices_shape.dims(), ", ",
indices_shape.dims(), "), but got ",
batch_dims_));
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(input) (",
input_shape.dims(), ")."));
OP_REQUIRES(context, *axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than or equal to ",
"axis (", *axis, ")."));
}
axis = axis.value_or(0);
DataType index_type = input_type(1); DataType index_type = input_type(1);
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64")); errors::InvalidArgument("indices must be int32 or int64"));
xla::XlaOp gather; xla::XlaOp gather;
if (batch_dims_ > 0) {
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
} else {
// XlaGather() manages degenerate cases, like empty-indices, which are
// error conditions and caught above if batch_dims is not 0.
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, axis, context, XlaGather(input, input_shape, indices, indices_shape, *axis,
/*indices_are_nd=*/false, input_type(0), index_type, /*indices_are_nd=*/false, input_type(0),
builder, &gather)); index_type, context->builder(), &gather));
}
context->SetOutput(0, gather); context->SetOutput(0, gather);
} }
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
// The number of batch dimensions, as passed in the batch_dims attribute.
// It must be less than rank(indices).
int32 batch_dims_ = 0;
}; };
REGISTER_XLA_OP(Name("Gather"), GatherOp); REGISTER_XLA_OP(Name("Gather"), GatherOp);

View File

@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1}); xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
// Calculate in each row of `predictions`, how many values are larger than // Calculate in each row of `predictions`, how many values are larger than
// the value of target class. Then return the result whether the count <= k, // the value of target class. Then return the result whether the count < k,
// which indicates the target is in topk. // which indicates the target is in topk.
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0}); xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2); xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
xla::XlaOp num_ge_r1 = xla::Reduce( xla::XlaOp num_gt_r1 = xla::Reduce(
one_hot_r2, zero_r0, one_hot_r2, zero_r0,
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
xla::XlaOp result = xla::XlaOp result =
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k)); xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
xla::IsFinite(targets_values_r1));
context->SetOutput(0, result); context->SetOutput(0, result);
} }

View File

@ -67,9 +67,9 @@ class MatMulOp : public XlaOpKernel {
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
a_shape.dim_size(first_index) == b_shape.dim_size(second_index), a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
errors::InvalidArgument("Matrix size-compatible: In[0]: ", errors::InvalidArgument(
a_shape.DebugString(), ", In[1]: ", "Matrix size-incompatible: In[0]: ", a_shape.DebugString(),
b_shape.DebugString())); ", In[1]: ", b_shape.DebugString()));
xla::XlaOp a = ctx->Input(0); xla::XlaOp a = ctx->Input(0);
xla::XlaOp b = ctx->Input(1); xla::XlaOp b = ctx->Input(1);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -22,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -77,5 +79,58 @@ class SelectOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Select"), SelectOp); REGISTER_XLA_OP(Name("Select"), SelectOp);
class SelectOpV2 : public XlaOpKernel {
public:
explicit SelectOpV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape cond_shape = ctx->InputShape(0);
const TensorShape then_shape = ctx->InputShape(1);
const TensorShape else_shape = ctx->InputShape(2);
// Compute the output shape from the broadcast of the two data inputs, with
// the broadcast of the conditional.
// Then Broadcast all three inputs to the output shape and emit a select.
BCast bcast_then_else(BCast::FromShape(then_shape),
BCast::FromShape(else_shape),
/*fewer_dims_optimization=*/false);
if (!bcast_then_else.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", then_shape.DebugString(), " vs. ",
else_shape.DebugString()));
return;
}
BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape),
/*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ",
BCast::ToShape(bcast_then_else.output_shape()).DebugString(), " vs. ",
cond_shape.DebugString()));
return;
}
auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_cond.status());
auto cond_handle = bcasted_cond.ValueOrDie();
auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_then.status());
auto then_handle = bcasted_then.ValueOrDie();
auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_else.status());
auto else_handle = bcasted_else.ValueOrDie();
ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOpV2);
};
REGISTER_XLA_OP(Name("SelectV2"), SelectOpV2);
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for softmax. // XLA-specific Ops for softmax.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -145,23 +146,36 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
: XlaOpKernel(ctx) {} : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape logits_shape = ctx->InputShape(0);
const TensorShape labels_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
errors::InvalidArgument(
"logits and labels must be same size: logits_size=",
logits_shape.DebugString(),
" labels_size=", labels_shape.DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional"));
// As we already tested that both inputs have the same shape no need to
// check that "labels" is a matrix too.
const DataType type = input_type(0); const DataType type = input_type(0);
const xla::PrimitiveType xla_type = ctx->input_xla_type(0); const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0); auto logits = ctx->Input(0);
auto labels = ctx->Input(1); auto labels = ctx->Input(1);
const TensorShape logits_shape = ctx->InputShape(0);
const TensorShape labels_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_shape),
errors::InvalidArgument("labels must be 2-dimensional"));
// Confirm that any necessary broadcasting to make the shapes the same will
// succeed.
for (int dim = 0; dim < 2; dim++) {
OP_REQUIRES(
ctx,
labels_shape.dim_size(dim) == 1 ||
logits_shape.dim_size(dim) == labels_shape.dim_size(dim),
errors::InvalidArgument("logits and labels must be same size after "
"broadcasting of labels: logits_size=",
logits_shape.DebugString(),
" labels_size=", labels_shape.DebugString()));
}
if (!logits_shape.IsSameSize(labels_shape)) {
auto labels_or = BroadcastTo(labels, logits_shape.dim_sizes());
OP_REQUIRES_OK(ctx, labels_or.status());
labels = labels_or.ConsumeValueOrDie();
}
xla::XlaOp loss, backprop; xla::XlaOp loss, backprop;
std::tie(loss, backprop) = std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);

View File

@ -1,9 +1,8 @@
# Utilities for building XLA computations. # Utilities for building XLA computations.
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = ["//tensorflow/compiler/tf2xla:friends"], default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
licenses = ["notice"], # Apache 2.0
) )
# Filegroup used to collect source files for dependency checking. # Filegroup used to collect source files for dependency checking.

View File

@ -1,9 +1,8 @@
package( package(
default_visibility = ["//tensorflow:internal"], default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
) )
licenses(["notice"]) # Apache 2.0
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_custom_op_library", "tf_custom_op_library",

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//visibility:public", "//visibility:public",
], ],
licenses = ["notice"], # Apache 2.0
) )
load( load(

View File

@ -550,6 +550,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
}; };
GraphOptimizer::Options graph_optimizer_options; GraphOptimizer::Options graph_optimizer_options;
graph_optimizer_options.cf_consider_fn = cf_consider_fn; graph_optimizer_options.cf_consider_fn = cf_consider_fn;
graph_optimizer_options.inline_multi_device_functions = true;
optimizer.Optimize(flib_runtime_, flib_runtime_->env(), optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, graph_optimizer_options); /*device=*/nullptr, &graph, graph_optimizer_options);

View File

@ -116,9 +116,12 @@ class XlaOpRegistry {
// If we should cluster operations returning DT_VARIANT. // If we should cluster operations returning DT_VARIANT.
bool cluster_variant_ops = false; bool cluster_variant_ops = false;
// Whether ops known to be slow or to have correctness issues should be // Whether ops known to be slow should be auto-clustered.
bool cluster_slow_ops = false;
// Whether ops known to have numerical accuracy issues should be
// auto-clustered. // auto-clustered.
bool cluster_slow_and_inaccurate_ops = false; bool cluster_inaccurate_ops = false;
}; };
// Registers an XLA backend. `compilation_device_name` is the name of the // Registers an XLA backend. `compilation_device_name` is the name of the

View File

@ -1,6 +1,7 @@
licenses(["notice"]) # Apache 2.0 package(
default_visibility = ["//tensorflow:internal"],
package(default_visibility = ["//tensorflow:internal"]) licenses = ["notice"], # Apache 2.0
)
package_group( package_group(
name = "friends", name = "friends",
@ -575,6 +576,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":types", ":types",
"@com_google_absl//absl/strings",
], ],
) )
@ -881,6 +883,26 @@ tf_cc_test(
], ],
) )
cc_library(
name = "refcounting_hash_map",
hdrs = ["refcounting_hash_map.h"],
deps = [
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "refcounting_hash_map_test",
srcs = ["refcounting_hash_map_test.cc"],
deps = [
":refcounting_hash_map",
":test",
"//tensorflow/core:test_main",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.

Some files were not shown because too many files have changed in this diff Show More