Merge branch 'tensorflow-master'
This commit is contained in:
commit
33dffe53fb
44
.bazelrc
44
.bazelrc
@ -39,32 +39,46 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
|
||||
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
|
||||
build:download_clang --define=using_clang=true
|
||||
build:download_clang --action_env TF_DOWNLOAD_CLANG=1
|
||||
# Instruct clang to use LLD for linking.
|
||||
# This only works with GPU builds currently, since Bazel sets -B/usr/bin in
|
||||
# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
|
||||
# the downloaded one.
|
||||
build:download_clang_use_lld --linkopt='-fuse-ld=lld'
|
||||
|
||||
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
build:using_cuda --action_env TF_NEED_CUDA=1
|
||||
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
|
||||
# This config refers to building CUDA op kernels with nvcc.
|
||||
build:cuda --config=using_cuda
|
||||
build:cuda --define=using_cuda_nvcc=true
|
||||
|
||||
# This config refers to building CUDA op kernels with clang.
|
||||
build:cuda_clang --config=using_cuda
|
||||
build:cuda_clang --define=using_cuda_clang=true
|
||||
build:cuda_clang --define=using_clang=true
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
|
||||
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
|
||||
build:rocm --action_env TF_NEED_ROCM=1
|
||||
|
||||
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl --define=using_sycl=true --define=using_trisycl=false
|
||||
build:sycl --define=using_sycl=true
|
||||
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
|
||||
|
||||
build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
|
||||
|
||||
build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
|
||||
|
||||
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true
|
||||
build:sycl_nodouble --config=sycl
|
||||
build:sycl_trisycl --define=using_trisycl=true
|
||||
|
||||
# Options extracted from configure script
|
||||
build:gdr --define=with_gdr_support=true
|
||||
@ -87,6 +101,9 @@ build --spawn_strategy=standalone
|
||||
build --strategy=Genrule=standalone
|
||||
build -c opt
|
||||
|
||||
# Make Bazel print out all options from rc files.
|
||||
build --announce_rc
|
||||
|
||||
# Other build flags.
|
||||
build --define=grpc_no_ares=true
|
||||
|
||||
@ -97,8 +114,7 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||
# Build TF with C++ 17 features.
|
||||
build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --cxxopt=-std=c++1z
|
||||
build:c++1z --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build --define=PREFIX=/usr
|
||||
|
@ -38,7 +38,13 @@ working on getting your pull request submitted to our internal repository. After
|
||||
the change has been submitted internally, your pull request will be merged
|
||||
automatically on GitHub.
|
||||
|
||||
If you want to contribute but you're not sure where to start, take a look at the
|
||||
If you want to contribute, start working through the TensorFlow codebase,
|
||||
navigate to the
|
||||
[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start
|
||||
looking through interesting issues. If you are not sure of where to start, then
|
||||
start by trying one of the smaller/easier issues here i.e.
|
||||
[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue)
|
||||
and then take a look at the
|
||||
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
|
||||
These are issues that we believe are particularly well suited for outside
|
||||
contributions, often because we probably won't get to them right now. If you
|
||||
|
60
configure.py
60
configure.py
@ -403,7 +403,8 @@ def set_action_env_var(environ_cp,
|
||||
enabled_by_default,
|
||||
question=None,
|
||||
yes_reply=None,
|
||||
no_reply=None):
|
||||
no_reply=None,
|
||||
bazel_config_name=None):
|
||||
"""Set boolean action_env variable.
|
||||
|
||||
Ask user if query_item will be enabled. Default is used if no input is given.
|
||||
@ -418,12 +419,16 @@ def set_action_env_var(environ_cp,
|
||||
question: optional string for how to ask for user input.
|
||||
yes_reply: optional string for reply when feature is enabled.
|
||||
no_reply: optional string for reply when feature is disabled.
|
||||
bazel_config_name: adding config to .bazelrc instead of action_env.
|
||||
"""
|
||||
var = int(
|
||||
get_var(environ_cp, var_name, query_item, enabled_by_default, question,
|
||||
yes_reply, no_reply))
|
||||
|
||||
if not bazel_config_name:
|
||||
write_action_env_to_bazelrc(var_name, var)
|
||||
elif var:
|
||||
write_to_bazelrc('build --config=%s' % bazel_config_name)
|
||||
environ_cp[var_name] = str(var)
|
||||
|
||||
|
||||
@ -543,7 +548,8 @@ def set_tf_cuda_clang(environ_cp):
|
||||
False,
|
||||
question=question,
|
||||
yes_reply=yes_reply,
|
||||
no_reply=no_reply)
|
||||
no_reply=no_reply,
|
||||
bazel_config_name='cuda_clang')
|
||||
|
||||
|
||||
def set_tf_download_clang(environ_cp):
|
||||
@ -558,7 +564,8 @@ def set_tf_download_clang(environ_cp):
|
||||
False,
|
||||
question=question,
|
||||
yes_reply=yes_reply,
|
||||
no_reply=no_reply)
|
||||
no_reply=no_reply,
|
||||
bazel_config_name='download_clang')
|
||||
|
||||
|
||||
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
|
||||
@ -782,8 +789,8 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path):
|
||||
print('WARNING: The NDK version in %s is %s, which is not '
|
||||
'supported by Bazel (officially supported versions: %s). Please use '
|
||||
'another version. Compiling Android targets may result in confusing '
|
||||
'errors.\n' % (android_ndk_home_path, ndk_version,
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
'errors.\n' %
|
||||
(android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
|
||||
# Now grab the NDK API level to use. Note that this is different from the
|
||||
# SDK API level, as the NDK API level is effectively the *min* target SDK
|
||||
@ -952,6 +959,7 @@ def set_tf_nccl_version(environ_cp):
|
||||
ask_nccl_version, '')
|
||||
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
|
||||
|
||||
|
||||
def get_native_cuda_compute_capabilities(environ_cp):
|
||||
"""Get native cuda compute capabilities.
|
||||
|
||||
@ -1293,9 +1301,6 @@ def configure_ios():
|
||||
"""
|
||||
if not is_macos():
|
||||
return
|
||||
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
|
||||
print(
|
||||
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
|
||||
for filepath in APPLE_BAZEL_FILES:
|
||||
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
||||
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
||||
@ -1386,7 +1391,7 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.3')
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1422,7 +1427,12 @@ def main():
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
xla_enabled_by_default, 'xla')
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
'TF_NEED_OPENCL_SYCL',
|
||||
'OpenCL SYCL',
|
||||
False,
|
||||
bazel_config_name='sycl')
|
||||
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
|
||||
set_host_cxx_compiler(environ_cp)
|
||||
set_host_c_compiler(environ_cp)
|
||||
@ -1432,30 +1442,44 @@ def main():
|
||||
else:
|
||||
set_trisycl_include_dir(environ_cp)
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
|
||||
set_action_env_var(
|
||||
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and
|
||||
'LD_LIBRARY_PATH' in environ_cp and
|
||||
environ_cp.get('LD_LIBRARY_PATH') != '1'):
|
||||
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
||||
environ_cp.get('LD_LIBRARY_PATH'))
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
'TF_NEED_TENSORRT',
|
||||
'TensorRT',
|
||||
False,
|
||||
bazel_config_name='tensorrt')
|
||||
|
||||
environ_save = dict(environ_cp)
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
|
||||
if validate_cuda_config(environ_cp):
|
||||
cuda_env_names = [
|
||||
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
|
||||
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
|
||||
'TF_CUDA_VERSION',
|
||||
'TF_CUBLAS_VERSION',
|
||||
'TF_CUDNN_VERSION',
|
||||
'TF_TENSORRT_VERSION',
|
||||
'TF_NCCL_VERSION',
|
||||
'TF_CUDA_PATHS',
|
||||
# Items below are for backwards compatibility when not using
|
||||
# TF_CUDA_PATHS.
|
||||
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
|
||||
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
|
||||
'CUDA_TOOLKIT_PATH',
|
||||
'CUDNN_INSTALL_PATH',
|
||||
'NCCL_INSTALL_PATH',
|
||||
'NCCL_HDR_PATH',
|
||||
'TENSORRT_INSTALL_PATH'
|
||||
]
|
||||
# Note: set_action_env_var above already writes to bazelrc.
|
||||
for name in cuda_env_names:
|
||||
@ -1506,8 +1530,6 @@ def main():
|
||||
# CUDA not required. Ask whether we should download the clang toolchain and
|
||||
# use it for the CPU build.
|
||||
set_tf_download_clang(environ_cp)
|
||||
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
|
||||
write_to_bazelrc('build --config=download_clang')
|
||||
|
||||
# SYCL / ROCm / CUDA are mutually exclusive.
|
||||
# At most 1 GPU platform can be configured.
|
||||
|
@ -59,7 +59,7 @@ except ImportError:
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
|
||||
_CONTRIB_WARNING = """
|
||||
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
||||
The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
||||
For more information, please see:
|
||||
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
|
||||
* https://github.com/tensorflow/addons
|
||||
|
@ -21,6 +21,9 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"tf_attrtype.h",
|
||||
"tf_datatype.h",
|
||||
"tf_status.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
@ -51,6 +54,8 @@ tf_cuda_library(
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
"c_api_internal.h",
|
||||
"tf_datatype.h",
|
||||
"tf_status.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
@ -61,6 +66,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_attrtype",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -71,16 +77,26 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_attrtype",
|
||||
hdrs = ["tf_attrtype.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
"tf_attrtype.h",
|
||||
"tf_datatype.h",
|
||||
"tf_status.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_no_xla",
|
||||
":c_api_internal",
|
||||
":tf_attrtype",
|
||||
] + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -96,14 +112,21 @@ tf_cuda_library(
|
||||
"c_api.cc",
|
||||
"c_api_function.cc",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [":c_api_internal"] + select({
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":tf_attrtype",
|
||||
":tf_datatype",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc:gradients",
|
||||
@ -124,6 +147,37 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_status",
|
||||
srcs = ["tf_status.cc"],
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_datatype",
|
||||
srcs = ["tf_datatype.cc"],
|
||||
hdrs = ["tf_datatype.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
@ -137,6 +191,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":checkpoint_reader",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
@ -151,15 +206,6 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_headers",
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"version_script.lds",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
|
||||
@ -97,7 +98,6 @@ using tensorflow::TensorId;
|
||||
using tensorflow::TensorShape;
|
||||
using tensorflow::TensorShapeProto;
|
||||
using tensorflow::VersionDef;
|
||||
using tensorflow::error::Code;
|
||||
using tensorflow::errors::FailedPrecondition;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
@ -108,34 +108,6 @@ extern "C" {
|
||||
// --------------------------------------------------------------------------
|
||||
const char* TF_Version() { return TF_VERSION_STRING; }
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
size_t TF_DataTypeSize(TF_DataType dt) {
|
||||
return static_cast<size_t>(
|
||||
tensorflow::DataTypeSize(static_cast<DataType>(dt)));
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
TF_Status* TF_NewStatus() { return new TF_Status; }
|
||||
|
||||
void TF_DeleteStatus(TF_Status* s) { delete s; }
|
||||
|
||||
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
|
||||
if (code == TF_OK) {
|
||||
s->status = Status::OK();
|
||||
return;
|
||||
}
|
||||
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
|
||||
}
|
||||
|
||||
TF_Code TF_GetCode(const TF_Status* s) {
|
||||
return static_cast<TF_Code>(s->status.code());
|
||||
}
|
||||
|
||||
const char* TF_Message(const TF_Status* s) {
|
||||
return s->status.error_message().c_str();
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
@ -1697,7 +1669,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
||||
if (metadata.list_size == 0) {
|
||||
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
|
||||
const auto& a = oper->node.op_def().attr(i);
|
||||
if (a.name().compare(attr_name) != 0) continue;
|
||||
if (a.name() != attr_name) continue;
|
||||
const string& typestr = a.type();
|
||||
if (typestr == "list(string)") {
|
||||
metadata.type = TF_ATTR_STRING;
|
||||
@ -2517,8 +2489,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
||||
// used in this graph
|
||||
for (const auto& pair : g->name_map) {
|
||||
const string& name = pair.first;
|
||||
if (name.compare(prefix) == 0 ||
|
||||
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
|
||||
if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
|
||||
status->status = InvalidArgument(
|
||||
"prefix [", prefix,
|
||||
"] conflicts with existing node in the graph named [", name, "]");
|
||||
@ -2548,8 +2519,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
||||
// Adding the gradients to the graph can alter the prefix to prevent
|
||||
// name collisions only if this prefix has not been provided explicitly
|
||||
// by the user. If it was provided, assert that it remained intact.
|
||||
if (prefix != nullptr &&
|
||||
!tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
|
||||
if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"BUG: The gradients prefix have been unexpectedly altered when "
|
||||
"adding the nodes to the graph. This is a bug. Please file an "
|
||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_attrtype.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow.
|
||||
//
|
||||
@ -69,7 +73,7 @@ limitations under the License.
|
||||
// .dylib, .dll).
|
||||
// This duplicates the TF_EXPORT macro definition in
|
||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||
// of any other includes.$a
|
||||
// of any other includes.
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
@ -93,89 +97,6 @@ extern "C" {
|
||||
// TensorFlow library. TensorFlow using semantic versioning.
|
||||
TF_CAPI_EXPORT extern const char* TF_Version(void);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
|
||||
// The enum values here are identical to corresponding values in types.proto.
|
||||
typedef enum TF_DataType {
|
||||
TF_FLOAT = 1,
|
||||
TF_DOUBLE = 2,
|
||||
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
|
||||
TF_UINT8 = 4,
|
||||
TF_INT16 = 5,
|
||||
TF_INT8 = 6,
|
||||
TF_STRING = 7,
|
||||
TF_COMPLEX64 = 8, // Single-precision complex
|
||||
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
|
||||
TF_INT64 = 9,
|
||||
TF_BOOL = 10,
|
||||
TF_QINT8 = 11, // Quantized int8
|
||||
TF_QUINT8 = 12, // Quantized uint8
|
||||
TF_QINT32 = 13, // Quantized int32
|
||||
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
|
||||
TF_QINT16 = 15, // Quantized int16
|
||||
TF_QUINT16 = 16, // Quantized uint16
|
||||
TF_UINT16 = 17,
|
||||
TF_COMPLEX128 = 18, // Double-precision complex
|
||||
TF_HALF = 19,
|
||||
TF_RESOURCE = 20,
|
||||
TF_VARIANT = 21,
|
||||
TF_UINT32 = 22,
|
||||
TF_UINT64 = 23,
|
||||
} TF_DataType;
|
||||
|
||||
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
|
||||
// to the given TF_DataType enum value. Returns 0 for variable length types
|
||||
// (eg. TF_STRING) or on failure.
|
||||
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_Code holds an error code. The enum values here are identical to
|
||||
// corresponding values in error_codes.proto.
|
||||
typedef enum TF_Code {
|
||||
TF_OK = 0,
|
||||
TF_CANCELLED = 1,
|
||||
TF_UNKNOWN = 2,
|
||||
TF_INVALID_ARGUMENT = 3,
|
||||
TF_DEADLINE_EXCEEDED = 4,
|
||||
TF_NOT_FOUND = 5,
|
||||
TF_ALREADY_EXISTS = 6,
|
||||
TF_PERMISSION_DENIED = 7,
|
||||
TF_UNAUTHENTICATED = 16,
|
||||
TF_RESOURCE_EXHAUSTED = 8,
|
||||
TF_FAILED_PRECONDITION = 9,
|
||||
TF_ABORTED = 10,
|
||||
TF_OUT_OF_RANGE = 11,
|
||||
TF_UNIMPLEMENTED = 12,
|
||||
TF_INTERNAL = 13,
|
||||
TF_UNAVAILABLE = 14,
|
||||
TF_DATA_LOSS = 15,
|
||||
} TF_Code;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_Status holds error information. It either has an OK code, or
|
||||
// else an error code with an associated error message.
|
||||
typedef struct TF_Status TF_Status;
|
||||
|
||||
// Return a new status object.
|
||||
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
|
||||
|
||||
// Delete a previously created status object.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
|
||||
|
||||
// Record <code, msg> in *s. Any previous information is lost.
|
||||
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
|
||||
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
|
||||
const char* msg);
|
||||
|
||||
// Return the code record in *s.
|
||||
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
|
||||
|
||||
// Return a pointer to the (null-terminated) error message in *s. The
|
||||
// return value points to memory that is only usable until the next
|
||||
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
|
||||
// TF_OK.
|
||||
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_Buffer holds a pointer to a block of data and its associated length.
|
||||
// Typically, the data consists of a serialized protocol buffer, but other data
|
||||
@ -686,19 +607,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs(
|
||||
TF_Operation* oper, TF_Operation** control_outputs,
|
||||
int max_control_outputs);
|
||||
|
||||
// TF_AttrType describes the type of the value of an attribute on an operation.
|
||||
typedef enum TF_AttrType {
|
||||
TF_ATTR_STRING = 0,
|
||||
TF_ATTR_INT = 1,
|
||||
TF_ATTR_FLOAT = 2,
|
||||
TF_ATTR_BOOL = 3,
|
||||
TF_ATTR_TYPE = 4,
|
||||
TF_ATTR_SHAPE = 5,
|
||||
TF_ATTR_TENSOR = 6,
|
||||
TF_ATTR_PLACEHOLDER = 7,
|
||||
TF_ATTR_FUNC = 8,
|
||||
} TF_AttrType;
|
||||
|
||||
// TF_AttrMetadata describes the value of an attribute on an operation.
|
||||
typedef struct TF_AttrMetadata {
|
||||
// A boolean: 1 if the attribute value is a list, 0 otherwise.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
@ -37,6 +38,7 @@ using tensorflow::FunctionDef;
|
||||
using tensorflow::Node;
|
||||
using tensorflow::NodeBuilder;
|
||||
using tensorflow::Status;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
|
||||
namespace {
|
||||
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
|
||||
@ -149,7 +151,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
|
||||
}
|
||||
|
||||
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
|
||||
const auto& debug_str = func->fdef.DebugString();
|
||||
const auto& debug_str = DebugString(func->fdef);
|
||||
*len = debug_str.size();
|
||||
char* ret = static_cast<char*>(malloc(*len + 1));
|
||||
memcpy(ret, debug_str.c_str(), *len + 1);
|
||||
@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
|
||||
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
|
||||
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
|
||||
std::vector<std::string> variable_list;
|
||||
};
|
||||
|
||||
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
|
||||
TF_Status* status) {
|
||||
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
const auto& m = reader->GetVariableToDataTypeMap();
|
||||
for (auto it = m.begin(); it != m.end(); ++it)
|
||||
reader->variable_list.push_back(it->first);
|
||||
std::sort(reader->variable_list.begin(), reader->variable_list.end());
|
||||
return reader;
|
||||
}
|
||||
|
||||
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
|
||||
|
||||
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
|
||||
const char* name) {
|
||||
return reader->HasTensor(name);
|
||||
}
|
||||
|
||||
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
|
||||
int index) {
|
||||
return reader->variable_list[index].c_str();
|
||||
}
|
||||
|
||||
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
|
||||
return reader->variable_list.size();
|
||||
}
|
||||
|
||||
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
|
||||
const char* name) {
|
||||
const auto& m = reader->GetVariableToDataTypeMap();
|
||||
return static_cast<TF_DataType>(m.at(name));
|
||||
}
|
||||
|
||||
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
||||
const char* name, TF_Status* status) {
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::TF_TensorFromTensor(*tensor.get(), status);
|
||||
}
|
||||
|
||||
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||
const char* name, int64_t* dims,
|
||||
int num_dims, TF_Status* status) {
|
||||
const auto& shape = reader->GetVariableToShapeMap().at(name);
|
||||
int rank = shape.dims();
|
||||
if (num_dims != rank) {
|
||||
status->status = InvalidArgument("Expected rank is ", num_dims,
|
||||
" but actual rank is ", rank);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
dims[i] = shape.dim_size(i);
|
||||
}
|
||||
}
|
||||
|
||||
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
|
||||
const char* name) {
|
||||
const auto& m = reader->GetVariableToShapeMap();
|
||||
return m.at(name).dims();
|
||||
}
|
||||
|
||||
// This builder is used in the eager API to build a NodeDef.
|
||||
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
||||
using tensorflow::AttrBuilder::AttrBuilder;
|
||||
|
@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
// TF_NewCheckpointReader() return the CheckpointReader that can be use to
|
||||
// investigate or load the variable from the checkpoint file
|
||||
typedef struct TF_CheckpointReader TF_CheckpointReader;
|
||||
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
|
||||
const char* filename, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
|
||||
TF_CheckpointReader* reader);
|
||||
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
|
||||
TF_CheckpointReader* reader, const char* name);
|
||||
// Get the variable name at the given index
|
||||
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
|
||||
TF_CheckpointReader* reader, int index);
|
||||
// Get the number of variable in the checkpoint
|
||||
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
|
||||
// Get the DataType of a variable
|
||||
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
|
||||
TF_CheckpointReader* reader, const char* name);
|
||||
// Read the shape of a variable and write to `dims`
|
||||
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
|
||||
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
|
||||
TF_Status* status);
|
||||
// Get the number of dimension of a variable
|
||||
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
|
||||
TF_CheckpointReader* reader, const char* name);
|
||||
// Load the weight of a variable
|
||||
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
|
||||
TF_CheckpointReader* reader, const char* name, TF_Status* status);
|
||||
|
||||
// TF_NewAttrBuilder() returns an object that you can set attributes on as
|
||||
// though it were an op. This allows querying properties of that op for
|
||||
// type-checking purposes like if the op will run on a particular device type.
|
||||
|
@ -62,8 +62,8 @@ protocol: "grpc"
|
||||
TF_Buffer* null_result =
|
||||
TFE_GetServerDef(malformed_text_proto.c_str(), status);
|
||||
EXPECT_NE(TF_GetCode(status), TF_OK);
|
||||
EXPECT_TRUE(tensorflow::str_util::StrContains(
|
||||
TF_Message(status), "Invalid text proto for ServerDef"));
|
||||
EXPECT_TRUE(absl::StrContains(TF_Message(status),
|
||||
"Invalid text proto for ServerDef"));
|
||||
EXPECT_EQ(null_result, nullptr);
|
||||
|
||||
// Cleanup
|
||||
|
@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test {
|
||||
const std::unordered_set<string>& nodes) {
|
||||
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
||||
<< "Got unexpected number of nodes. Expected: ["
|
||||
<< str_util::Join(nodes, ", ")
|
||||
<< absl::StrJoin(nodes, ", ")
|
||||
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
||||
for (const NodeDef& node_def : fdef.node_def()) {
|
||||
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
||||
|
@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
namespace {
|
||||
|
||||
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Experimental extensions to the C API for eager execution of kernels.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
|
@ -30,7 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
@ -135,11 +137,12 @@ tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, int64 rendezvous_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const tensorflow::eager::CreateContextRequest& base_request,
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
|
||||
for (int i = 0; i < remote_workers.size(); i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextResponse response;
|
||||
request.set_rendezvous_id(rendezvous_id);
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -221,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&remote_device_mgr));
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendevzous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
|
||||
grpc_server->channel_cache();
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
|
||||
@ -230,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
||||
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
|
||||
remote_eager_workers.get(), ctx->context->Async(), base_request,
|
||||
&remote_contexts));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
|
||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, true));
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -250,9 +272,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
|
||||
return ctx->context->InitializeRemote(
|
||||
std::move(server), std::move(remote_eager_workers),
|
||||
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
|
||||
keep_alive_secs);
|
||||
std::move(server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
||||
remote_contexts, r, device_mgr, keep_alive_secs,
|
||||
worker_session->cluster_flr.get());
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
@ -970,6 +993,23 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
return t;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
// TensorHandles created by PyFuncOp lack context and therefore could
|
||||
// not be copied.
|
||||
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(), "CPU:0", &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
|
@ -462,6 +462,9 @@ class Tensor;
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
||||
#endif
|
||||
|
||||
|
@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
status->status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
<< padded_shape.DebugString();
|
||||
}
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
||||
bool ok = str_util::StrContains(base, substr);
|
||||
bool ok = absl::StrContains(base, substr);
|
||||
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
|
||||
return ok;
|
||||
}
|
||||
|
@ -1408,6 +1408,10 @@ void FunctionDefAndExecute(bool async) {
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
for (bool clear_cache : {true, false, true}) {
|
||||
if (clear_cache) {
|
||||
TFE_ContextClearCaches(ctx);
|
||||
}
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* retval[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -1431,6 +1435,7 @@ void FunctionDefAndExecute(bool async) {
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
@ -1,7 +1,9 @@
|
||||
# Description:
|
||||
# Experimental C APIs for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
|
@ -6,10 +6,9 @@ load(
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
tf_kernel_library(
|
||||
name = "bitcast_op",
|
||||
prefix = "bitcast_op",
|
||||
|
39
tensorflow/c/tf_attrtype.h
Normal file
39
tensorflow/c/tf_attrtype.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
|
||||
#define TENSORFLOW_C_TF_ATTRTYPE_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// TF_AttrType describes the type of the value of an attribute on an operation.
|
||||
typedef enum TF_AttrType {
|
||||
TF_ATTR_STRING = 0,
|
||||
TF_ATTR_INT = 1,
|
||||
TF_ATTR_FLOAT = 2,
|
||||
TF_ATTR_BOOL = 3,
|
||||
TF_ATTR_TYPE = 4,
|
||||
TF_ATTR_SHAPE = 5,
|
||||
TF_ATTR_TENSOR = 6,
|
||||
TF_ATTR_PLACEHOLDER = 7,
|
||||
TF_ATTR_FUNC = 8,
|
||||
} TF_AttrType;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_TF_ATTRTYPE_H_
|
23
tensorflow/c/tf_datatype.cc
Normal file
23
tensorflow/c/tf_datatype.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
size_t TF_DataTypeSize(TF_DataType dt) {
|
||||
return static_cast<size_t>(
|
||||
tensorflow::DataTypeSize(static_cast<tensorflow::DataType>(dt)));
|
||||
}
|
83
tensorflow/c/tf_datatype.h
Normal file
83
tensorflow/c/tf_datatype.h
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_TF_DATATYPE_H_
|
||||
#define TENSORFLOW_C_TF_DATATYPE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||
// .dylib, .dll).
|
||||
// This duplicates the TF_EXPORT macro definition in
|
||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||
// of any other includes.
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
|
||||
// The enum values here are identical to corresponding values in types.proto.
|
||||
typedef enum TF_DataType {
|
||||
TF_FLOAT = 1,
|
||||
TF_DOUBLE = 2,
|
||||
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
|
||||
TF_UINT8 = 4,
|
||||
TF_INT16 = 5,
|
||||
TF_INT8 = 6,
|
||||
TF_STRING = 7,
|
||||
TF_COMPLEX64 = 8, // Single-precision complex
|
||||
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
|
||||
TF_INT64 = 9,
|
||||
TF_BOOL = 10,
|
||||
TF_QINT8 = 11, // Quantized int8
|
||||
TF_QUINT8 = 12, // Quantized uint8
|
||||
TF_QINT32 = 13, // Quantized int32
|
||||
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
|
||||
TF_QINT16 = 15, // Quantized int16
|
||||
TF_QUINT16 = 16, // Quantized uint16
|
||||
TF_UINT16 = 17,
|
||||
TF_COMPLEX128 = 18, // Double-precision complex
|
||||
TF_HALF = 19,
|
||||
TF_RESOURCE = 20,
|
||||
TF_VARIANT = 21,
|
||||
TF_UINT32 = 22,
|
||||
TF_UINT64 = 23,
|
||||
} TF_DataType;
|
||||
|
||||
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
|
||||
// to the given TF_DataType enum value. Returns 0 for variable length types
|
||||
// (eg. TF_STRING) or on failure.
|
||||
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_TF_DATATYPE_H_
|
42
tensorflow/c/tf_status.cc
Normal file
42
tensorflow/c/tf_status.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::error::Code;
|
||||
|
||||
TF_Status* TF_NewStatus() { return new TF_Status; }
|
||||
|
||||
void TF_DeleteStatus(TF_Status* s) { delete s; }
|
||||
|
||||
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
|
||||
if (code == TF_OK) {
|
||||
s->status = Status::OK();
|
||||
return;
|
||||
}
|
||||
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
|
||||
}
|
||||
|
||||
TF_Code TF_GetCode(const TF_Status* s) {
|
||||
return static_cast<TF_Code>(s->status.code());
|
||||
}
|
||||
|
||||
const char* TF_Message(const TF_Status* s) {
|
||||
return s->status.error_message().c_str();
|
||||
}
|
88
tensorflow/c/tf_status.h
Normal file
88
tensorflow/c/tf_status.h
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_TF_STATUS_H_
|
||||
#define TENSORFLOW_C_TF_STATUS_H_
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct TF_Status TF_Status;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_Code holds an error code. The enum values here are identical to
|
||||
// corresponding values in error_codes.proto.
|
||||
typedef enum TF_Code {
|
||||
TF_OK = 0,
|
||||
TF_CANCELLED = 1,
|
||||
TF_UNKNOWN = 2,
|
||||
TF_INVALID_ARGUMENT = 3,
|
||||
TF_DEADLINE_EXCEEDED = 4,
|
||||
TF_NOT_FOUND = 5,
|
||||
TF_ALREADY_EXISTS = 6,
|
||||
TF_PERMISSION_DENIED = 7,
|
||||
TF_UNAUTHENTICATED = 16,
|
||||
TF_RESOURCE_EXHAUSTED = 8,
|
||||
TF_FAILED_PRECONDITION = 9,
|
||||
TF_ABORTED = 10,
|
||||
TF_OUT_OF_RANGE = 11,
|
||||
TF_UNIMPLEMENTED = 12,
|
||||
TF_INTERNAL = 13,
|
||||
TF_UNAVAILABLE = 14,
|
||||
TF_DATA_LOSS = 15,
|
||||
} TF_Code;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Return a new status object.
|
||||
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
|
||||
|
||||
// Delete a previously created status object.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
|
||||
|
||||
// Record <code, msg> in *s. Any previous information is lost.
|
||||
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
|
||||
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
|
||||
const char* msg);
|
||||
|
||||
// Return the code record in *s.
|
||||
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
|
||||
|
||||
// Return a pointer to the (null-terminated) error message in *s. The
|
||||
// return value points to memory that is only usable until the next
|
||||
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
|
||||
// TF_OK.
|
||||
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_TF_STATUS_H_
|
@ -4,10 +4,9 @@
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = [
|
||||
@ -638,6 +637,7 @@ cc_library(
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -657,6 +657,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) {
|
||||
}
|
||||
|
||||
string PrintString(const string& str) {
|
||||
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
||||
return strings::StrCat("\"", absl::CEscape(str), "\"");
|
||||
}
|
||||
|
||||
string PrintTensorShape(const TensorShapeProto& shape_proto) {
|
||||
@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) {
|
||||
string ret;
|
||||
for (int64 i = 0; i < num_elts; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, " ");
|
||||
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
|
||||
strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ op {
|
||||
)";
|
||||
|
||||
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_FALSE(str_util::StrContains(s, expected))
|
||||
EXPECT_FALSE(absl::StrContains(s, expected))
|
||||
<< "'" << s << "' contains '" << expected << "'";
|
||||
}
|
||||
|
||||
|
@ -275,7 +275,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
||||
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
||||
for (const string& entry : node_constraints) {
|
||||
StringPiece s(entry);
|
||||
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
||||
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
||||
current_constraints.emplace(s);
|
||||
}
|
||||
}
|
||||
|
@ -3,10 +3,9 @@
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
|
@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const Status status = LoadSavedModelInternal(session_options, run_options,
|
||||
export_dir, tags, bundle);
|
||||
auto log_and_count = [&](const string& status_str) {
|
||||
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
|
||||
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
|
||||
<< " }; Status: " << status_str << ". Took "
|
||||
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
|
||||
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
||||
|
@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) {
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{"missing-tag"}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||
<< st.error_message();
|
||||
@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe, "missing-tag"}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: "))
|
||||
<< st.error_message();
|
||||
@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget))
|
||||
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* meta_graph_def) {
|
||||
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
|
||||
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
|
||||
<< " }";
|
||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||
// Get tags from the graph_def.
|
||||
@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
error::Code::NOT_FOUND,
|
||||
strings::StrCat(
|
||||
"Could not find meta graph def matching supplied tags: { ",
|
||||
str_util::Join(tags, " "),
|
||||
absl::StrJoin(tags, " "),
|
||||
" }. To inspect available tag-sets in the SavedModel, please "
|
||||
"use the SavedModel CLI: `saved_model_cli`"));
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||
<< st.error_message();
|
||||
@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: "))
|
||||
<< st.error_message();
|
||||
|
@ -3,10 +3,9 @@
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
|
@ -167,8 +167,7 @@ namespace {
|
||||
|
||||
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
int32* dst) {
|
||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
|
||||
tensorflow::str_util::ConsumePrefix(&arg, "=")) {
|
||||
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
|
||||
char extra;
|
||||
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
||||
}
|
||||
@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
|
||||
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
bool* dst) {
|
||||
if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
|
||||
if (absl::ConsumePrefix(&arg, flag)) {
|
||||
if (arg.empty()) {
|
||||
*dst = true;
|
||||
return true;
|
||||
|
@ -1,7 +1,6 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
@ -1,4 +1,12 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
@ -14,15 +22,6 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
@ -200,6 +199,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:host_constant_op",
|
||||
"//tensorflow/core/kernels:identity_n_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:logging_ops",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:queue_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
@ -257,10 +257,8 @@ cc_library(
|
||||
name = "xla_launch_util",
|
||||
srcs = ["xla_launch_util.cc"],
|
||||
hdrs = ["xla_launch_util.h"],
|
||||
# TODO(skyewm): remove this once XlaAllocator is factored out.
|
||||
visibility = [
|
||||
":internal",
|
||||
"//tensorflow/compiler/xla/python:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
|
@ -244,11 +244,11 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
"resource variable op in called function");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
|
||||
if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
|
||||
return LogNotCompilableAndReturn(node, "operation with correctness issues");
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
|
||||
if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
|
||||
return LogNotCompilableAndReturn(node, "slow operation");
|
||||
}
|
||||
|
||||
@ -268,8 +268,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
registration.elide_assert_and_checknumerics;
|
||||
op_filter.allow_ops_producing_or_consuming_variant =
|
||||
registration.cluster_variant_ops;
|
||||
op_filter.allow_slow_and_inaccurate_ops =
|
||||
registration.cluster_slow_and_inaccurate_ops;
|
||||
op_filter.allow_slow_ops = registration.cluster_slow_ops;
|
||||
op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
|
||||
return op_filter;
|
||||
}
|
||||
|
||||
|
@ -97,9 +97,12 @@ class RecursiveCompilabilityChecker {
|
||||
// live-out DT_VARIANT values.
|
||||
bool allow_ops_producing_or_consuming_variant;
|
||||
|
||||
// Whether ops known to be slow or to have correctness issues should be
|
||||
// auto-clustered.
|
||||
bool allow_slow_and_inaccurate_ops;
|
||||
// Whether ops known to be slow on XLA-GPU should be considered compilable..
|
||||
bool allow_slow_ops;
|
||||
|
||||
// Whether ops known to have numerical accuracy issues should be considered
|
||||
// compilable..
|
||||
bool allow_inaccurate_ops;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
||||
|
@ -14,10 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -43,12 +45,12 @@ limitations under the License.
|
||||
// ------------------------------------------
|
||||
//
|
||||
// If we ignore cycles for a moment, computing predicates is fairly
|
||||
// straightforward. We traverse the graph in RPO, mapping each node to a
|
||||
// predicate based on the predicates its inputs are mapped to. For instance a
|
||||
// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
|
||||
// Roughtly speaking, we abstract interpret each node on the "liveness" domain,
|
||||
// where values in the domain represent if a tensor carries a dead signal or
|
||||
// not.
|
||||
// straightforward. We traverse the graph in a topological order, mapping each
|
||||
// node to a predicate based on the predicates its inputs are mapped to. For
|
||||
// instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
|
||||
// PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
|
||||
// the "liveness" domain, where values in the domain represent if a tensor
|
||||
// carries a dead signal or not.
|
||||
//
|
||||
//
|
||||
// DEALING WITH CYCLES
|
||||
@ -85,22 +87,28 @@ limitations under the License.
|
||||
// true on iteration 0, 1, 2 respectively. This is made more precise in the
|
||||
// comment on the AndRecurrence class.
|
||||
//
|
||||
// The general algorithm that deals with cycles does two RPO (reverse post
|
||||
// order) passes over the graph. On the first pass it assigns a symbolic
|
||||
// predicate to merge nodes with backedges. On the second pass it tries to
|
||||
// pattern matche the predicates for the backedges of these merges and infer an
|
||||
// AndRecurrence for the merge.
|
||||
// The general algorithm that deals with cycles does two topological-order
|
||||
// iterations over the graph. On the first iteration it assigns a symbolic
|
||||
// predicate to merge nodes with backedges. On the second iteration it tries
|
||||
// to pattern match the predicates for the backedges of these merges and infer
|
||||
// an AndRecurrence for the merge. In other words, we do a data flow analysis
|
||||
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
|
||||
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
|
||||
// sufficient to converge.
|
||||
//
|
||||
// In other words, we do a pessimistic data flow analysis where the data-flow
|
||||
// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
|
||||
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
|
||||
// converge. We don't do an optimistic data flow analysis to make pattern
|
||||
// matching easier: if we assigned the predicate of the initial value to the
|
||||
// merge during the first pass, on the second pass the backedge may see a
|
||||
// simplified value that would be difficult to pattern match.
|
||||
// We first do an optimisitc analysis and, if it does not converge, we then fall
|
||||
// back to a pessimistic analysis. The optimistic analysis assigns the same
|
||||
// symbolic predicate to all the merge nodes whose preceding enter nodes have
|
||||
// the same frame name on the first iteration. On the second iteration, if all
|
||||
// the merge nodes are pattern matched into the same AndRecurrence predicate
|
||||
// instance, the optimistic assignment of the same symbolic predicate is correct
|
||||
// and the analyzed result is taken.
|
||||
//
|
||||
// We still use symbolic predicates for merges for which we can't pattern match
|
||||
// on the backedge predicate. This is conservatively correct.
|
||||
// Otherwise, if the optimistic analysis fails to converge, we then obtain the
|
||||
// result by falling back to the pessimistic analysis which assigns a unique
|
||||
// symbolic predicate to each merge on the first iteration. We still use
|
||||
// symbolic predicates for merges for which we can't pattern match on the
|
||||
// backedge predicate. This is conservatively correct.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -636,6 +644,35 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
negated_ops.insert(negated_op);
|
||||
}
|
||||
|
||||
// Simplify {S,&,X} & ~X & ... => S & ...
|
||||
if (is_and) {
|
||||
absl::flat_hash_set<Predicate*> to_remove;
|
||||
std::vector<Predicate*> to_add;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kAndRecurrence) {
|
||||
auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
|
||||
if (negated_ops.contains(and_rec->step())) {
|
||||
// Remove and_rec and ~X and insert S. Note that checking the
|
||||
// existence of ~X through negated_ops is sufficient since it makes
|
||||
// sure the predicate is in the input operands. It does not need to
|
||||
// be in simplified_ops if it was already cancelled out.
|
||||
to_remove.insert(and_rec);
|
||||
to_remove.insert(MakeNotPredicate(and_rec->step()));
|
||||
to_add.push_back(and_rec->start());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto it = simplified_ops.begin();
|
||||
while (it != simplified_ops.end()) {
|
||||
if (to_remove.contains(*it)) {
|
||||
it = simplified_ops.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
|
||||
}
|
||||
|
||||
// If all ops contain the same subop, then factor it out thanks to the
|
||||
// distributive property. Such as:
|
||||
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
|
||||
@ -699,8 +736,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
explicit DeadnessAnalysisImpl(const Graph* graph)
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate();
|
||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
||||
Status Populate(bool enable_optimistic);
|
||||
Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
|
||||
bool* success);
|
||||
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
||||
Node* n, int oidx) const override;
|
||||
void Print() const override;
|
||||
@ -742,16 +780,29 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
}
|
||||
|
||||
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode);
|
||||
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode = false);
|
||||
|
||||
Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
|
||||
|
||||
bool IsRootEnter(const Node* n) const {
|
||||
return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
|
||||
}
|
||||
|
||||
bool IsRootExit(const Node* n) const {
|
||||
return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
|
||||
}
|
||||
|
||||
const Graph& graph_;
|
||||
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||
PredicateFactory predicate_factory_;
|
||||
std::vector<ControlFlowInfo> control_flow_info_;
|
||||
bool vlog_;
|
||||
absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
|
||||
};
|
||||
|
||||
TensorId InputEdgeToTensorId(const Edge* e) {
|
||||
@ -914,10 +965,32 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If the node is inside some frames, get the name of the outermost non-empty
|
||||
// frame. Otherwise, get an empty frame name.
|
||||
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
absl::string_view* frame) {
|
||||
int depth = 0;
|
||||
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
|
||||
while (!cfi_iter->parent_frame->IsSource()) {
|
||||
n = cfi_iter->parent_frame;
|
||||
cfi_iter = &cfi_infos[n->id()];
|
||||
|
||||
if (depth++ > 5000) {
|
||||
return errors::Internal(
|
||||
"Frame of depth > 5000: Probably malformed graph or a bug in "
|
||||
"BuildControlFlowInfo");
|
||||
}
|
||||
}
|
||||
|
||||
*frame = cfi_iter->frame_name;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
std::vector<bool>* should_revisit) {
|
||||
std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode) {
|
||||
// Merge ignores deadness of its control inputs. A merge that isn't the
|
||||
// target of a backedge has is alive iff any of its data inputs are. The
|
||||
// liveness of a merge that is the target of a backedge can sometimes be
|
||||
@ -937,8 +1010,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
// We're visiting this merge for the first time and it has an unvisited
|
||||
// backedge.
|
||||
Predicate* input_data_pred;
|
||||
if (use_optimistic_mode) {
|
||||
// In the optimistic mode, we use the first-seen Merge node per
|
||||
// frame as the representative Merge node. It is just convenient and
|
||||
// does not affect the result after pattern-matching into the
|
||||
// AndRecurrence form.
|
||||
absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
|
||||
auto insert_result = frame_to_merge_node_.insert({frame_name, n});
|
||||
Node* representative = insert_result.first->second;
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
representative, /*output_idx=*/0, /*must_be_true=*/false,
|
||||
&input_data_pred));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
|
||||
}
|
||||
|
||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||
should_revisit);
|
||||
@ -948,7 +1034,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
std::vector<Predicate*> input_preds;
|
||||
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
|
||||
|
||||
// We're visiting this merge for the first time and it is a acyclic merge.
|
||||
// We're visiting this merge for the first time and it is an acyclic merge.
|
||||
Predicate* input_data_pred =
|
||||
predicate_factory_.MakeOrPredicate(input_preds);
|
||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||
@ -1022,11 +1108,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
||||
std::vector<bool>* should_revisit) {
|
||||
std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode) {
|
||||
if (n->IsSwitch()) {
|
||||
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
|
||||
} else if (n->IsMerge()) {
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
|
||||
} else if (n->IsControlTrigger()) {
|
||||
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
|
||||
nullptr);
|
||||
@ -1040,17 +1127,129 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::Populate() {
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
return PopulateWithReversePostOrder(rpo);
|
||||
// Compute a special topological order for the Graph, where nodes having the
|
||||
// same root frame are placed adjacent to each other. The traversal uses a
|
||||
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
|
||||
// many inputs of each node are ready; a node is ready to be scheduled if all
|
||||
// of its inputs are ready.
|
||||
// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
|
||||
Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
|
||||
std::vector<Node*>* order) {
|
||||
absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
|
||||
absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
|
||||
std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
|
||||
Node* src_node = graph_.source_node();
|
||||
for (const auto* node : graph_.op_nodes()) {
|
||||
const ControlFlowInfo& cf = control_flow_info_[node->id()];
|
||||
if (IsRootEnter(node)) {
|
||||
// Since we care only the root-level frame, full frame names are the same
|
||||
// as frame names.
|
||||
++num_enters_for_frame[cf.frame_name];
|
||||
} else if (IsRootExit(node)) {
|
||||
++num_exits_for_frame[cf.frame_name];
|
||||
}
|
||||
// Edge NextIteration->Merge is counted before starting the traveral to
|
||||
// break the backedges.
|
||||
if (IsMerge(node)) {
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (IsNextIteration(e->src())) {
|
||||
++num_ready_inputs[node->id()];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
absl::Span<Node* const> rpo) {
|
||||
// dequeue is used to ensure that the nodes are first-in-first-out. This
|
||||
// order guarantees that the exits in the ready queue are visited before
|
||||
// nodes that will become ready in the future.
|
||||
std::deque<Node*> ready;
|
||||
ready.push_back(src_node);
|
||||
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
|
||||
// the ready enters/exits before they are moved to the `ready` queue for
|
||||
// controlling the start and end of a processing frame.
|
||||
absl::flat_hash_map<absl::string_view, std::vector<Node*>>
|
||||
ready_enters_per_frame;
|
||||
// Exit nodes shall all be from the same frame, as we process a frame at a
|
||||
// time. So, one vector is enough.
|
||||
std::vector<Node*> ready_exits;
|
||||
while (!ready.empty()) {
|
||||
Node* curr_node = ready.front();
|
||||
ready.pop_front();
|
||||
|
||||
VLOG(4) << "Visiting " << curr_node->name();
|
||||
order->push_back(curr_node);
|
||||
|
||||
for (const Edge* out_edge : curr_node->out_edges()) {
|
||||
Node* out = out_edge->dst();
|
||||
int out_id = out->id();
|
||||
if (IsNextIteration(curr_node) && IsMerge(out)) {
|
||||
// Edge NextIteration->Merge has been counted.
|
||||
continue;
|
||||
}
|
||||
++num_ready_inputs[out->id()];
|
||||
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
|
||||
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
|
||||
|
||||
absl::string_view frame_name = control_flow_info_[out_id].frame_name;
|
||||
if (IsRootEnter(out)) {
|
||||
ready_enters_per_frame[frame_name].push_back(out);
|
||||
} else if (IsRootExit(out)) {
|
||||
ready_exits.push_back(out);
|
||||
} else {
|
||||
ready.push_back(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (ready.empty()) {
|
||||
// Try moving nodes from ready_enters_per_frame and ready_exits to
|
||||
// `ready`.
|
||||
if (!ready_exits.empty()) {
|
||||
// If there are nodes in ready_exits we must process them before
|
||||
// processing ready_enters_per_frame to make sure all nodes in the
|
||||
// currently processing frame are visited before starting processing
|
||||
// other frames.
|
||||
absl::string_view frame_name =
|
||||
control_flow_info_[ready_exits.front()->id()].frame_name;
|
||||
CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
|
||||
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
|
||||
ready_exits.clear();
|
||||
} else {
|
||||
// Otherwise, try moving nodes from ready_enters to `ready`.
|
||||
for (auto iter = ready_enters_per_frame.begin();
|
||||
iter != ready_enters_per_frame.end(); ++iter) {
|
||||
absl::string_view frame_name = iter->first;
|
||||
const std::vector<Node*>& ready_enters = iter->second;
|
||||
if (ready_enters.size() == num_enters_for_frame[frame_name]) {
|
||||
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
|
||||
ready_enters_per_frame.erase(iter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Some enters/exits have never been visited in the traversal."
|
||||
" Most probably the input graph is malformed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We populate the nodes along a special topological order where nodes having
|
||||
// the same root frame are placed adjacent to each other. This grouping enables
|
||||
// processing the graph per root frame at a time and guarantees that when a root
|
||||
// frame is being processed, nodes in the downstream frames have not yet been
|
||||
// processed. This property is important because we need to process an entire
|
||||
// frame to know whether the optimistic mode converges or not. In other words,
|
||||
// nodes in the downstream frames shall not be populated until all of its
|
||||
// upstream frames are populated. In effect, this order enables processing each
|
||||
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
|
||||
// (root) frame. Note that we don't separate while loops belonging to the same
|
||||
// nested while, as there is no clean cut for separating them in the topological
|
||||
// order.
|
||||
Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
|
||||
std::vector<string> unreachable_nodes;
|
||||
// Compute the loop structure of the graph.
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1069,14 +1268,63 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
absl::StrJoin(unreachable_nodes, ", "));
|
||||
}
|
||||
|
||||
std::vector<Node*> topo;
|
||||
TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
|
||||
|
||||
size_t frame_start = 0;
|
||||
while (frame_start < topo.size()) {
|
||||
// Batching nodes who have the same root frame.
|
||||
absl::string_view cur_frame_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
|
||||
size_t frame_end = frame_start;
|
||||
for (size_t i = frame_start + 1; i < topo.size(); ++i) {
|
||||
absl::string_view i_frame_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
|
||||
if (i_frame_name == cur_frame_name) {
|
||||
frame_end = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
absl::Span<Node*> sub_topo(topo.data() + frame_start,
|
||||
/*length=*/frame_end - frame_start + 1);
|
||||
frame_start = frame_end + 1;
|
||||
|
||||
// First, try the optimistic mode.
|
||||
bool success = false;
|
||||
if (enable_optimistic && !cur_frame_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
|
||||
}
|
||||
if (!success) {
|
||||
// The optimistic mode does not converge. Let's fall back to the
|
||||
// pessimistic mode.
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
|
||||
}
|
||||
VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
|
||||
<< (success ? "optimistic" : "pessimistic") << " mode.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
|
||||
bool use_optimistic_mode,
|
||||
bool* success) {
|
||||
CHECK(use_optimistic_mode && success != nullptr ||
|
||||
!use_optimistic_mode && success == nullptr);
|
||||
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
//
|
||||
// We iterate over the graph twice, each time in RPO. On the first iteration
|
||||
// merge nodes with backedges are mapped to symbolic predicates. On the
|
||||
// second iteration we use the predicates assigned to the backedges in the
|
||||
// previous iteration to infer a more precise predicate for the backedge merge
|
||||
// nodes and all the nodes that transitively use it.
|
||||
// We iterate over the graph twice, each time in a topological order. On the
|
||||
// first iteration merge nodes with backedges are mapped to symbolic
|
||||
// predicates. On the second iteration we use the predicates assigned to the
|
||||
// backedges in the previous iteration to infer a more precise predicate for
|
||||
// the backedge merge nodes and all the nodes that transitively use it.
|
||||
//
|
||||
// We don't track the output indices for should_revisit. Instead, putting a
|
||||
// node in `should_revisit` denotes that the deadness flowing out from any
|
||||
@ -1086,9 +1334,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
// delta should not change in the second iteration.
|
||||
std::vector<bool> should_revisit;
|
||||
should_revisit.resize(graph_.num_node_ids());
|
||||
for (Node* n : rpo) {
|
||||
for (Node* n : topo) {
|
||||
VLOG(4) << "Visiting " << n->name();
|
||||
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
|
||||
if (n->IsNextIteration()) {
|
||||
// If this is a backedge for a merge node then remember to reprocess the
|
||||
// merge the next time we run.
|
||||
@ -1100,11 +1349,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : rpo) {
|
||||
for (Node* n : topo) {
|
||||
// The nodes added to should_revisit in the previous loop need to be
|
||||
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
||||
// to should_revisit, and these newly added nodes will also be processed by
|
||||
// this very same loop. Since we're traversing the graph in reverse post
|
||||
// this very same loop. Since we're traversing the graph in topological
|
||||
// order (producers before consumers) and HandleNode(n) can only ever add
|
||||
// n's consumers to should_revisit, we won't "miss" an addition to
|
||||
// should_revisit.
|
||||
@ -1114,6 +1363,71 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the optimistic analysis converges. Specifically, check whether
|
||||
// all the predicates of the merge nodes in the same frame are the same. If
|
||||
// yes, report success. If not, report failure and clear the assigned
|
||||
// predicates.
|
||||
if (use_optimistic_mode) {
|
||||
bool is_converged = true;
|
||||
absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
|
||||
for (Node* n : topo) {
|
||||
if (!n->IsMerge()) {
|
||||
continue;
|
||||
}
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
|
||||
if (e == nullptr) {
|
||||
// Skip acyclic merge nodes.
|
||||
continue;
|
||||
}
|
||||
Node* merge = n;
|
||||
// Note that here uses frame names instead of root frame names. In the
|
||||
// case of a nested while loop, each level of while loops can have merges
|
||||
// with different predicate instances, while the merge nodes on the same
|
||||
// level must have the same predicate instances.
|
||||
absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
|
||||
auto it = predicate_map_.find(TensorId(merge->name(), 0));
|
||||
Predicate* merge_pred = it->second;
|
||||
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
|
||||
is_converged = false;
|
||||
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||
<< " does not converge because node " << merge->name()
|
||||
<< " cannot be mapped into the AndRecurrence form.";
|
||||
break;
|
||||
}
|
||||
|
||||
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
|
||||
if (!insert_result.second) {
|
||||
// If we have already seen this frame name, verify the predicate is the
|
||||
// same as the previously seen one's.
|
||||
Predicate* curr_andrec = merge_pred;
|
||||
Predicate* prev_andrec = insert_result.first->second;
|
||||
if (curr_andrec != prev_andrec) {
|
||||
is_converged = false;
|
||||
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||
<< " does not converge. Seeing different Merge predicates: \n"
|
||||
<< curr_andrec->ToString() << " and \n"
|
||||
<< prev_andrec->ToString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the assigned predicates if the optimistic mode does not converge.
|
||||
if (!is_converged) {
|
||||
for (Node* n : topo) {
|
||||
for (int oid = 0; oid < n->num_outputs(); ++oid) {
|
||||
predicate_map_.erase(TensorId(n->name(), oid));
|
||||
}
|
||||
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
|
||||
}
|
||||
}
|
||||
|
||||
if (success != nullptr) {
|
||||
*success = is_converged;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1149,7 +1463,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
||||
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
|
||||
std::unique_ptr<DeadnessAnalysisImpl> analysis(
|
||||
new DeadnessAnalysisImpl(&graph));
|
||||
TF_RETURN_IF_ERROR(analysis->Populate());
|
||||
TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
analysis->Print();
|
||||
@ -1170,22 +1484,18 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||
}
|
||||
|
||||
namespace deadness_analysis_internal {
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
PredicateMapTy* out_predicate_map) {
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||
bool enable_optimistic) {
|
||||
DeadnessAnalysisImpl impl(&graph);
|
||||
TF_RETURN_IF_ERROR(impl.Populate());
|
||||
TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
|
||||
*out_predicate_map = impl.PredicateMapAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map) {
|
||||
DeadnessAnalysisImpl impl(&graph);
|
||||
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
|
||||
*out_predicate_map = impl.PredicateMapAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace deadness_analysis_internal
|
||||
|
||||
string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
|
||||
return static_cast<Predicate*>(predicate.pred_)->ToString();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -82,6 +82,8 @@ class DeadnessAnalysis {
|
||||
virtual void Print() const = 0;
|
||||
virtual ~DeadnessAnalysis();
|
||||
|
||||
string DebugString(DeadnessPredicate predicate) const;
|
||||
|
||||
// Run the deadness analysis over `graph` and returns an error or a populated
|
||||
// instance of DeadnessAnalysis in `result`.
|
||||
static Status Run(const Graph& graph,
|
||||
|
@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
|
||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||
// testing purposes only.
|
||||
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||
bool enable_optimistic = true);
|
||||
|
||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||
// testing purposes only. Makes deadness analysis visit the graph in the order
|
||||
// specified in `reverse_post_order` which must be a valid RPO for the graph
|
||||
// minus NextIteration->Merge edges.
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map);
|
||||
} // namespace deadness_analysis_internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
// To make deadness analysis think that dependent_iv is a loop we need an RPO
|
||||
// that visits the merge before the backedge. This is a legal RPO for
|
||||
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
|
||||
// Right now dependent_iv has an edge from Merge to NextIteration so do the
|
||||
// RPO with this edge in place. Then remove this edge to get our test case.
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
|
||||
|
||||
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<frame>");
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"div0/iv:0");
|
||||
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
|
||||
"{(*iv_outer/cond:0 & "
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
|
||||
"cond:0}<inner_loop;outer_loop>");
|
||||
|
||||
// enable_optimistic = true or not should produce the same results because
|
||||
// of fallback. However, note that the order of iv_inner/cond:0 and
|
||||
// iv_inner/iv:0 is different because the optimistic approach does not
|
||||
// create predicates for all merges and it can change the predicate id and
|
||||
// hence the symbol order.
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
|
||||
"iv_inner/iv:0)}<inner_loop;outer_loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv_outer =
|
||||
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
|
||||
Output enter_constant_outer_loop = ops::internal::Enter(
|
||||
root.WithOpName("constant_enter_outer_loop"),
|
||||
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
ops::Switch inner_value(root.WithOpName("outer_is_live"),
|
||||
enter_constant_outer_loop, iv_outer.loop_cond);
|
||||
InductionVarInfo iv_inner = CreateInductionVariable(
|
||||
root, "iv_inner", "inner_loop", inner_value.output_true);
|
||||
|
||||
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
|
||||
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
|
||||
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||
|
||||
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
|
||||
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
|
||||
div0_outer.induction_var);
|
||||
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
|
||||
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
|
||||
div1_outer.induction_var);
|
||||
|
||||
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
|
||||
"tensor_a", "sender", 0, "receiver");
|
||||
Output capture_enter_outer = ops::internal::Enter(
|
||||
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
Output capture_enter_inner = ops::internal::Enter(
|
||||
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
|
||||
capture_enter_inner);
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
|
||||
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
|
||||
div1_inner.induction_var);
|
||||
|
||||
VLogGraphIfAsked(*root.graph());
|
||||
|
||||
{
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
|
||||
DependentInductionVar div0 =
|
||||
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
|
||||
DependentInductionVar div1 =
|
||||
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
|
||||
div0.latch.output_true.node(), 0));
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
|
||||
div1.latch.output_true.node(), 0));
|
||||
|
||||
VLogGraphIfAsked(*root.graph());
|
||||
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
|
||||
// This tests the rule {S,&,X} & ~X => S.
|
||||
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
|
||||
div1.latch.output_false.index()};
|
||||
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*enable_optimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -52,16 +52,6 @@ typedef std::function<Status(
|
||||
// 'group_attribute' must be a string valued-attribute that names the new
|
||||
// functions to introduce.
|
||||
//
|
||||
// 'outside_compilation_attribute' must be a string-valued attribute that is
|
||||
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
|
||||
// cluster within the subgraph. A cluster is formed from the set of nodes with
|
||||
// the same value of outside_compilation_subgraph and group_attribute. The nodes
|
||||
// in an outside_compilation cluster are left in the original graph. Edges
|
||||
// crossing from the subgraph to an outside_compilation cluster nested in the
|
||||
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
|
||||
// crossing from an outside_compilation cluster into its enclosing subgraph are
|
||||
// lifted into a SendFromHost/RecvFromHost pair of nodes.
|
||||
//
|
||||
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
||||
// function conversion.
|
||||
//
|
||||
@ -74,10 +64,9 @@ typedef std::function<Status(
|
||||
// dep from B. Originally D must run after C, post-transformation this
|
||||
// dependency is lost.
|
||||
Status EncapsulateSubgraphsInFunctions(
|
||||
string group_attribute, string outside_compilation_attribute,
|
||||
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
|
||||
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
|
||||
FunctionLibraryDefinition* library);
|
||||
string group_attribute, const Graph& graph_in,
|
||||
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
|
||||
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
|
||||
|
||||
// The attribute that marks function calls produced by the encapsulate
|
||||
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
|
||||
|
@ -514,10 +514,10 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
s = EncapsulateSubgraphsInFunctions(
|
||||
"_encapsulate", /*outside_compilation_attribute=*/"", *graph,
|
||||
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
|
||||
/*rewrite_subgraph_fn=*/{},
|
||||
/*reuse_existing_functions=*/false, &graph_out, lib_def.get());
|
||||
/*reuse_existing_functions=*/false,
|
||||
&graph_out, lib_def.get());
|
||||
if (!s.ok()) return s;
|
||||
|
||||
std::unordered_map<string, XlaClusterInfo> clusters;
|
||||
@ -746,7 +746,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_cluster", "", graph_before_encapsulation,
|
||||
"_cluster", graph_before_encapsulation,
|
||||
/*rewrite_subgraph_fn=*/{},
|
||||
/*reuse_existing_functions=*/false, &graph, &library));
|
||||
|
||||
@ -798,7 +798,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
int guaranteed_consts = 0;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_encapsulate", "", graph_before,
|
||||
"_encapsulate", graph_before,
|
||||
/*rewrite_subgraph_fn=*/
|
||||
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
||||
std::unique_ptr<Graph>* graph_ptr,
|
||||
@ -843,7 +843,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
int guaranteed_consts = 0;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_encapsulate", "", graph_before,
|
||||
"_encapsulate", graph_before,
|
||||
/*rewrite_subgraph_fn=*/
|
||||
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
|
||||
std::unique_ptr<Graph>* graph_ptr,
|
||||
@ -1109,7 +1109,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}},
|
||||
{"F"}},
|
||||
{"F", "outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
@ -1990,7 +1990,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
"outside_compilation_O1_host_compute"})}},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
@ -2117,7 +2118,8 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute"})}}},
|
||||
"outside_compilation_O1_host_compute"})}},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
@ -2267,7 +2269,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{"_xla_token_input_nodes",
|
||||
absl::Span<const string>(
|
||||
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
|
||||
{}},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O3_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
@ -2282,7 +2284,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
absl::Span<const string>({"_xla_token_arg_node",
|
||||
"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"})}},
|
||||
{}}},
|
||||
{"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"}}},
|
||||
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{"h_0_retval_retval", "H:o:0"}});
|
||||
|
||||
|
@ -231,9 +231,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
||||
auto output = absl::make_unique<Graph>((*graph)->op_registry());
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
EncapsulateSubgraphsInFunctions(
|
||||
kXlaClusterAttr, "", **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true, &output, flib_def),
|
||||
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true,
|
||||
&output, flib_def),
|
||||
"EncapsulateXlaComputationsPass failed");
|
||||
graph->swap(output);
|
||||
return Status::OK();
|
||||
|
@ -393,7 +393,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// Replace outside compilation function call node with XlaHostCompute node.
|
||||
// If the function call node has no input/output edges, we will just remove it
|
||||
// and not create a XlaHostCompute node.
|
||||
Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
xla::StatusOr<Node*> ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
// If the function call node has no input/output edges, just remove it.
|
||||
@ -413,7 +413,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
if (!has_edge) {
|
||||
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
|
||||
g->RemoveNode(call_node);
|
||||
return Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Build XlaHostCompute NodeDef.
|
||||
@ -424,7 +424,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
ReplaceNode(g, call_node, node_def));
|
||||
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
|
||||
|
||||
return Status::OK();
|
||||
return host_compute_node;
|
||||
}
|
||||
|
||||
// Resets "device_ordinal" attr to placeholder value for related nodes
|
||||
@ -1634,7 +1634,7 @@ Status ExtractOutsideCompilationForFunction(
|
||||
RewriteOutsideCompilationSubgraphFn rewrite_fn(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
|
||||
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
|
||||
outside_compilation_attr_name, "", *fbody->graph, rewrite_fn,
|
||||
outside_compilation_attr_name, *fbody->graph, rewrite_fn,
|
||||
/*reuse_existing_functions=*/true, &graph_out, fld));
|
||||
|
||||
// Replace outside_compilation function nodes with HostCompute ops.
|
||||
@ -1670,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction(
|
||||
}
|
||||
}
|
||||
}
|
||||
std::map<string, Node*> host_compute_nodes;
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps));
|
||||
auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps);
|
||||
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
||||
Node* host_compute_node = host_compute_node_or.ValueOrDie();
|
||||
if (host_compute_node) {
|
||||
host_compute_nodes[host_compute_node->name()] = host_compute_node;
|
||||
}
|
||||
}
|
||||
// For XlaHostCompute nodes with dependencies, add control edges between them
|
||||
// so XlaCompiler can handle them in correct order.
|
||||
for (auto iter : host_compute_nodes) {
|
||||
Node* host_compute_node = iter.second;
|
||||
std::vector<string> token_input_node_names;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
|
||||
kXlaTokenInputNodesAttrName,
|
||||
&token_input_node_names));
|
||||
for (const string& node_name : token_input_node_names) {
|
||||
if (node_name == kXlaTokenArgNodeName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto iter = host_compute_nodes.find(node_name);
|
||||
if (iter != host_compute_nodes.end()) {
|
||||
graph_out->AddControlEdge(iter->second, host_compute_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nodes with associated functions.
|
||||
|
@ -990,6 +990,16 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
|
||||
// Check there is a control edge from host_compute_0 to host_compute_1.
|
||||
bool has_control_edge = false;
|
||||
for (const Edge *e : host_compute_1->in_edges()) {
|
||||
if (e->IsControlEdge() && e->src() == host_compute_0) {
|
||||
has_control_edge = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(has_control_edge);
|
||||
}
|
||||
|
||||
TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
@ -1062,5 +1072,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
|
||||
"_xla_token_input_nodes", &token_input_nodes));
|
||||
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
|
||||
|
||||
// Check there is a control edge from host_compute_0 to host_compute_1.
|
||||
bool has_control_edge = false;
|
||||
for (const Edge *e : host_compute_1->in_edges()) {
|
||||
if (e->IsControlEdge() && e->src() == host_compute_0) {
|
||||
has_control_edge = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(has_control_edge);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -1,9 +1,8 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
@ -1,9 +1,8 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -29,6 +28,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
|
@ -61,7 +61,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
std::unique_ptr<XlaAllocator> xla_allocator;
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
@ -93,7 +93,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||
|
||||
xla_allocator = absl::make_unique<XlaAllocator>(
|
||||
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
|
||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -36,10 +37,10 @@ class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
explicit XlaPlatformInfo(
|
||||
const DeviceType device_type, se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::unique_ptr<XlaAllocator> xla_allocator,
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
@ -84,8 +85,8 @@ class XlaPlatformInfo {
|
||||
// then device_allocator_ is the xla::Backend's memory allocator and
|
||||
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
|
||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||
// XlaAllocator instance.
|
||||
std::unique_ptr<XlaAllocator> xla_allocator_;
|
||||
// se::TfAllocatorAdapter instance.
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
|
@ -229,32 +229,18 @@ class MarkForCompilationPassImpl {
|
||||
// Initialize some internal data structures.
|
||||
Status Initialize();
|
||||
|
||||
// Runs through all the nodes in `cycles_graph_` and tries to create clusters.
|
||||
// Returns true if any new clusters were created.
|
||||
StatusOr<bool> RunEdgeContractionLoopInPostOrderOnce();
|
||||
// Runs through the entire cluster graph in post-order and calls `fn(from,
|
||||
// to)` on each edge. `fn(from, to)` is expected to return true if it was
|
||||
// able to contract `from`->`to`.
|
||||
//
|
||||
// Returns true if `fn` returned true for any edge.
|
||||
template <typename FnTy>
|
||||
StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
|
||||
|
||||
// Runs through all the nodes in `cycles_graph_` and tries to contract high
|
||||
// priority edges for clusters. Returns true if any new clusters were created.
|
||||
//
|
||||
// There are potentially many maximal clustering results, but they will not
|
||||
// all be equally performant. Some clustering decision are likely to improve
|
||||
// performance much more than others, and we cannot order contractions on this
|
||||
// cost function, nor can we look at global information while deciding on
|
||||
// individual edges to contract. Instead, we will make decisions on these
|
||||
// important edges then make decisions on all other edges, causing the highest
|
||||
// chance of all most important edges to be contracted.
|
||||
//
|
||||
// An example of where this might occur is with a digraph:
|
||||
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
|
||||
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
|
||||
// should be clustered with A because it will prevent a potentially large
|
||||
// tensor from A being computed and copied.
|
||||
//
|
||||
// This pass will ensure that contraction happens, which cannot be enforced in
|
||||
// a single pass with the current algorithm.
|
||||
// graph and prevent B->C from being clusterd in anticipation of a later A->B
|
||||
// cluster.
|
||||
StatusOr<bool> ContractPreferredEdges();
|
||||
// If from->to is a "preferred" edge (i.e. if we have a choice, we want to
|
||||
// prioritize contracting from->to over contracting other edges) then
|
||||
// contracts it and returns true. Else returns false.
|
||||
StatusOr<bool> ContractEdgeIfPreferred(Cluster* from, Cluster* to);
|
||||
|
||||
// Contracts as many edges as possible to create XLA clusters. After this
|
||||
// finishes the clustering decisions made are implicitly stored in
|
||||
@ -276,10 +262,6 @@ class MarkForCompilationPassImpl {
|
||||
// true if successful.
|
||||
StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
|
||||
|
||||
// Tries to contract each edge from `cluster_from`. Returns true if any edges
|
||||
// were contracted, false otherwise.
|
||||
StatusOr<bool> TryToContractEdgesFrom(Cluster* cluster_from);
|
||||
|
||||
// Nodes that XLA can compile are put in `compilation_candidates_`.
|
||||
Status FindCompilationCandidates();
|
||||
|
||||
@ -401,6 +383,13 @@ class MarkForCompilationPassImpl {
|
||||
return true;
|
||||
}
|
||||
|
||||
string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
|
||||
absl::string_view reason) {
|
||||
return absl::StrCat("Could not contract ", from->DebugString(*graph_),
|
||||
" -> ", to->DebugString(*graph_), " because ", reason,
|
||||
".");
|
||||
}
|
||||
|
||||
DebugOptions debug_options_;
|
||||
Graph* graph_;
|
||||
FunctionLibraryDefinition* flib_def_;
|
||||
@ -611,7 +600,8 @@ Status MarkForCompilationPassImpl::Initialize() {
|
||||
return BuildInitialClusterSet();
|
||||
}
|
||||
|
||||
StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
||||
template <typename FnTy>
|
||||
StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
|
||||
bool changed = false;
|
||||
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
|
||||
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||
@ -632,8 +622,18 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cluster_to->cluster_size() == 1) {
|
||||
Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode());
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
|
||||
changed |= contracted_edge;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> MarkForCompilationPassImpl::ContractEdgeIfPreferred(
|
||||
Cluster* from, Cluster* to) {
|
||||
if (to->cluster_size() == 1) {
|
||||
Node* n = graph_->FindNodeId(to->GetIdOfOnlyNode());
|
||||
|
||||
// Shape consuming operations are desirable to cluster with their
|
||||
// operands because they return a small set of scalar values after
|
||||
@ -644,43 +644,11 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
|
||||
// tensor that must be computed and possible transposed/copied before
|
||||
// the second cluster executes.
|
||||
if (IsShapeConsumerOp(*n)) {
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_edge,
|
||||
TryToContractEdge(cluster_from, cluster_to));
|
||||
changed |= contracted_edge;
|
||||
}
|
||||
}
|
||||
return TryToContractEdge(from, to);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool>
|
||||
MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() {
|
||||
bool changed = false;
|
||||
// Iterating over the graph once in post-order is sufficient to produce a
|
||||
// maximal clustering:
|
||||
//
|
||||
// A. We visit a cluster only after maximally clustering all its children.
|
||||
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
|
||||
// its children that could have been absorbed into `node` have been
|
||||
// absorbed.
|
||||
// C. We have an invariant that making a cluster larger does not make edges
|
||||
// leaving it more contractable. That is, if we have
|
||||
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
|
||||
// to contract Y->Z if Y->Z was not contractible originally.
|
||||
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
|
||||
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||
if (!cluster_from) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
|
||||
TryToContractEdgesFrom(cluster_from));
|
||||
changed |= contracted_one_edge;
|
||||
}
|
||||
|
||||
return changed;
|
||||
return false;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
||||
@ -694,25 +662,68 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
||||
// without restrictions. This helps to minimize data output from clusters (and
|
||||
// possible transpose operations before outputs) that might occur if a
|
||||
// ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations.
|
||||
TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges());
|
||||
//
|
||||
// There are potentially many maximal clustering results, but they will not
|
||||
// all be equally performant. Some clustering decision are likely to improve
|
||||
// performance much more than others, and we cannot order contractions on this
|
||||
// cost function, nor can we look at global information while deciding on
|
||||
// individual edges to contract. Instead, we will make decisions on these
|
||||
// important edges then make decisions on all other edges, causing the highest
|
||||
// chance of all most important edges to be contracted.
|
||||
//
|
||||
// An example of where this might occur is with a digraph:
|
||||
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
|
||||
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
|
||||
// should be clustered with A because it will prevent a potentially large
|
||||
// tensor from A being computed and copied.
|
||||
//
|
||||
// This pass will ensure that contraction happens, which cannot be enforced in
|
||||
// a single pass with the current algorithm.
|
||||
// graph and prevent B->C from being clusterd in anticipation of a later A->B
|
||||
// cluster.
|
||||
|
||||
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
|
||||
TF_ASSIGN_OR_RETURN(bool changed,
|
||||
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||
return ContractEdgeIfPreferred(from, to);
|
||||
}));
|
||||
|
||||
// Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the
|
||||
// linear time post-order scheme has been battle tested we can move this to
|
||||
// happen only in debug builds.
|
||||
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
|
||||
// Iterating over the whole graph once in post-order is sufficient to produce
|
||||
// a maximal clustering:
|
||||
//
|
||||
// A. We visit a cluster only after maximally clustering all its children.
|
||||
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
|
||||
// its children that could have been absorbed into `node` have been
|
||||
// absorbed.
|
||||
// C. We have an invariant that making a cluster larger does not make edges
|
||||
// leaving it more contractable. That is, if we have
|
||||
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
|
||||
// to contract Y->Z if Y->Z was not contractible originally.
|
||||
TF_ASSIGN_OR_RETURN(changed,
|
||||
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||
return TryToContractEdge(from, to);
|
||||
}));
|
||||
|
||||
// Check that the conclusion made above (that iterating over the graph once in
|
||||
// post order gives a maximal clustering) holds. Once the linear time
|
||||
// post-order scheme has been battle tested we can move this to happen only in
|
||||
// debug builds.
|
||||
TF_ASSIGN_OR_RETURN(changed,
|
||||
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
|
||||
return TryToContractEdge(from, to);
|
||||
}));
|
||||
TF_RET_CHECK(!changed);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
|
||||
|
||||
Status MarkForCompilationPassImpl::CreateClusters() {
|
||||
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
|
||||
clusters_created_ = true;
|
||||
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
|
||||
@ -745,7 +756,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
|
||||
string& name = cluster_names[cluster->cycles_graph_node_id()];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
|
||||
}
|
||||
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
@ -1065,8 +1076,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
|
||||
|
||||
bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
|
||||
Cluster* from, Cluster* to, absl::string_view reason) {
|
||||
VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> "
|
||||
<< to->DebugString(*graph_) << " because " << reason << ".";
|
||||
VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1075,8 +1085,14 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
||||
DCHECK(from->deadness_predicate().has_value() ==
|
||||
to->deadness_predicate().has_value());
|
||||
if (from->deadness_predicate() != to->deadness_predicate()) {
|
||||
return LogNotContractableAndReturnFalse(
|
||||
from, to, "the two nodes have mismatching deadness");
|
||||
VLOG(3) << EdgeContractionFailureMsg(
|
||||
from, to,
|
||||
absl::StrCat(
|
||||
"the two nodes have mismatching deadness: ",
|
||||
deadness_analysis_->DebugString(*from->deadness_predicate()),
|
||||
" and ",
|
||||
deadness_analysis_->DebugString(*to->deadness_predicate())));
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool devices_compatible,
|
||||
@ -1133,32 +1149,6 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
|
||||
return MergeClusters(from, to);
|
||||
}
|
||||
|
||||
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
|
||||
Cluster* cluster_from) {
|
||||
bool changed = false;
|
||||
|
||||
// Make a copy of the set of successors because we may modify the graph in
|
||||
// TryToContractEdge.
|
||||
std::vector<int32> successors_copy =
|
||||
cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
||||
|
||||
for (int to : successors_copy) {
|
||||
iteration_count_++;
|
||||
|
||||
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
||||
if (!cluster_to) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool contracted_edge,
|
||||
TryToContractEdge(cluster_from, cluster_to));
|
||||
|
||||
changed |= contracted_edge;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::Run() {
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
@ -1485,7 +1475,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_and_inaccurate_ops = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
|
||||
.IsCompilableCall(ndef, flr);
|
||||
@ -1522,4 +1513,8 @@ Status MarkForCompilationPass::RunForTest(
|
||||
|
||||
return MarkForCompilation(options, debug_options);
|
||||
}
|
||||
|
||||
namespace testing {
|
||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable.
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
//
|
||||
// Resets some internal state to let us write reliable unit tests.
|
||||
void ResetClusterSequenceNumber();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
|
@ -1,7 +1,6 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch(
|
||||
missing.push_back(entry.first);
|
||||
}
|
||||
return errors::InvalidArgument("Missing shapes for nodes: ",
|
||||
str_util::Join(missing, ","));
|
||||
absl::StrJoin(missing, ","));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -60,7 +60,8 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
registration.cluster_slow_ops = true;
|
||||
registration.cluster_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -71,7 +71,7 @@ class XlaDeviceContext : public DeviceContext {
|
||||
StatusCallback done) const override;
|
||||
|
||||
xla::LocalClient* client() const { return client_; }
|
||||
se::Stream* stream() const { return stream_.get(); }
|
||||
se::Stream* stream() const override { return stream_.get(); }
|
||||
se::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/host_constant_op.h"
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/logging_ops.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
@ -81,6 +82,11 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
||||
|
||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Assert") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("condition") \
|
||||
.HostMemory("data"), \
|
||||
AssertOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -95,7 +95,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
registration.cluster_slow_ops = true;
|
||||
registration.cluster_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -63,7 +63,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
registration.cluster_slow_and_inaccurate_ops = true;
|
||||
registration.cluster_slow_ops = true;
|
||||
registration.cluster_inaccurate_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
||||
|
@ -167,32 +167,6 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
|
||||
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
|
||||
XlaAllocator::~XlaAllocator() {}
|
||||
|
||||
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
AllocationAttributes attrs;
|
||||
attrs.no_retry_on_failure = !retry_on_failure;
|
||||
void* data = nullptr;
|
||||
if (size != 0) {
|
||||
data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
|
||||
if (data == nullptr) {
|
||||
return errors::ResourceExhausted(
|
||||
"Out of memory while trying to allocate ", size, " bytes.");
|
||||
}
|
||||
}
|
||||
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
}
|
||||
|
||||
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
wrapped_->DeallocateRaw(mem.opaque());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class XlaAllocator;
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
@ -104,74 +103,6 @@ class VariableInfo {
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
|
||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class XlaAllocator : public se::DeviceMemoryAllocator {
|
||||
public:
|
||||
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
|
||||
~XlaAllocator() override;
|
||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) override;
|
||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
|
||||
|
||||
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
|
||||
// before GPU execution takes place. Tensorflow uses the ordering of the main
|
||||
// compute stream to enforce a happens-before relationship between a memory
|
||||
// allocation and code that reuses the same memory. If Tensorflow adds
|
||||
// support for multiple GPU streams or allocators with different ordering
|
||||
// requirements, this code may need to change.
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
private:
|
||||
Allocator* wrapped_;
|
||||
};
|
||||
|
||||
// Adapter class that wraps per-device TF allocators as an XLA allocator.
|
||||
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
|
||||
public:
|
||||
MultiDeviceAdapter(
|
||||
const se::Platform* platform,
|
||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
|
||||
: DeviceMemoryAllocator(platform),
|
||||
tf_allocators_(std::move(tf_allocators)) {
|
||||
for (const auto& tf_allocator : tf_allocators_) {
|
||||
per_device_allocators_.emplace_back(platform, tf_allocator.get());
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<se::OwningDeviceMemory> Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) override {
|
||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
||||
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
|
||||
retry_on_failure);
|
||||
}
|
||||
|
||||
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
|
||||
CHECK_LT(device_ordinal, per_device_allocators_.size());
|
||||
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
|
||||
mem);
|
||||
}
|
||||
|
||||
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
|
||||
// before GPU execution takes place. Tensorflow uses the ordering of the main
|
||||
// compute stream to enforce a happens-before relationship between a memory
|
||||
// allocation and code that reuses the same memory. If Tensorflow adds
|
||||
// support for multiple GPU streams or allocators with different ordering
|
||||
// requirements, this code may need to change.
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
private:
|
||||
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
|
||||
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
|
||||
// not take ownership of its underlying Allocator).
|
||||
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
|
||||
};
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
class XlaComputationLaunchContext {
|
||||
|
@ -28,10 +28,9 @@
|
||||
** Please don't remove this file - it is supporting some 3rd party plugins **
|
||||
"""
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -954,7 +954,7 @@ tf_xla_py_test(
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["ternary_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(
|
||||
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
constant_op.constant(True),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondAndTensorArrayInDefun_constFolding(self):
|
||||
g = ops.Graph()
|
||||
with session.Session(graph=g), g.as_default(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
@function.defun
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(False),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t = f()
|
||||
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondAndTensorArray_xlaCompile(self):
|
||||
self.skipTest("b/127846988")
|
||||
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
|
||||
with self.session(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(True),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t, = xla.compile(f)
|
||||
self.assertAllEqual([5.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondConstPropagation(self):
|
||||
with self.session() as sess, self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testSwitchCaseAndTensorArray_xlaCompile(self):
|
||||
self.skipTest("b/127846988")
|
||||
with self.session(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.switch_case(
|
||||
constant_op.constant(1), {
|
||||
0: lambda: ta.write(0, 5.),
|
||||
1: lambda: ta.write(0, 10.),
|
||||
2: lambda: ta.write(0, 15.),
|
||||
})
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t, = xla.compile(f)
|
||||
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testSwitchCaseConstPropagation(self):
|
||||
self.skipTest("b/127846988")
|
||||
with self.session() as sess, self.test_scope():
|
||||
|
@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
|
||||
"""Test for 2x2 kernel with VALID padding."""
|
||||
# [1, 2, 2, 2]
|
||||
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
|
||||
# [1, 1, 1, 8]
|
||||
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[2, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -72,21 +72,21 @@ class TernaryOpsTest(xla_test.XLATestCase):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array(0, dtype=np.bool),
|
||||
np.array(False),
|
||||
np.array(2, dtype=dtype),
|
||||
np.array(7, dtype=dtype),
|
||||
expected=np.array(7, dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array(1, dtype=np.bool),
|
||||
np.array(True),
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([1, 2, 3, 4], dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where,
|
||||
np.array(0, dtype=np.bool),
|
||||
np.array(False),
|
||||
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
|
||||
@ -105,6 +105,74 @@ class TernaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
|
||||
|
||||
def testSelectV2(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array(False),
|
||||
np.array(2, dtype=dtype),
|
||||
np.array(7, dtype=dtype),
|
||||
expected=np.array(7, dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array(True),
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([1, 2, 3, 4], dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array(False),
|
||||
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
|
||||
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([0, 1, 1, 0], dtype=np.bool),
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array([5, 2, 3, 8], dtype=dtype))
|
||||
|
||||
# Broadcast the condition
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([0, 1], dtype=np.bool),
|
||||
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype))
|
||||
|
||||
# Broadcast the then branch to the else
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool),
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
|
||||
|
||||
# Broadcast the else branch to the then
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
|
||||
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
|
||||
|
||||
# Broadcast the then/else branches to the condition
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool),
|
||||
np.array(7, dtype=dtype),
|
||||
np.array(8, dtype=dtype),
|
||||
expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype))
|
||||
self._testTernary(
|
||||
array_ops.where_v2,
|
||||
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
|
||||
np.array(7, dtype=dtype),
|
||||
np.array([8, 9], dtype=dtype),
|
||||
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
|
||||
|
||||
def testSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
|
@ -104,6 +104,24 @@ struct EdgePtrCompare {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(laigd): instead of deciding the device here, the converter should accept
|
||||
// a device name as one of the conversion parameter so users can control on
|
||||
// which device they want to run the conversion.
|
||||
std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
|
||||
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
|
||||
TfGpuId tf_gpu_id(tf_gpu_id_value);
|
||||
PlatformGpuId platform_gpu_id;
|
||||
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
|
||||
if (s.ok()) {
|
||||
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
|
||||
<< platform_gpu_id.value();
|
||||
return std::make_pair(tf_gpu_id, platform_gpu_id);
|
||||
}
|
||||
}
|
||||
LOG(ERROR) << "Could not find any TF GPUs";
|
||||
return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
|
||||
}
|
||||
|
||||
// Function to get subsegment information structure.
|
||||
Status GetEngineInfo(const Graph* g,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
@ -128,20 +146,37 @@ Status GetEngineInfo(const Graph* g,
|
||||
if (segment_nodes.count(node) == 0) continue;
|
||||
auto node_device = node->requested_device();
|
||||
if (!node_device.empty()) {
|
||||
// If device is CPU, treat as if no device was assigned. Don't add CPU to
|
||||
// segment_device because that would cause a segfault in
|
||||
// GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes
|
||||
// any already set device is a GPU.
|
||||
// If device is set, it means device placement may have been done before,
|
||||
// so we need to assign a device for the TRTEngineOp to maintain the
|
||||
// invariance.
|
||||
// If the device is CPU in this case, it tries to find the first available
|
||||
// GPU and use it as the device.
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
const bool parse_succeeded =
|
||||
DeviceNameUtils::ParseFullName(node_device, &parsed_name);
|
||||
if (parsed_name.type == "CPU") {
|
||||
VLOG(1) << "Node " << node->name() << " was assigned to the CPU. "
|
||||
<< "Attempting to place on GPU.";
|
||||
if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) {
|
||||
string msg;
|
||||
if (!parse_succeeded) {
|
||||
msg = StrCat("Failed to parse assigned device of node ", node->name(),
|
||||
". ");
|
||||
} else {
|
||||
msg = StrCat("Node ", node->name(), " was assigned to the CPU. ");
|
||||
}
|
||||
VLOG(1) << msg << "Attempting to place on GPU.";
|
||||
TfGpuId tf_gpu_id;
|
||||
PlatformGpuId platform_gpu_id;
|
||||
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
|
||||
if (tf_gpu_id.value() >= 0) {
|
||||
parsed_name.type = "GPU";
|
||||
parsed_name.id = tf_gpu_id.value();
|
||||
segment_devices.insert(DeviceNameUtils::FullName(
|
||||
parsed_name.job, parsed_name.replica, parsed_name.task,
|
||||
parsed_name.type, parsed_name.id));
|
||||
}
|
||||
} else {
|
||||
segment_devices.insert(node_device);
|
||||
}
|
||||
} else {
|
||||
if (node->has_assigned_device_name()) {
|
||||
} else if (node->has_assigned_device_name()) {
|
||||
// It appears that nodes will not have assigned devices at this point in
|
||||
// execution.
|
||||
segment_devices.insert(node->assigned_device_name());
|
||||
@ -149,7 +184,6 @@ Status GetEngineInfo(const Graph* g,
|
||||
VLOG(2) << "Node " << node->name()
|
||||
<< " neither have requested device nor assigned device";
|
||||
}
|
||||
}
|
||||
subgraph_nodes.push_back(node);
|
||||
|
||||
const int node_id = node->id();
|
||||
@ -251,13 +285,11 @@ Status GetEngineInfo(const Graph* g,
|
||||
info->engine_name = StrCat(scope_name, info->engine_name);
|
||||
VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
|
||||
<< "' to a GraphDef";
|
||||
// TODO(sami): This should not happen once segmenter is updated.
|
||||
if (segment_devices.size() == 1) {
|
||||
info->device = *segment_devices.begin();
|
||||
} else if (segment_devices.size() > 1) {
|
||||
LOG(WARNING) << "Detected multiple (" << segment_devices.size()
|
||||
<< ") devices for the segment. Picking first one to continue "
|
||||
<< "but this shouldn't have happened";
|
||||
<< ") devices for the segment. Picking first one to continue.";
|
||||
info->device = *segment_devices.begin();
|
||||
} else {
|
||||
VLOG(1) << "No device is assigned to the segment. "
|
||||
@ -543,10 +575,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
|
||||
std::map<string, Node*> io_nodes;
|
||||
int num_inputs = 0;
|
||||
for (auto n : sgraph.op_nodes()) {
|
||||
if (str_util::StartsWith(n->name(), kInputPHName)) {
|
||||
if (absl::StartsWith(n->name(), kInputPHName)) {
|
||||
num_inputs++;
|
||||
io_nodes.insert({n->name(), n});
|
||||
} else if (str_util::StartsWith(n->name(), kOutputPHName)) {
|
||||
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
|
||||
io_nodes.insert({n->name(), n});
|
||||
}
|
||||
}
|
||||
@ -640,24 +672,17 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||
if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
|
||||
engine.device.empty()) {
|
||||
// If device is not set, use the first found GPU device for the conversion.
|
||||
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
|
||||
TfGpuId tf_gpu_id(tf_gpu_id_value);
|
||||
TfGpuId tf_gpu_id;
|
||||
PlatformGpuId platform_gpu_id;
|
||||
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
|
||||
if (s.ok()) {
|
||||
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
|
||||
<< platform_gpu_id.value();
|
||||
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
|
||||
cuda_device_id = platform_gpu_id.value();
|
||||
if (cuda_device_id >= 0) {
|
||||
GPUOptions gpu_options;
|
||||
// If the TF to Cuda gpu id mapping exist, the device and corresponding
|
||||
// allocator must have been initialized already, so the
|
||||
// GetGPUAllocator() call won't create a new allocator.
|
||||
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
|
||||
gpu_options, tf_gpu_id, 1);
|
||||
break;
|
||||
}
|
||||
LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
|
||||
<< s;
|
||||
}
|
||||
return std::make_pair(cuda_device_id, dev_allocator);
|
||||
}
|
||||
@ -750,8 +775,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
EngineInfo curr_engine;
|
||||
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
|
||||
Status status =
|
||||
GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
|
||||
node_map, reverse_topo_order, &curr_engine);
|
||||
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
|
||||
reverse_topo_order, &curr_engine);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
|
||||
<< status;
|
||||
@ -776,7 +801,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
|
||||
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
||||
total_engine_bytes_size += engine_bytes_size.back();
|
||||
total_num_nodes_in_segments += curr_segment.first.size();
|
||||
total_num_nodes_in_segments += curr_segment.size();
|
||||
engine_segments.push_back(std::move(curr_engine));
|
||||
converted_segments.push_back(std::move(curr_segment));
|
||||
|
||||
@ -806,7 +831,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
engine.max_workspace_size_bytes =
|
||||
params.max_workspace_size_bytes *
|
||||
(engine_bytes_size.at(i) / total_engine_bytes_size +
|
||||
converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
|
||||
converted_segments.at(i).size() / total_num_nodes_in_segments) /
|
||||
2.0;
|
||||
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
|
||||
<< engine.engine_name;
|
||||
@ -828,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
|
||||
alloc.get(), &engine_nodes);
|
||||
|
||||
string msg = StrCat("TensorRT node ", engine.engine_name,
|
||||
" added for segment ", i, " consisting of ",
|
||||
converted_segments.at(i).first.size(), " nodes");
|
||||
string msg =
|
||||
StrCat("TensorRT node ", engine.engine_name, " added for segment ", i,
|
||||
" consisting of ", converted_segments.at(i).size(), " nodes");
|
||||
if (status.ok()) {
|
||||
LOG(INFO) << msg << " succeeded.";
|
||||
} else {
|
||||
@ -839,7 +864,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
}
|
||||
if (VLOG_IS_ON(1)) {
|
||||
msg = "Segment consists of nodes: ";
|
||||
for (const Node* node : converted_segments.at(i).first) {
|
||||
for (const Node* node : converted_segments.at(i)) {
|
||||
StrAppend(&msg, node->name(), ", ");
|
||||
}
|
||||
VLOG(1) << msg;
|
||||
@ -848,7 +873,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
// If status is ok, we successfully added the node to the graph and can
|
||||
// remove segment ops. Otherwise graph is not modified.
|
||||
if (status.ok()) {
|
||||
for (const Node* node : converted_segments.at(i).first) {
|
||||
for (const Node* node : converted_segments.at(i)) {
|
||||
graph.RemoveNode(const_cast<Node*>(node));
|
||||
}
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ class ConvertAfterShapesTest : public ::testing::Test {
|
||||
params.output_names = &output_names;
|
||||
params.max_workspace_size_bytes = 8 << 20;
|
||||
params.output_graph_def = output_graph_def;
|
||||
params.minimum_segment_size = 2;
|
||||
params.minimum_segment_size = 1;
|
||||
params.graph_properties = &graph_properties;
|
||||
params.use_calibration = false;
|
||||
|
||||
|
@ -385,11 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) {
|
||||
", dims=", DebugString(tensor.getDimensions()), ")");
|
||||
}
|
||||
|
||||
Status Converter::GetTrtBroadcastShape(
|
||||
const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
|
||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||
const TRT_TensorOrWeights& operand_r,
|
||||
nvinfer1::Dims* operand_l_new_dims,
|
||||
nvinfer1::Dims* operand_r_new_dims) const {
|
||||
// ***************************************************************************
|
||||
nvinfer1::Dims* operand_r_new_dims) {
|
||||
// TensorRT Elementwise op supports broadcast but requires both tensor to be
|
||||
// of Identical rank
|
||||
//
|
||||
@ -473,14 +472,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
|
||||
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
|
||||
nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
|
||||
if (!layer) return nullptr;
|
||||
const nvinfer1::DataType trt_dtype = trt_weights.type;
|
||||
nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
|
||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
// TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
|
||||
// the data type below, it will always be kFLOAT regardless what the data type
|
||||
// of the weights is. Once NVIDIA fixes this bug, we should remove the data
|
||||
// type setting logic below and test should still pass.
|
||||
trt_tensor->setType(trt_dtype);
|
||||
trt_tensor->setType(trt_weights.type);
|
||||
#endif
|
||||
return trt_tensor;
|
||||
}
|
||||
@ -1677,190 +1675,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
|
||||
// right operand. If swapped_inputs is true, those two are swapped.
|
||||
//
|
||||
// TODO(jie): broadcast is needed yet not implemented.
|
||||
// Only implemented channel wise for the time being.
|
||||
Status BinaryTensorOpWeight(OpConverterParams* params,
|
||||
nvinfer1::ITensor* tensor,
|
||||
TRT_ShapedWeights weights, bool swapped_inputs) {
|
||||
static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
|
||||
"Div", "RealDiv"};
|
||||
const auto& node_def = params->node_def;
|
||||
if (!supported_ops.count(node_def.op())) {
|
||||
return errors::Unimplemented(node_def.op(), " is not supported, at ",
|
||||
node_def.name());
|
||||
}
|
||||
|
||||
// Check scale mode.
|
||||
auto dims_w = weights.shape_;
|
||||
const auto dims_t = tensor->getDimensions();
|
||||
|
||||
// TODO(jie): addScale checks for input tensor dimension
|
||||
if (dims_t.nbDims != 3) {
|
||||
return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
|
||||
node_def.name());
|
||||
}
|
||||
|
||||
// Default to element-wise
|
||||
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
|
||||
|
||||
// TODO(jie): maybe use a permutation instead to support more cases;
|
||||
bool need_to_permute = false;
|
||||
|
||||
if (weights.count() == 1) {
|
||||
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
|
||||
} else {
|
||||
VLOG(2) << "weights dims: " << DebugString(dims_w)
|
||||
<< "; tensor dims: " << DebugString(dims_t);
|
||||
// Make sure no broadcasting on batch dimension.
|
||||
if (dims_w.nbDims == dims_t.nbDims + 1) {
|
||||
if (dims_w.d[0] == 1) {
|
||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
||||
dims_w.d[i - 1] = dims_w.d[i];
|
||||
}
|
||||
dims_w.nbDims--;
|
||||
} else {
|
||||
return errors::InvalidArgument("Binary op cannot operate on batch, at ",
|
||||
node_def.name());
|
||||
}
|
||||
}
|
||||
|
||||
if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
|
||||
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
|
||||
// Default is element-wise
|
||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
||||
if (dims_w.d[i] != dims_t.d[i]) {
|
||||
// If dimension does not match, switch back to per-channel
|
||||
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If the mode is per-channel, since channel dimension is assumed to be
|
||||
// the third to last dimension, we need to make sure all other dimensions
|
||||
// have size 1.
|
||||
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
|
||||
for (int i = 1; i < dims_w.nbDims; i++) {
|
||||
if (dims_w.d[i] != 1)
|
||||
return errors::InvalidArgument(
|
||||
"Weight dims not compatible for channel-wise broadcast at ",
|
||||
node_def.name());
|
||||
}
|
||||
}
|
||||
} else if (dims_w.nbDims == 1 &&
|
||||
dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
|
||||
// Channel wise and broadcast required. We compare the last dimension of
|
||||
// the tensor shape because of tensorflow default broadcasting rules.
|
||||
need_to_permute = true;
|
||||
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
|
||||
} else {
|
||||
return errors::InvalidArgument("Weight dims not compatible at ",
|
||||
node_def.name());
|
||||
}
|
||||
}
|
||||
// TODO(laigd): we should add validation_only support in TransposeTensor() and
|
||||
// PrepareTensorForShape().
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Transpose last dimension.
|
||||
std::vector<int> permutation(dims_t.nbDims + 1);
|
||||
if (need_to_permute) {
|
||||
// We swap the last dimension into channel for trt, because of tensorflow
|
||||
// default broadcasting rules.
|
||||
for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
|
||||
permutation[i] = i;
|
||||
}
|
||||
permutation[1] = dims_t.nbDims;
|
||||
permutation[dims_t.nbDims] = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
params->converter->TransposeTensor(tensor, permutation, &tensor));
|
||||
}
|
||||
|
||||
// Prepare weights
|
||||
TRT_ShapedWeights shift_weights(weights.TrtDType());
|
||||
TRT_ShapedWeights scale_weights(weights.TrtDType());
|
||||
TRT_ShapedWeights power_weights(weights.TrtDType());
|
||||
|
||||
if (node_def.op() == "Sub") {
|
||||
if (swapped_inputs) {
|
||||
shift_weights = weights;
|
||||
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
|
||||
*tensor, nvinfer1::UnaryOperation::kNEG);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
// Since quantization ranges are symmetric, the same range as the input
|
||||
// will work for the negation of the input.
|
||||
params->converter->MarkQuantizationRangesAsInferrable(
|
||||
tensor, layer->getOutput(0));
|
||||
tensor = layer->getOutput(0);
|
||||
} else {
|
||||
TRT_ShapedWeights neg_weights =
|
||||
params->weight_store->GetTempWeights(weights);
|
||||
LambdaFactory unary_op;
|
||||
unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
|
||||
TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
|
||||
shift_weights = neg_weights;
|
||||
}
|
||||
} else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
|
||||
if (swapped_inputs) {
|
||||
// We need to infer the quantization range for this intermediate tensor.
|
||||
//
|
||||
// x -> [Recip] -> 1/x -> [Scale] -> s/x
|
||||
// ^
|
||||
// need range for this
|
||||
//
|
||||
// We have the quantization scales for x and s/x - can we divide the scale
|
||||
// for s/x by s? Only if it is a scalar.
|
||||
//
|
||||
// Because of this issue, fall back to BinaryTensorOpTensor if we are
|
||||
// doing INT8 with no calibration. There is most likely no performance
|
||||
// penalty by falling back here.
|
||||
if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
|
||||
!params->converter->use_calibration()) {
|
||||
return errors::Unimplemented(
|
||||
"Intermediate quantization range cannot be determined without"
|
||||
" calibration. Falling back to BinaryTensorOpTensor for ",
|
||||
node_def.op(), ", at ", node_def.name());
|
||||
}
|
||||
scale_weights = weights;
|
||||
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
|
||||
*tensor, nvinfer1::UnaryOperation::kRECIP);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
tensor = layer->getOutput(0);
|
||||
} else {
|
||||
TRT_ShapedWeights recip_weights =
|
||||
params->weight_store->GetTempWeights(weights);
|
||||
LambdaFactory unary_op;
|
||||
unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
|
||||
TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
|
||||
scale_weights = recip_weights;
|
||||
}
|
||||
} else if (node_def.op() == "Mul") {
|
||||
scale_weights = weights;
|
||||
} else if (node_def.op() == "Add") {
|
||||
shift_weights = weights;
|
||||
} else {
|
||||
// This should not happen.
|
||||
return errors::Unimplemented("Binary op not supported at ", node_def.op());
|
||||
}
|
||||
|
||||
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
|
||||
*tensor, scale_mode, shift_weights.GetTrtWeights(),
|
||||
scale_weights.GetTrtWeights(), power_weights.GetTrtWeights());
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
// Transpose back dimension
|
||||
if (need_to_permute) {
|
||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||
output_tensor, permutation, &output_tensor));
|
||||
}
|
||||
|
||||
// Pass the output
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
bool is_conv2d_backprop_input) {
|
||||
const auto& inputs = params->inputs;
|
||||
@ -1951,7 +1765,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
kernel_size.h() = weights.shape_.d[2];
|
||||
kernel_size.w() = weights.shape_.d[3];
|
||||
|
||||
// Add padding.
|
||||
// Before TRT 5.1.3, we have to calculate padding ourselves.
|
||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
std::vector<std::pair<int, int>> padding;
|
||||
if (attrs.get<string>("padding") == "SAME") {
|
||||
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
||||
@ -1978,12 +1793,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
padding = {{0, 0}, {0, 0}};
|
||||
}
|
||||
|
||||
// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we
|
||||
// can only use asymmetric padding in convolutions with 5.1.3+.
|
||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
// Handle asymmetric padding. TensorRT 5.1 added support for asymmetric
|
||||
// padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can
|
||||
// only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we
|
||||
// will always use setPaddingMode for simplicity.
|
||||
if (padding[0].first != padding[0].second ||
|
||||
padding[1].first != padding[1].second) {
|
||||
// Handle asymmetric padding.
|
||||
auto pad_layer = params->converter->network()->addPadding(
|
||||
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
||||
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
||||
@ -2006,20 +1821,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
layer->setStride(stride);
|
||||
// TensorRT 5.1.3 added support for padding modes.
|
||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
// VALID padding is the default TRT behavior.
|
||||
if (attrs.get<string>("padding") == "SAME") {
|
||||
VLOG(2) << "Using SAME padding";
|
||||
// SAME_UPPER means that post padding is preferred.
|
||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||
}
|
||||
// For VALID padding, we need to manually set the padding.
|
||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||
layer->setPostPadding(
|
||||
nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
||||
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
|
||||
<< " and post-padding to: " << DebugString(layer->getPostPadding());
|
||||
#else
|
||||
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
|
||||
#endif
|
||||
layer->setName(node_def.name().c_str());
|
||||
layer->setNbGroups(num_groups);
|
||||
@ -2033,17 +1841,10 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
layer->setStride(stride);
|
||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
if (attrs.get<string>("padding") == "SAME") {
|
||||
VLOG(2) << "Using SAME padding";
|
||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||
}
|
||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||
layer->setPostPadding(
|
||||
nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
||||
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
|
||||
<< " and post-padding to: " << DebugString(layer->getPostPadding());
|
||||
#else
|
||||
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
|
||||
#endif
|
||||
layer->setName(node_def.name().c_str());
|
||||
layer->setNbGroups(num_groups);
|
||||
@ -2061,74 +1862,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BinaryTensorOpTensor(OpConverterParams* params,
|
||||
const TRT_TensorOrWeights& operand_l,
|
||||
const TRT_TensorOrWeights& operand_r) {
|
||||
const auto& node_def = params->node_def;
|
||||
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
|
||||
{"Add", nvinfer1::ElementWiseOperation::kSUM},
|
||||
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
|
||||
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
|
||||
{"Div", nvinfer1::ElementWiseOperation::kDIV},
|
||||
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
|
||||
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
|
||||
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
|
||||
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
|
||||
};
|
||||
auto op_pair = ops.find(node_def.op());
|
||||
if (op_pair == ops.end()) {
|
||||
return errors::Unimplemented("Binary op ", node_def.op(),
|
||||
" not supported at: ", node_def.name());
|
||||
}
|
||||
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
Status status = params->converter->GetTrtBroadcastShape(
|
||||
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
|
||||
if (!status.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
|
||||
status.error_message());
|
||||
}
|
||||
TFAttrs attrs(node_def);
|
||||
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
|
||||
if (dtype == nvinfer1::DataType::kINT32) {
|
||||
return errors::Unimplemented("Binary op ", node_def.op(),
|
||||
" does not support INT32, at ",
|
||||
node_def.name());
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::ITensor* tensor_l = nullptr;
|
||||
nvinfer1::ITensor* tensor_r = nullptr;
|
||||
status = params->converter->PrepareTensorForShape(
|
||||
operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l);
|
||||
if (status.ok()) {
|
||||
status = params->converter->PrepareTensorForShape(
|
||||
operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
return errors::Internal("Failed to convert binary op ", node_def.name(),
|
||||
": ", status.error_message());
|
||||
}
|
||||
|
||||
// Check type consistency.
|
||||
TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
|
||||
<< DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
|
||||
TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
|
||||
<< DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
|
||||
|
||||
// Add ElementWise layer.
|
||||
nvinfer1::IElementWiseLayer* layer =
|
||||
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
|
||||
op_pair->second);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
|
||||
// Pass the output
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertPlugin(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
@ -2777,6 +2510,8 @@ Status ConvertPool(OpConverterParams* params) {
|
||||
const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
|
||||
const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
|
||||
|
||||
// Before TRT 5.1.3, we have to calculate padding ourselves.
|
||||
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
auto tensor_dim = tensor->getDimensions();
|
||||
std::vector<std::pair<int, int>> padding;
|
||||
if (padding_type == "SAME") {
|
||||
@ -2789,13 +2524,13 @@ Status ConvertPool(OpConverterParams* params) {
|
||||
} else if (padding_type == "VALID") {
|
||||
padding = {{0, 0}, {0, 0}};
|
||||
}
|
||||
|
||||
// TensorRT 5.1 added support for asymmetric padding.
|
||||
#endif
|
||||
// TensorRT 5.1 added support for asymmetric padding. Before that, we need an
|
||||
// extra padding layer.
|
||||
#if !IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
// Asymmetric padding case.
|
||||
if (padding[0].first != padding[0].second ||
|
||||
padding[1].first != padding[1].second) {
|
||||
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
|
||||
<< padding[1].first << padding[1].second;
|
||||
auto pad_layer = params->converter->network()->addPadding(
|
||||
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
||||
nvinfer1::DimsHW(padding[0].second, padding[1].second));
|
||||
@ -2817,16 +2552,13 @@ Status ConvertPool(OpConverterParams* params) {
|
||||
layer->getOutput(0));
|
||||
|
||||
layer->setStride(stride);
|
||||
// TensorRT 5.1.3 added support for padding modes.
|
||||
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
|
||||
// VALID padding is the default TRT behavior.
|
||||
if (attrs.get<string>("padding") == "SAME") {
|
||||
// SAME_UPPER means that post padding is preferred.
|
||||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||
}
|
||||
#endif
|
||||
// TensorRT 5.1 has support for asymmetric padding.
|
||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
// If padding mode is not SAME, then these values will be used instead.
|
||||
#elif IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
|
||||
layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
|
||||
#else
|
||||
@ -3350,9 +3082,6 @@ Status ConvertIdentity(OpConverterParams* params) {
|
||||
Status ConvertBinary(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
|
||||
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
|
||||
// false}}));
|
||||
if (inputs.size() != 2) {
|
||||
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
|
||||
" inputs but expected 2, at ",
|
||||
@ -3368,33 +3097,45 @@ Status ConvertBinary(OpConverterParams* params) {
|
||||
"both input as constant at: ",
|
||||
node_def.name());
|
||||
}
|
||||
const TRT_TensorOrWeights& operand_l = inputs.at(0);
|
||||
const TRT_TensorOrWeights& operand_r = inputs.at(1);
|
||||
|
||||
// TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with
|
||||
// IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For
|
||||
// now, the performance will be slightly better with IScaleLayer because it
|
||||
// can be fused in more situations. However, most of the benefits of
|
||||
// IScaleLayer are when the layer performs both a shift and a scale, which we
|
||||
// don't do except for convolutions.
|
||||
//
|
||||
// Try to convert into Scale layer first (for better performance).
|
||||
// Since scale layer supports restricted broadcast policy and op types, we
|
||||
// allow failure and try to handle it through Elementwise op
|
||||
// (BinaryTensorOpTensor).
|
||||
Status status = Status::OK();
|
||||
if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
|
||||
status = BinaryTensorOpWeight(params, inputs.at(0).tensor(),
|
||||
inputs.at(1).weights(), false);
|
||||
} else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
|
||||
status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
|
||||
inputs.at(0).weights(), true);
|
||||
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
|
||||
{"Add", nvinfer1::ElementWiseOperation::kSUM},
|
||||
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
|
||||
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
|
||||
{"Div", nvinfer1::ElementWiseOperation::kDIV},
|
||||
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
|
||||
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
|
||||
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
|
||||
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
|
||||
};
|
||||
auto op_pair = ops.find(node_def.op());
|
||||
if (op_pair == ops.end()) {
|
||||
return errors::Unimplemented("Binary op ", node_def.op(),
|
||||
" not supported at: ", node_def.name());
|
||||
}
|
||||
// If both input are tensors, or one of them is weights but the conversion
|
||||
// above failed, try the conversion using BinaryTensorOpTensor.
|
||||
if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
|
||||
if (!status.ok()) VLOG(2) << status;
|
||||
status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
|
||||
}
|
||||
return status;
|
||||
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r));
|
||||
|
||||
nvinfer1::ITensor* tensor_l = nullptr;
|
||||
nvinfer1::ITensor* tensor_r = nullptr;
|
||||
// This will also convert constants to tensors, and set quantization ranges.
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
operand_l, broadcasted_dims_l, params->validation_only, &tensor_l));
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
operand_r, broadcasted_dims_r, params->validation_only, &tensor_r));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Add ElementWise layer.
|
||||
nvinfer1::IElementWiseLayer* layer =
|
||||
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
|
||||
op_pair->second);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertRsqrt(OpConverterParams* params) {
|
||||
@ -4547,7 +4288,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
// Broadcast inputs.
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape(
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
|
||||
nvinfer1::ITensor* tensor_l = nullptr;
|
||||
nvinfer1::ITensor* tensor_r = nullptr;
|
||||
@ -4692,8 +4433,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) {
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
|
||||
|
||||
// Create plugin
|
||||
nvinfer1::IPluginV2* plugin =
|
||||
creator->createPlugin(node_def.name().c_str(), &fc);
|
||||
TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
|
||||
creator->createPlugin(node_def.name().c_str(), &fc));
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
|
||||
|
||||
// Set plugin inputs
|
||||
@ -4875,7 +4616,8 @@ static void RegisterValidatableOpConverters(
|
||||
for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
|
||||
(*registration)[pool_op_type] = ConvertPool;
|
||||
}
|
||||
for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) {
|
||||
for (auto normalization_op_type :
|
||||
{"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) {
|
||||
(*registration)[normalization_op_type] = ConvertFusedBatchNorm;
|
||||
}
|
||||
for (auto unary_op_pair : *UnaryOperationMap()) {
|
||||
|
@ -512,13 +512,6 @@ class Converter {
|
||||
const bool validation_only,
|
||||
nvinfer1::ITensor** tensor);
|
||||
|
||||
// Return OK if the broadcast scheme is supported and compute the shapes after
|
||||
// broadcasting.
|
||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||
const TRT_TensorOrWeights& operand_r,
|
||||
nvinfer1::Dims* operand_l_new_dims,
|
||||
nvinfer1::Dims* operand_r_new_dims) const;
|
||||
|
||||
// Creates an IConstantLayer using 'weights' whose dimensions are specified by
|
||||
// 'dims', and returns the output ITensor.
|
||||
nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
|
||||
@ -592,6 +585,13 @@ class Converter {
|
||||
friend class OpConverterTest;
|
||||
};
|
||||
|
||||
// Return OK if the broadcast scheme is supported and compute the shapes after
|
||||
// broadcasting.
|
||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||
const TRT_TensorOrWeights& operand_r,
|
||||
nvinfer1::Dims* operand_l_new_dims,
|
||||
nvinfer1::Dims* operand_r_new_dims);
|
||||
|
||||
// Map of all supported UnaryOperations
|
||||
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
|
||||
// Map of all supported ActivationTypes
|
||||
|
@ -988,18 +988,16 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
|
||||
operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
|
||||
|
||||
// operand_1 broadcast operand_2
|
||||
ExpectStatus(
|
||||
this->converter_->GetTrtBroadcastShape(
|
||||
operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
|
||||
ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims,
|
||||
&operand_2_new_dims),
|
||||
expected_code, expected_error_msg_substr);
|
||||
if (expected_code == error::OK) {
|
||||
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
||||
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
|
||||
}
|
||||
// operand_2 broadcast operand_1
|
||||
ExpectStatus(
|
||||
this->converter_->GetTrtBroadcastShape(
|
||||
operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
|
||||
ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims,
|
||||
&operand_1_new_dims),
|
||||
expected_code, expected_error_msg_substr);
|
||||
if (expected_code == error::OK) {
|
||||
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
|
||||
@ -1033,18 +1031,29 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
|
||||
error::INVALID_ARGUMENT,
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims 4 vs broadcast #dims 5)");
|
||||
symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {},
|
||||
error::INVALID_ARGUMENT,
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims 2 vs broadcast #dims 3)",
|
||||
/*operand_1_batch_size=*/2);
|
||||
|
||||
// Both inputs are tensors.
|
||||
symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
|
||||
error::INVALID_ARGUMENT,
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims 3 vs broadcast #dims 4)");
|
||||
symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {},
|
||||
error::INVALID_ARGUMENT,
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims 2 vs broadcast #dims 3)");
|
||||
symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
|
||||
{2, 1, 4});
|
||||
symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
|
||||
error::INVALID_ARGUMENT,
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims 4 vs broadcast #dims 5)");
|
||||
symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {},
|
||||
error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, CreateConstantLayer) {
|
||||
@ -1070,7 +1079,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
||||
int batch_size = -1;
|
||||
for (const NodeDef& node : gdef.node()) {
|
||||
absl::string_view node_name(node.name());
|
||||
if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
|
||||
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
|
||||
int port = -1;
|
||||
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
||||
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
||||
@ -1351,10 +1360,6 @@ class OpConverterTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
void TestMatMulHelper(
|
||||
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
||||
const std::string& op_name);
|
||||
|
||||
// Expose quantization_ranges_ for tests
|
||||
std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
|
||||
return converter_->quantization_ranges_;
|
||||
@ -1682,59 +1687,60 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
// Helper function for testing MatMul and BatchMatMul
|
||||
// get_matmul corresponds to the function used to generate the node. It should
|
||||
// accept (DataType, transpose_a, transpose_b) as parameters.
|
||||
void OpConverterTest::TestMatMulHelper(
|
||||
void TestMatMulHelper(
|
||||
OpConverterTest* test,
|
||||
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
||||
const std::string& op_name) {
|
||||
// HACK: This needs to be done in a better way.
|
||||
const bool is_batch_matmul = op_name == "BatchMatMul";
|
||||
{
|
||||
// Unsupported data type.
|
||||
Reset();
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_INT32, false, false);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
|
||||
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||
RunValidationAndConversion(
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1,
|
||||
nvinfer1::DataType::kINT32);
|
||||
test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
("Data type int32 is not supported for " + op_name +
|
||||
", "
|
||||
"must be one of [float, half], at my_matmul")
|
||||
StrCat("Data type int32 is not supported for ", op_name,
|
||||
", must be one of [float, half], at my_matmul")
|
||||
.c_str());
|
||||
}
|
||||
// OK.
|
||||
for (bool transpose_a : {false, true}) {
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Reset();
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
if (is_batch_matmul) {
|
||||
if (transpose_a || transpose_b) {
|
||||
RunValidationAndConversion(
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input weight attempts to broadcast across batch dimension for "
|
||||
"BatchMatMul, at my_matmul");
|
||||
} else {
|
||||
RunValidationAndConversion(
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input weight attempts to broadcast across batch dimension");
|
||||
}
|
||||
continue;
|
||||
} else if (transpose_a) {
|
||||
RunValidationAndConversion(
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions");
|
||||
continue;
|
||||
}
|
||||
RunValidationAndConversion(node_def);
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
||||
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
||||
BuildAndRun(input_data, &output_data);
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
if (transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||
} else {
|
||||
@ -1744,31 +1750,31 @@ void OpConverterTest::TestMatMulHelper(
|
||||
}
|
||||
// OK, 3D inputs
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Reset();
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
if (is_batch_matmul) {
|
||||
if (transpose_b) {
|
||||
RunValidationAndConversion(
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input weight attempts to broadcast across batch dimension for "
|
||||
"BatchMatMul, at my_matmul");
|
||||
} else {
|
||||
RunValidationAndConversion(
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input weight attempts to broadcast across batch dimension");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
RunValidationAndConversion(node_def);
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
|
||||
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
|
||||
BuildAndRun(input_data, &output_data);
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
if (transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||
} else {
|
||||
@ -1832,7 +1838,7 @@ TEST_F(OpConverterTest, ConvertMatMul) {
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Cannot currently transpose constant input if it is not 2 dimensional");
|
||||
}
|
||||
TestMatMulHelper(get_matmul_nodedef, "MatMul");
|
||||
TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
||||
@ -1889,7 +1895,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
||||
}
|
||||
}
|
||||
|
||||
TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul");
|
||||
TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
@ -2010,250 +2016,82 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
|
||||
}
|
||||
|
||||
template <typename OpType, DataType dtype>
|
||||
void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
for (auto swap_inputs : {false, true}) {
|
||||
test->Reset();
|
||||
NodeDef node_def;
|
||||
if (swap_inputs) {
|
||||
node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
|
||||
} else {
|
||||
node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
|
||||
}
|
||||
|
||||
const std::vector<CType> operand1{CType(3), CType(7.5)};
|
||||
const std::vector<CType> operand2{CType(2), CType(3)};
|
||||
|
||||
// It requires the dims to be at least of rank 3 to apply an IScaleLayer.
|
||||
test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
|
||||
/*values=*/swap_inputs ? operand1 : operand2);
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
||||
|
||||
// Check the dims of the output ITensor.
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{
|
||||
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
|
||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
|
||||
test->BuildAndRun(
|
||||
input_data, &output_data,
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||
if (node_def.op() == "Add") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(5), CType(10.5)));
|
||||
} else if (node_def.op() == "Sub") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1), CType(4.5)));
|
||||
} else if (node_def.op() == "Mul") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(6), CType(22.5)));
|
||||
} else if (node_def.op() == "Div") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1.5), CType(2.5)));
|
||||
} else if (node_def.op() == "RealDiv") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1.5), CType(2.5)));
|
||||
} else {
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const NodeDef node_def =
|
||||
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
|
||||
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
|
||||
const std::vector<CType> weights{CType(10), CType(20)};
|
||||
// There are two types of valid dim pairs which requires channel-wise
|
||||
// broadcasting:
|
||||
// - input dims (X Y Z) vs weights dims (X 1 1)
|
||||
// - input dims (X Y Z) vs weights dims (Z)
|
||||
// Here X=Z=2 and Y=1.
|
||||
for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
|
||||
test->Reset();
|
||||
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->AddTestWeights<CType>("weights", weights_dims, weights);
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
||||
|
||||
// Check the dims of the output ITensor.
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
|
||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
if (weights_dims.size() == 1) {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(11), CType(22), CType(13), CType(24)));
|
||||
} else {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(11), CType(12), CType(23), CType(24)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const NodeDef node_def =
|
||||
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
|
||||
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
|
||||
const std::vector<CType> weights{CType(10)};
|
||||
test->Reset();
|
||||
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
|
||||
CheckAddedLayers(test, /*expect_scale_layer=*/true);
|
||||
|
||||
// Check the dims of the output ITensor.
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
|
||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(11), CType(12), CType(13), CType(14)));
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
|
||||
const std::vector<int32>& input_dims,
|
||||
const std::vector<int>& weights_dims,
|
||||
error::Code code = error::OK,
|
||||
const char* error_msg_substr = nullptr,
|
||||
const int input_batch_size = 1) {
|
||||
const DataType dtype = DT_FLOAT;
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims));
|
||||
const size_t num_weights =
|
||||
TrtWeightDimsNumElements(GetTestDims(weights_dims));
|
||||
|
||||
test->Reset();
|
||||
const NodeDef node_def =
|
||||
GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
|
||||
test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->AddTestWeights<CType>(
|
||||
"weights", /*dims=*/weights_dims,
|
||||
/*values=*/std::vector<CType>(num_weights, CType(1)));
|
||||
test->RunValidationAndConversion(node_def, code, error_msg_substr);
|
||||
if (code != error::OK) return;
|
||||
|
||||
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
|
||||
CheckAddedLayers(test, /*expect_scale_layer=*/false);
|
||||
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
|
||||
// Check the dims of the output ITensor.
|
||||
std::vector<int> expected_output_dims = input_dims;
|
||||
for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
|
||||
i >= 0 && j >= 0; --i, --j) {
|
||||
if (expected_output_dims[i] == 1) {
|
||||
expected_output_dims[i] = weights_dims[j];
|
||||
}
|
||||
}
|
||||
ExpectTrtDimsEqualsArray(expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
// Check the result of running the engine.
|
||||
const int expected_num_outputs =
|
||||
TrtTensorDimsNumElements(GetTestDims(expected_output_dims));
|
||||
const DataVec input_data{
|
||||
{"input", ConstructTensor<CType>(num_inputs, CType(2))}};
|
||||
DataVec output_data{
|
||||
{"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
if (node_def.op() == "Add") {
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
|
||||
} else if (node_def.op() == "Minimum") {
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
|
||||
} else {
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OpType, DataType dtype>
|
||||
void TestBinaryTensorOpTensor(OpConverterTest* test) {
|
||||
void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor,
|
||||
bool operand_2_is_tensor) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
test->Reset();
|
||||
const NodeDef node_def =
|
||||
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
|
||||
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
|
||||
if (operand_1_is_tensor) {
|
||||
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
|
||||
} else {
|
||||
test->AddTestWeights("input1", /*dims=*/{1, 2},
|
||||
/*values=*/std::vector<CType>{CType(3), CType(6)});
|
||||
}
|
||||
if (operand_2_is_tensor) {
|
||||
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2,
|
||||
TfDataTypeToTrt(dtype));
|
||||
} else {
|
||||
test->AddTestWeights("input2", /*dims=*/{2, 1},
|
||||
/*values=*/std::vector<CType>{CType(2), CType(3)});
|
||||
}
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
|
||||
CheckAddedLayers(test, /*expect_scale_layer=*/false);
|
||||
|
||||
DataVec input_data;
|
||||
if (operand_1_is_tensor) {
|
||||
input_data.push_back(
|
||||
{"input1",
|
||||
test::AsTensor<CType>({CType(3), CType(6), CType(3), CType(6)})});
|
||||
}
|
||||
if (operand_2_is_tensor) {
|
||||
input_data.push_back(
|
||||
{"input2",
|
||||
test::AsTensor<CType>({CType(2), CType(3), CType(2), CType(3)})});
|
||||
}
|
||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(8)}};
|
||||
// Check output dims.
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{
|
||||
{"input1", test::AsTensor<CType>({CType(3), CType(6)})},
|
||||
{"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
|
||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
||||
// After broadcasting first input becomes {3, 6, 3, 6} and second input
|
||||
// becomes {2, 3, 2, 3}.
|
||||
test->BuildAndRun(
|
||||
input_data, &output_data,
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32,
|
||||
/*batch_size=*/2);
|
||||
if (node_def.op() == "Add") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(CastTestVector<int, CType>({5, 8, 6, 9, 5, 8, 6, 9})));
|
||||
} else if (node_def.op() == "Sub") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1), CType(4), CType(0), CType(3)));
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(CastTestVector<int, CType>({1, 4, 0, 3, 1, 4, 0, 3})));
|
||||
} else if (node_def.op() == "Mul") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(6), CType(12), CType(9), CType(18)));
|
||||
ElementsAreArray(
|
||||
CastTestVector<int, CType>({6, 12, 9, 18, 6, 12, 9, 18})));
|
||||
} else if (node_def.op() == "Div") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
|
||||
ElementsAreArray(CastTestVector<float, CType>(
|
||||
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
|
||||
} else if (node_def.op() == "RealDiv") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
|
||||
ElementsAreArray(CastTestVector<float, CType>(
|
||||
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
|
||||
} else if (node_def.op() == "Minimum") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(2), CType(2), CType(3), CType(3)));
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(CastTestVector<int, CType>({2, 2, 3, 3, 2, 2, 3, 3})));
|
||||
} else if (node_def.op() == "Maximum") {
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAre(CType(3), CType(6), CType(3), CType(6)));
|
||||
EXPECT_THAT(
|
||||
GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(CastTestVector<int, CType>({3, 6, 3, 6, 3, 6, 3, 6})));
|
||||
} else if (node_def.op() == "Pow") {
|
||||
ExpectArrayNear(
|
||||
std::vector<CType>{CType(9), CType(36), CType(27), CType(216)},
|
||||
CastTestVector<int, CType>({9, 36, 27, 216, 9, 36, 27, 216}),
|
||||
GetSpanForData<CType>(output_data[0]));
|
||||
} else {
|
||||
ASSERT_TRUE(false);
|
||||
@ -2287,58 +2125,48 @@ TEST_F(OpConverterTest, ConvertBinary) {
|
||||
"both input as constant at: my_add");
|
||||
}
|
||||
|
||||
// Test BinaryTensorOpWeight() without broadcasting.
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
|
||||
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
|
||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
|
||||
|
||||
// Test BinaryTensorOpWeight() with channel-wise broadcasting.
|
||||
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
|
||||
|
||||
// Test BinaryTensorOpWeight() with uniformly broadcasting.
|
||||
TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
|
||||
|
||||
// Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
|
||||
// Unsupported op.
|
||||
TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
|
||||
// Rank of input tensor dimension <3.
|
||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
|
||||
// Broadcast on batch dimension, should fail.
|
||||
TestBinaryTensorOpWeightFallback<ops::Add>(
|
||||
this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
|
||||
"Unsupported binary op broadcast scheme for op my_binary",
|
||||
/*input_batch_size=*/2);
|
||||
// Incompatible dims with per-channel mode.
|
||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
|
||||
// Incompatible dims.
|
||||
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
|
||||
|
||||
// Test BinaryTensorOpTensor() with broadcasting.
|
||||
TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
|
||||
TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
|
||||
|
||||
TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
|
||||
TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
|
||||
// Test combinations of tensor vs weight inputs (except when both inputs are
|
||||
// weights).
|
||||
for (const bool operand_1_is_tensor : {true, false}) {
|
||||
for (const bool operand_2_is_tensor : {true, false}) {
|
||||
if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
|
||||
// FP32 tests
|
||||
TestBinaryOp<ops::Add, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Sub, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Mul, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Div, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::RealDiv, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Minimum, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Maximum, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Pow, DT_FLOAT>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
// FP16 tests
|
||||
// TODO(tmorris): Use templates to avoid duplication.
|
||||
TestBinaryOp<ops::Add, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Sub, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Mul, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Div, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::RealDiv, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Minimum, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Maximum, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
TestBinaryOp<ops::Pow, DT_HALF>(this, operand_1_is_tensor,
|
||||
operand_2_is_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertQuantize) {
|
||||
@ -2583,7 +2411,6 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
|
||||
// implementation that, the extra output classes that are outside of the
|
||||
// range specified by valid_detections[i] are not zeros but -1s.
|
||||
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
|
||||
const int batch_size = 1;
|
||||
|
||||
for (int i = 0; i < kCombinedNMSOKCases; ++i) {
|
||||
Reset();
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
@ -32,9 +34,9 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
// TODO(sami): Remove VLOG messages once the code matures
|
||||
using absl::AsciiStrToUpper;
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
using str_util::Uppercase;
|
||||
|
||||
Status TRTOptimizationPass::Init(
|
||||
const RewriterConfig_CustomGraphOptimizer* config) {
|
||||
@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init(
|
||||
}
|
||||
if (params.count("precision_mode")) {
|
||||
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
|
||||
Uppercase(params.at("precision_mode").s()), &precision_mode_));
|
||||
AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
|
||||
}
|
||||
if (params.count("use_calibration")) {
|
||||
use_calibration_ = params.at("use_calibration").b();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
||||
Tensor* output = context_->mutable_output(0);
|
||||
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
|
||||
// string type output will remain on CPU, so we're not using GetOutput() here.
|
||||
EXPECT_EQ("my_serialized_str",
|
||||
context_->mutable_output(0)->scalar<string>()());
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -87,15 +87,10 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
||||
Tensor* output = OpsTestBase::context_->mutable_output(0);
|
||||
const auto& tensor_map = output->flat<TypeParam>();
|
||||
std::vector<TypeParam> output_data(tensor_map.size());
|
||||
ASSERT_EQ(0, cudaDeviceSynchronize());
|
||||
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
|
||||
sizeof(TypeParam) * tensor_map.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
|
||||
Tensor* output = OpsTestBase::GetOutput(0);
|
||||
EXPECT_THAT(
|
||||
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
|
||||
output->NumElements()),
|
||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||
}
|
||||
|
||||
|
@ -681,31 +681,33 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
<< " with parent=" << segment_root << ":" << s;
|
||||
}
|
||||
|
||||
// Don't use small segments.
|
||||
if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
|
||||
const int num_effective_nodes = std::count_if(
|
||||
segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
|
||||
static auto noops =
|
||||
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
|
||||
return noops->count(node->type_string()) == 0;
|
||||
});
|
||||
|
||||
// Don't use segments whose number of effective nodes is small.
|
||||
if (num_effective_nodes < options.minimum_segment_size) {
|
||||
VLOG(1) << "Segment " << segments->size() << " has only "
|
||||
<< segment_nodes.size() << " nodes, dropping";
|
||||
<< num_effective_nodes << " effective nodes, dropping";
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(sami): Make segmenter placement aware once trtscopes are in place
|
||||
const auto& dev_itr = device_maps.find(segment_root);
|
||||
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
|
||||
VLOG(1) << "No device assigned to segment " << segments->size();
|
||||
segments->emplace_back(std::make_pair(segment_nodes, string()));
|
||||
} else if (dev_itr->second.size() > 1) {
|
||||
string s("Segment ");
|
||||
StrAppend(&s, segments->size(), " has multiple devices attached: ");
|
||||
string s = StrCat("Segment ", segments->size(),
|
||||
" has multiple devices attached: ");
|
||||
for (const auto& dev : dev_itr->second) {
|
||||
StrAppend(&s, dev, ", ");
|
||||
}
|
||||
LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
|
||||
segments->emplace_back(
|
||||
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
|
||||
} else {
|
||||
segments->emplace_back(
|
||||
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
|
||||
LOG(WARNING) << s;
|
||||
}
|
||||
|
||||
segments->emplace_back(segment_nodes);
|
||||
}
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const auto& d : device_maps) {
|
||||
|
@ -31,10 +31,8 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
|
||||
// Vector of segments, each entry contains a set of node pointers and a device
|
||||
// name in the segment.
|
||||
using SegmentNodesVector =
|
||||
std::vector<std::pair<std::set<const Node*>, string>>;
|
||||
// Vector of segments, each entry contains a set of node pointers.
|
||||
using SegmentNodesVector = std::vector<std::set<const Node*>>;
|
||||
|
||||
struct SegmentOptions {
|
||||
// Segment must contain at least this many nodes.
|
||||
|
@ -77,7 +77,7 @@ class SegmentTest : public ::testing::Test {
|
||||
EXPECT_EQ(expected_segments.size(), segments.size());
|
||||
for (int i = 0; i < segments.size(); ++i) {
|
||||
std::set<string> segment_node_names;
|
||||
for (const Node* node : segments[i].first) {
|
||||
for (const Node* node : segments[i]) {
|
||||
segment_node_names.insert(node->name());
|
||||
}
|
||||
const auto& expected = expected_segments[i];
|
||||
@ -262,6 +262,23 @@ TEST_F(SegmentTest, BigIfElse) {
|
||||
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
|
||||
}
|
||||
|
||||
TEST_F(SegmentTest, IdentityOps) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
|
||||
auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
|
||||
auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
|
||||
auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
|
||||
auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_EXPECT_OK(s.ToGraph(&g));
|
||||
|
||||
const std::set<string> all_identities = {"identity0", "identity1",
|
||||
"identity2", "identity3"};
|
||||
// Identity ops are not counted as effective ops in the segment, so no segment
|
||||
// will be formed in this case.
|
||||
RunTest(&g, all_identities, all_identities, all_identities, {});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
||||
|
@ -1,7 +1,10 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
@ -23,15 +26,12 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla_supported_ops_lib",
|
||||
@ -67,6 +67,19 @@ xla_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
# A proto library that is minimal in size and dependencies for platforms like Android.
|
||||
tf_portable_proto_library(
|
||||
name = "portable_tf2xla_proto",
|
||||
config_string = "allow_all:true",
|
||||
header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"],
|
||||
portable_deps = ["//tensorflow/core:android_proto_lib"],
|
||||
proto_deps = [
|
||||
":tf2xla_proto",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
xla_py_proto_library(
|
||||
name = "tf2xla_py",
|
||||
has_services = False,
|
||||
|
@ -1,9 +1,8 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
||||
|
||||
tf_gen_op_wrapper_cc(
|
||||
|
@ -918,10 +918,16 @@ string Conditional::name() const {
|
||||
|
||||
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
|
||||
int port) {
|
||||
NodeBuilder id_builder(replacee->name(), "Identity");
|
||||
id_builder.Input(if_node, port);
|
||||
string outside_compilation;
|
||||
if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
|
||||
&outside_compilation)
|
||||
.ok()) {
|
||||
id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
|
||||
}
|
||||
Node* id;
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
|
||||
.Input(if_node, port)
|
||||
.Finalize(graph_, &id));
|
||||
TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
|
||||
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
|
||||
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
|
||||
return Status::OK();
|
||||
|
@ -247,8 +247,8 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
// multiple times, and we want to avoid functionalize it again.
|
||||
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
|
||||
new std::map<string, string>{
|
||||
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
|
||||
{"TPUReplicate", "computation"},
|
||||
// _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
|
||||
{"_TPUReplicate", "computation"},
|
||||
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
|
||||
{"XlaLaunch", "function"},
|
||||
};
|
||||
|
@ -1,9 +1,8 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -195,6 +194,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ class AssertOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Assert"), AssertOp);
|
||||
REGISTER_XLA_OP(Name("Assert").CompilationOnly(), AssertOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -39,7 +39,10 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); }
|
||||
|
||||
protected:
|
||||
virtual void CompileImpl(XlaOpKernelContext* ctx) {
|
||||
xla::PrimitiveType input_type;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
|
||||
@ -116,8 +119,29 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
bool is_on_gpu_;
|
||||
};
|
||||
|
||||
class FusedBatchNormOpV3 : public FusedBatchNormOp {
|
||||
public:
|
||||
explicit FusedBatchNormOpV3(OpKernelConstruction* ctx)
|
||||
: FusedBatchNormOp(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
FusedBatchNormOp::CompileImpl(ctx);
|
||||
if (!ctx->status().ok()) {
|
||||
return;
|
||||
}
|
||||
ctx->SetConstantOutput(5, Tensor());
|
||||
}
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
TensorFormat data_format_;
|
||||
bool is_training_;
|
||||
bool is_on_gpu_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
|
||||
|
||||
class FusedBatchNormGradOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -233,6 +257,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
xla::XlaOp conv =
|
||||
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dims, depth);
|
||||
// Feature group convolution, will end up with the kernel_size change more
|
||||
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
|
||||
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
|
||||
conv_dims.back() = depth;
|
||||
conv_dims.push_back(kernel_size);
|
||||
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
|
||||
conv_dims.pop_back();
|
||||
conv_dims.back() *= kernel_size;
|
||||
conv = xla::Reshape(conv, conv_dims);
|
||||
ctx->SetOutput(0, conv);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -20,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
||||
|
||||
class GatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||
// Set batch_dims_ to 0 if the attribute does not exist.
|
||||
if (context->HasAttr("batch_dims")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
|
||||
} else {
|
||||
batch_dims_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
xla::XlaBuilder* builder = context->builder();
|
||||
auto input = context->Input(0);
|
||||
auto input_shape = context->InputShape(0);
|
||||
auto indices = context->Input(1);
|
||||
auto indices_shape = context->InputShape(1);
|
||||
int64 axis = 0;
|
||||
|
||||
absl::optional<int64> axis;
|
||||
if (context->num_inputs() == 3) {
|
||||
const TensorShape axis_shape = context->InputShape(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
||||
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
|
||||
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
||||
errors::InvalidArgument("axis must be int32 or int64"));
|
||||
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
|
||||
int64 axis_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->ConstantInputAsIntScalar(2, &axis_input));
|
||||
|
||||
const auto params_dims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
context, -params_dims <= axis && axis < params_dims,
|
||||
errors::InvalidArgument("Expected axis in the range [", -params_dims,
|
||||
", ", params_dims, "), but got ", axis));
|
||||
if (axis < 0) {
|
||||
axis += params_dims;
|
||||
OP_REQUIRES(context,
|
||||
-params_dims <= axis_input && axis_input < params_dims,
|
||||
errors::InvalidArgument("Expected axis in the range [",
|
||||
-params_dims, ", ", params_dims,
|
||||
"), but got ", axis_input));
|
||||
if (axis_input < 0) {
|
||||
axis_input += params_dims;
|
||||
}
|
||||
axis = axis_input;
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices_shape.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
axis = axis.value_or(batch_dims_);
|
||||
|
||||
OP_REQUIRES(context,
|
||||
batch_dims_ >= -indices_shape.dims() &&
|
||||
batch_dims_ < indices_shape.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices_shape.dims(), ", ",
|
||||
indices_shape.dims(), "), but got ",
|
||||
batch_dims_));
|
||||
|
||||
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than rank(input) (",
|
||||
input_shape.dims(), ")."));
|
||||
|
||||
OP_REQUIRES(context, *axis >= batch_dims_,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than or equal to ",
|
||||
"axis (", *axis, ")."));
|
||||
}
|
||||
|
||||
axis = axis.value_or(0);
|
||||
DataType index_type = input_type(1);
|
||||
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
||||
errors::InvalidArgument("indices must be int32 or int64"));
|
||||
|
||||
xla::XlaOp gather;
|
||||
if (batch_dims_ > 0) {
|
||||
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
|
||||
} else {
|
||||
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||
// error conditions and caught above if batch_dims is not 0.
|
||||
OP_REQUIRES_OK(
|
||||
context, XlaGather(input, input_shape, indices, indices_shape, axis,
|
||||
/*indices_are_nd=*/false, input_type(0), index_type,
|
||||
builder, &gather));
|
||||
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||
/*indices_are_nd=*/false, input_type(0),
|
||||
index_type, context->builder(), &gather));
|
||||
}
|
||||
context->SetOutput(0, gather);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
|
||||
|
||||
// The number of batch dimensions, as passed in the batch_dims attribute.
|
||||
// It must be less than rank(indices).
|
||||
int32 batch_dims_ = 0;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Gather"), GatherOp);
|
||||
|
@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
|
||||
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
|
||||
|
||||
// Calculate in each row of `predictions`, how many values are larger than
|
||||
// the value of target class. Then return the result whether the count <= k,
|
||||
// the value of target class. Then return the result whether the count < k,
|
||||
// which indicates the target is in topk.
|
||||
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0});
|
||||
xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
|
||||
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
|
||||
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
|
||||
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
|
||||
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
|
||||
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2);
|
||||
xla::XlaOp num_ge_r1 = xla::Reduce(
|
||||
xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
|
||||
xla::XlaOp num_gt_r1 = xla::Reduce(
|
||||
one_hot_r2, zero_r0,
|
||||
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
|
||||
|
||||
xla::XlaOp result =
|
||||
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k));
|
||||
xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
|
||||
xla::IsFinite(targets_values_r1));
|
||||
|
||||
context->SetOutput(0, result);
|
||||
}
|
||||
|
@ -67,9 +67,9 @@ class MatMulOp : public XlaOpKernel {
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
|
||||
errors::InvalidArgument("Matrix size-compatible: In[0]: ",
|
||||
a_shape.DebugString(), ", In[1]: ",
|
||||
b_shape.DebugString()));
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", a_shape.DebugString(),
|
||||
", In[1]: ", b_shape.DebugString()));
|
||||
|
||||
xla::XlaOp a = ctx->Input(0);
|
||||
xla::XlaOp b = ctx->Input(1);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -22,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -77,5 +79,58 @@ class SelectOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("Select"), SelectOp);
|
||||
|
||||
class SelectOpV2 : public XlaOpKernel {
|
||||
public:
|
||||
explicit SelectOpV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape cond_shape = ctx->InputShape(0);
|
||||
const TensorShape then_shape = ctx->InputShape(1);
|
||||
const TensorShape else_shape = ctx->InputShape(2);
|
||||
|
||||
// Compute the output shape from the broadcast of the two data inputs, with
|
||||
// the broadcast of the conditional.
|
||||
// Then Broadcast all three inputs to the output shape and emit a select.
|
||||
|
||||
BCast bcast_then_else(BCast::FromShape(then_shape),
|
||||
BCast::FromShape(else_shape),
|
||||
/*fewer_dims_optimization=*/false);
|
||||
if (!bcast_then_else.IsValid()) {
|
||||
ctx->SetStatus(errors::InvalidArgument(
|
||||
"Incompatible shapes: ", then_shape.DebugString(), " vs. ",
|
||||
else_shape.DebugString()));
|
||||
return;
|
||||
}
|
||||
BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape),
|
||||
/*fewer_dims_optimization=*/false);
|
||||
if (!bcast.IsValid()) {
|
||||
ctx->SetStatus(errors::InvalidArgument(
|
||||
"Incompatible shapes: ",
|
||||
BCast::ToShape(bcast_then_else.output_shape()).DebugString(), " vs. ",
|
||||
cond_shape.DebugString()));
|
||||
return;
|
||||
}
|
||||
|
||||
auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape());
|
||||
OP_REQUIRES_OK(ctx, bcasted_cond.status());
|
||||
auto cond_handle = bcasted_cond.ValueOrDie();
|
||||
|
||||
auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape());
|
||||
OP_REQUIRES_OK(ctx, bcasted_then.status());
|
||||
auto then_handle = bcasted_then.ValueOrDie();
|
||||
|
||||
auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape());
|
||||
OP_REQUIRES_OK(ctx, bcasted_else.status());
|
||||
auto else_handle = bcasted_else.ValueOrDie();
|
||||
|
||||
ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SelectOpV2);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("SelectV2"), SelectOpV2);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// XLA-specific Ops for softmax.
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -145,23 +146,36 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
|
||||
: XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape logits_shape = ctx->InputShape(0);
|
||||
const TensorShape labels_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
|
||||
errors::InvalidArgument(
|
||||
"logits and labels must be same size: logits_size=",
|
||||
logits_shape.DebugString(),
|
||||
" labels_size=", labels_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
|
||||
errors::InvalidArgument("logits must be 2-dimensional"));
|
||||
// As we already tested that both inputs have the same shape no need to
|
||||
// check that "labels" is a matrix too.
|
||||
|
||||
const DataType type = input_type(0);
|
||||
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
|
||||
auto logits = ctx->Input(0);
|
||||
auto labels = ctx->Input(1);
|
||||
|
||||
const TensorShape logits_shape = ctx->InputShape(0);
|
||||
const TensorShape labels_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
|
||||
errors::InvalidArgument("logits must be 2-dimensional"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_shape),
|
||||
errors::InvalidArgument("labels must be 2-dimensional"));
|
||||
|
||||
// Confirm that any necessary broadcasting to make the shapes the same will
|
||||
// succeed.
|
||||
for (int dim = 0; dim < 2; dim++) {
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
labels_shape.dim_size(dim) == 1 ||
|
||||
logits_shape.dim_size(dim) == labels_shape.dim_size(dim),
|
||||
errors::InvalidArgument("logits and labels must be same size after "
|
||||
"broadcasting of labels: logits_size=",
|
||||
logits_shape.DebugString(),
|
||||
" labels_size=", labels_shape.DebugString()));
|
||||
}
|
||||
if (!logits_shape.IsSameSize(labels_shape)) {
|
||||
auto labels_or = BroadcastTo(labels, logits_shape.dim_sizes());
|
||||
OP_REQUIRES_OK(ctx, labels_or.status());
|
||||
labels = labels_or.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaOp loss, backprop;
|
||||
std::tie(loss, backprop) =
|
||||
CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
|
||||
|
@ -1,9 +1,8 @@
|
||||
# Utilities for building XLA computations.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
|
@ -1,9 +1,8 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
|
@ -1,9 +1,8 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load(
|
||||
|
@ -550,6 +550,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
};
|
||||
GraphOptimizer::Options graph_optimizer_options;
|
||||
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
|
||||
graph_optimizer_options.inline_multi_device_functions = true;
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
|
||||
|
@ -116,9 +116,12 @@ class XlaOpRegistry {
|
||||
// If we should cluster operations returning DT_VARIANT.
|
||||
bool cluster_variant_ops = false;
|
||||
|
||||
// Whether ops known to be slow or to have correctness issues should be
|
||||
// Whether ops known to be slow should be auto-clustered.
|
||||
bool cluster_slow_ops = false;
|
||||
|
||||
// Whether ops known to have numerical accuracy issues should be
|
||||
// auto-clustered.
|
||||
bool cluster_slow_and_inaccurate_ops = false;
|
||||
bool cluster_inaccurate_ops = false;
|
||||
};
|
||||
|
||||
// Registers an XLA backend. `compilation_device_name` is the name of the
|
||||
|
@ -1,6 +1,7 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
@ -575,6 +576,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -881,6 +883,26 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "refcounting_hash_map",
|
||||
hdrs = ["refcounting_hash_map.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "refcounting_hash_map_test",
|
||||
srcs = ["refcounting_hash_map_test.cc"],
|
||||
deps = [
|
||||
":refcounting_hash_map",
|
||||
":test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user