Merge branch 'tensorflow-master'

This commit is contained in:
minds 2019-05-27 18:47:30 +09:00
commit 33dffe53fb
1650 changed files with 82482 additions and 58113 deletions

View File

@ -39,32 +39,46 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true
build:download_clang --action_env TF_DOWNLOAD_CLANG=1
# Instruct clang to use LLD for linking.
# This only works with GPU builds currently, since Bazel sets -B/usr/bin in
# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
# the downloaded one.
build:download_clang_use_lld --linkopt='-fuse-ld=lld'
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
# This config refers to building CUDA op kernels with clang.
build:cuda_clang --config=using_cuda
build:cuda_clang --define=using_cuda_clang=true
build:cuda_clang --define=using_clang=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
build:rocm --action_env TF_NEED_ROCM=1
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true --define=using_trisycl=false
build:sycl --define=using_sycl=true
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_nodouble --config=sycl
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_nodouble --config=sycl
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true
build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script
build:gdr --define=with_gdr_support=true
@ -87,6 +101,9 @@ build --spawn_strategy=standalone
build --strategy=Genrule=standalone
build -c opt
# Make Bazel print out all options from rc files.
build --announce_rc
# Other build flags.
build --define=grpc_no_ares=true
@ -97,8 +114,7 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
# Build TF with C++ 17 features.
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --cxxopt=-std=c++1z
build:c++1z --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Default paths for TF_SYSTEM_LIBS
build --define=PREFIX=/usr

View File

@ -38,7 +38,13 @@ working on getting your pull request submitted to our internal repository. After
the change has been submitted internally, your pull request will be merged
automatically on GitHub.
If you want to contribute but you're not sure where to start, take a look at the
If you want to contribute, start working through the TensorFlow codebase,
navigate to the
[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start
looking through interesting issues. If you are not sure of where to start, then
start by trying one of the smaller/easier issues here i.e.
[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue)
and then take a look at the
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you

View File

@ -403,7 +403,8 @@ def set_action_env_var(environ_cp,
enabled_by_default,
question=None,
yes_reply=None,
no_reply=None):
no_reply=None,
bazel_config_name=None):
"""Set boolean action_env variable.
Ask user if query_item will be enabled. Default is used if no input is given.
@ -418,12 +419,16 @@ def set_action_env_var(environ_cp,
question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled.
no_reply: optional string for reply when feature is disabled.
bazel_config_name: adding config to .bazelrc instead of action_env.
"""
var = int(
get_var(environ_cp, var_name, query_item, enabled_by_default, question,
yes_reply, no_reply))
write_action_env_to_bazelrc(var_name, var)
if not bazel_config_name:
write_action_env_to_bazelrc(var_name, var)
elif var:
write_to_bazelrc('build --config=%s' % bazel_config_name)
environ_cp[var_name] = str(var)
@ -543,7 +548,8 @@ def set_tf_cuda_clang(environ_cp):
False,
question=question,
yes_reply=yes_reply,
no_reply=no_reply)
no_reply=no_reply,
bazel_config_name='cuda_clang')
def set_tf_download_clang(environ_cp):
@ -558,7 +564,8 @@ def set_tf_download_clang(environ_cp):
False,
question=question,
yes_reply=yes_reply,
no_reply=no_reply)
no_reply=no_reply,
bazel_config_name='download_clang')
def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
@ -782,8 +789,8 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path):
print('WARNING: The NDK version in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_version,
_SUPPORTED_ANDROID_NDK_VERSIONS))
'errors.\n' %
(android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS))
# Now grab the NDK API level to use. Note that this is different from the
# SDK API level, as the NDK API level is effectively the *min* target SDK
@ -952,6 +959,7 @@ def set_tf_nccl_version(environ_cp):
ask_nccl_version, '')
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@ -1293,9 +1301,6 @@ def configure_ios():
"""
if not is_macos():
return
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
print(
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
for filepath in APPLE_BAZEL_FILES:
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
@ -1386,7 +1391,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
current_bazel_version = check_bazel_version('0.24.1', '0.25.3')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1422,7 +1427,12 @@ def main():
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',
'OpenCL SYCL',
False,
bazel_config_name='sycl')
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp)
@ -1432,30 +1442,44 @@ def main():
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
if (environ_cp.get('TF_NEED_ROCM') == '1' and
'LD_LIBRARY_PATH' in environ_cp and
environ_cp.get('LD_LIBRARY_PATH') != '1'):
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
environ_cp.get('LD_LIBRARY_PATH'))
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
environ_cp['TF_NEED_CUDA'] = str(
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
set_action_env_var(
environ_cp,
'TF_NEED_TENSORRT',
'TensorRT',
False,
bazel_config_name='tensorrt')
environ_save = dict(environ_cp)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
if validate_cuda_config(environ_cp):
cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
'TF_CUDA_VERSION',
'TF_CUBLAS_VERSION',
'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION',
'TF_NCCL_VERSION',
'TF_CUDA_PATHS',
# Items below are for backwards compatibility when not using
# TF_CUDA_PATHS.
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
'CUDA_TOOLKIT_PATH',
'CUDNN_INSTALL_PATH',
'NCCL_INSTALL_PATH',
'NCCL_HDR_PATH',
'TENSORRT_INSTALL_PATH'
]
# Note: set_action_env_var above already writes to bazelrc.
for name in cuda_env_names:
@ -1506,8 +1530,6 @@ def main():
# CUDA not required. Ask whether we should download the clang toolchain and
# use it for the CPU build.
set_tf_download_clang(environ_cp)
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
write_to_bazelrc('build --config=download_clang')
# SYCL / ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured.

View File

@ -59,7 +59,7 @@ except ImportError:
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons

View File

@ -21,6 +21,9 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_status.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
@ -51,6 +54,8 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"tf_datatype.h",
"tf_status.h",
],
visibility = [
"//tensorflow:internal",
@ -61,6 +66,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -71,16 +77,26 @@ tf_cuda_library(
}),
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
hdrs = [
"c_api.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_status.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_no_xla",
":c_api_internal",
":tf_attrtype",
] + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -96,14 +112,21 @@ tf_cuda_library(
"c_api.cc",
"c_api_function.cc",
],
hdrs = ["c_api.h"],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [":c_api_internal"] + select({
deps = [
":c_api_internal",
":tf_attrtype",
":tf_datatype",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_status",
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
@ -124,6 +147,37 @@ tf_cuda_library(
}),
)
cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:lib",
],
}),
)
cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
hdrs = ["tf_datatype.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}),
)
tf_cuda_library(
name = "c_api_experimental",
srcs = [
@ -137,6 +191,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
":checkpoint_reader",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/jit:flags",
@ -151,15 +206,6 @@ tf_cuda_library(
],
)
cc_library(
name = "c_api_headers",
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
)
exports_files(
[
"version_script.lds",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/strings/match.h"
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h" // NOLINT
@ -97,7 +98,6 @@ using tensorflow::TensorId;
using tensorflow::TensorShape;
using tensorflow::TensorShapeProto;
using tensorflow::VersionDef;
using tensorflow::error::Code;
using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice;
@ -108,34 +108,6 @@ extern "C" {
// --------------------------------------------------------------------------
const char* TF_Version() { return TF_VERSION_STRING; }
// --------------------------------------------------------------------------
size_t TF_DataTypeSize(TF_DataType dt) {
return static_cast<size_t>(
tensorflow::DataTypeSize(static_cast<DataType>(dt)));
}
// --------------------------------------------------------------------------
TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}
TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}
const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}
// --------------------------------------------------------------------------
namespace {
@ -1697,7 +1669,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
if (metadata.list_size == 0) {
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
const auto& a = oper->node.op_def().attr(i);
if (a.name().compare(attr_name) != 0) continue;
if (a.name() != attr_name) continue;
const string& typestr = a.type();
if (typestr == "list(string)") {
metadata.type = TF_ATTR_STRING;
@ -2517,8 +2489,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
// used in this graph
for (const auto& pair : g->name_map) {
const string& name = pair.first;
if (name.compare(prefix) == 0 ||
tensorflow::str_util::StartsWith(name, prefix_cmp)) {
if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
status->status = InvalidArgument(
"prefix [", prefix,
"] conflicts with existing node in the graph named [", name, "]");
@ -2548,8 +2519,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
// Adding the gradients to the graph can alter the prefix to prevent
// name collisions only if this prefix has not been provided explicitly
// by the user. If it was provided, assert that it remained intact.
if (prefix != nullptr &&
!tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
status->status = tensorflow::errors::Internal(
"BUG: The gradients prefix have been unexpectedly altered when "
"adding the nodes to the graph. This is a bug. Please file an "

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/tf_attrtype.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
// --------------------------------------------------------------------------
// C API for TensorFlow.
//
@ -69,7 +73,7 @@ limitations under the License.
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.$a
// of any other includes.
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
@ -93,89 +97,6 @@ extern "C" {
// TensorFlow library. TensorFlow using semantic versioning.
TF_CAPI_EXPORT extern const char* TF_Version(void);
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
// The enum values here are identical to corresponding values in types.proto.
typedef enum TF_DataType {
TF_FLOAT = 1,
TF_DOUBLE = 2,
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
TF_UINT8 = 4,
TF_INT16 = 5,
TF_INT8 = 6,
TF_STRING = 7,
TF_COMPLEX64 = 8, // Single-precision complex
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
TF_INT64 = 9,
TF_BOOL = 10,
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
// to the given TF_DataType enum value. Returns 0 for variable length types
// (eg. TF_STRING) or on failure.
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
typedef enum TF_Code {
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
} TF_Code;
// --------------------------------------------------------------------------
// TF_Status holds error information. It either has an OK code, or
// else an error code with an associated error message.
typedef struct TF_Status TF_Status;
// Return a new status object.
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
// Record <code, msg> in *s. Any previous information is lost.
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
const char* msg);
// Return the code record in *s.
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
// Return a pointer to the (null-terminated) error message in *s. The
// return value points to memory that is only usable until the next
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
// TF_OK.
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
// --------------------------------------------------------------------------
// TF_Buffer holds a pointer to a block of data and its associated length.
// Typically, the data consists of a serialized protocol buffer, but other data
@ -686,19 +607,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs(
TF_Operation* oper, TF_Operation** control_outputs,
int max_control_outputs);
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
// TF_AttrMetadata describes the value of an attribute on an operation.
typedef struct TF_AttrMetadata {
// A boolean: 1 if the attribute value is a list, 0 otherwise.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/substitute.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
@ -37,6 +38,7 @@ using tensorflow::FunctionDef;
using tensorflow::Node;
using tensorflow::NodeBuilder;
using tensorflow::Status;
using tensorflow::errors::InvalidArgument;
namespace {
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
@ -149,7 +151,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
}
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
const auto& debug_str = func->fdef.DebugString();
const auto& debug_str = DebugString(func->fdef);
*len = debug_str.size();
char* ret = static_cast<char*>(malloc(*len + 1));
memcpy(ret, debug_str.c_str(), *len + 1);
@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
std::vector<std::string> variable_list;
};
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
TF_Status* status) {
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
if (!status->status.ok()) return nullptr;
const auto& m = reader->GetVariableToDataTypeMap();
for (auto it = m.begin(); it != m.end(); ++it)
reader->variable_list.push_back(it->first);
std::sort(reader->variable_list.begin(), reader->variable_list.end());
return reader;
}
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
const char* name) {
return reader->HasTensor(name);
}
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
int index) {
return reader->variable_list[index].c_str();
}
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
return reader->variable_list.size();
}
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToDataTypeMap();
return static_cast<TF_DataType>(m.at(name));
}
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
const char* name, TF_Status* status) {
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor.get(), status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
const char* name, int64_t* dims,
int num_dims, TF_Status* status) {
const auto& shape = reader->GetVariableToShapeMap().at(name);
int rank = shape.dims();
if (num_dims != rank) {
status->status = InvalidArgument("Expected rank is ", num_dims,
" but actual rank is ", rank);
return;
}
for (int i = 0; i < num_dims; i++) {
dims[i] = shape.dim_size(i);
}
}
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToShapeMap();
return m.at(name).dims();
}
// This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder;

View File

@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
// TF_NewCheckpointReader() return the CheckpointReader that can be use to
// investigate or load the variable from the checkpoint file
typedef struct TF_CheckpointReader TF_CheckpointReader;
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
const char* filename, TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
TF_CheckpointReader* reader);
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
TF_CheckpointReader* reader, const char* name);
// Get the variable name at the given index
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
TF_CheckpointReader* reader, int index);
// Get the number of variable in the checkpoint
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
// Get the DataType of a variable
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
TF_CheckpointReader* reader, const char* name);
// Read the shape of a variable and write to `dims`
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
TF_Status* status);
// Get the number of dimension of a variable
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
TF_CheckpointReader* reader, const char* name);
// Load the weight of a variable
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
TF_CheckpointReader* reader, const char* name, TF_Status* status);
// TF_NewAttrBuilder() returns an object that you can set attributes on as
// though it were an op. This allows querying properties of that op for
// type-checking purposes like if the op will run on a particular device type.

View File

@ -62,8 +62,8 @@ protocol: "grpc"
TF_Buffer* null_result =
TFE_GetServerDef(malformed_text_proto.c_str(), status);
EXPECT_NE(TF_GetCode(status), TF_OK);
EXPECT_TRUE(tensorflow::str_util::StrContains(
TF_Message(status), "Invalid text proto for ServerDef"));
EXPECT_TRUE(absl::StrContains(TF_Message(status),
"Invalid text proto for ServerDef"));
EXPECT_EQ(null_result, nullptr);
// Cleanup

View File

@ -253,7 +253,7 @@ class CApiFunctionTest : public ::testing::Test {
const std::unordered_set<string>& nodes) {
ASSERT_EQ(nodes.size(), fdef.node_def_size())
<< "Got unexpected number of nodes. Expected: ["
<< str_util::Join(nodes, ", ")
<< absl::StrJoin(nodes, ", ")
<< "] Actual nodes in fdef: " << fdef.DebugString();
for (const NodeDef& node_def : fdef.node_def()) {
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())

View File

@ -56,7 +56,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected))
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}

View File

@ -1,6 +1,8 @@
# Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0
package(
licenses = ["notice"], # Apache 2.0
)
load(
"//tensorflow:tensorflow.bzl",

View File

@ -30,7 +30,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -135,11 +137,12 @@ tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response;
request.set_rendezvous_id(rendezvous_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -221,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
@ -230,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
remote_eager_workers.get(), ctx->context->Async(), base_request,
&remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, true));
session_name, server_def, base_request.cluster_device_attributes(),
true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@ -250,9 +272,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr;
return ctx->context->InitializeRemote(
std::move(server), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
keep_alive_secs);
std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_contexts, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get());
#undef LOG_AND_RETURN_IF_ERROR
}
#endif // !IS_MOBILE_PLATFORM
@ -970,6 +993,23 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
return t;
}
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status) {
// TensorHandles created by PyFuncOp lack context and therefore could
// not be copied.
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
tensorflow::TensorHandle* handle;
status->status = tensorflow::EagerCopyToDevice(
h->handle, h->handle->Context(), "CPU:0", &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
} else {
return nullptr;
}
}
return h;
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);

View File

@ -462,6 +462,9 @@ class Tensor;
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status);
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
#endif

View File

@ -78,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
<< absl::StrJoin(shape_to_log, ", ") << "] is "
<< padded_shape.DebugString();
}
}

View File

@ -33,7 +33,7 @@ namespace tensorflow {
namespace {
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
bool ok = str_util::StrContains(base, substr);
bool ok = absl::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok;
}

View File

@ -1408,29 +1408,34 @@ void FunctionDefAndExecute(bool async) {
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, m, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &retval[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(m);
TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
TFE_DeleteTensorHandle(retval[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
for (bool clear_cache : {true, false, true}) {
if (clear_cache) {
TFE_ContextClearCaches(ctx);
}
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, m, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &retval[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(m);
TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
TFE_DeleteTensorHandle(retval[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
}
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);

View File

@ -1,7 +1,9 @@
# Description:
# Experimental C APIs for TensorFlow.
licenses(["notice"]) # Apache 2.0
package(
licenses = ["notice"], # Apache 2.0
)
load(
"//tensorflow:tensorflow.bzl",

View File

@ -6,10 +6,9 @@ load(
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
tf_kernel_library(
name = "bitcast_op",
prefix = "bitcast_op",

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
#define TENSORFLOW_C_TF_ATTRTYPE_H_
#ifdef __cplusplus
extern "C" {
#endif
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_ATTRTYPE_H_

View File

@ -0,0 +1,23 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/types.h"
size_t TF_DataTypeSize(TF_DataType dt) {
return static_cast<size_t>(
tensorflow::DataTypeSize(static_cast<tensorflow::DataType>(dt)));
}

View File

@ -0,0 +1,83 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_DATATYPE_H_
#define TENSORFLOW_C_TF_DATATYPE_H_
#include <stddef.h>
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
// The enum values here are identical to corresponding values in types.proto.
typedef enum TF_DataType {
TF_FLOAT = 1,
TF_DOUBLE = 2,
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
TF_UINT8 = 4,
TF_INT16 = 5,
TF_INT8 = 6,
TF_STRING = 7,
TF_COMPLEX64 = 8, // Single-precision complex
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
TF_INT64 = 9,
TF_BOOL = 10,
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
// to the given TF_DataType enum value. Returns 0 for variable length types
// (eg. TF_STRING) or on failure.
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_DATATYPE_H_

42
tensorflow/c/tf_status.cc Normal file
View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/lib/core/status.h"
using ::tensorflow::Status;
using ::tensorflow::error::Code;
TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}
TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}
const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}

88
tensorflow/c/tf_status.h Normal file
View File

@ -0,0 +1,88 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_STATUS_H_
#define TENSORFLOW_C_TF_STATUS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TF_Status TF_Status;
// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
typedef enum TF_Code {
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
} TF_Code;
// --------------------------------------------------------------------------
// Return a new status object.
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);
// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);
// Record <code, msg> in *s. Any previous information is lost.
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
const char* msg);
// Return the code record in *s.
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);
// Return a pointer to the (null-terminated) error message in *s. The
// return value points to memory that is only usable until the next
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
// TF_OK.
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_STATUS_H_

View File

@ -4,10 +4,9 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
filegroup(
name = "srcs",
srcs = [
@ -638,6 +637,7 @@ cc_library(
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
@ -657,6 +657,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/cc_op_gen.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "absl/strings/escaping.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@ -133,7 +135,7 @@ string MakeComment(StringPiece text, StringPiece indent) {
}
string PrintString(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\"");
return strings::StrCat("\"", absl::CEscape(str), "\"");
}
string PrintTensorShape(const TensorShapeProto& shape_proto) {
@ -191,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) {
string ret;
for (int64 i = 0; i < num_elts; ++i) {
if (i > 0) strings::StrAppend(&ret, " ");
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
}
return ret;
}

View File

@ -62,12 +62,12 @@ op {
)";
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected))
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
EXPECT_FALSE(str_util::StrContains(s, expected))
EXPECT_FALSE(absl::StrContains(s, expected))
<< "'" << s << "' contains '" << expected << "'";
}

View File

@ -275,7 +275,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
current_constraints.emplace(s);
}
}

View File

@ -3,10 +3,9 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(

View File

@ -308,7 +308,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
const Status status = LoadSavedModelInternal(session_options, run_options,
export_dir, tags, bundle);
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
<< " }; Status: " << status_str << ". Took "
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);

View File

@ -136,7 +136,7 @@ TEST_F(LoaderTest, NoTagMatch) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message();
@ -152,7 +152,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe, "missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: "))
<< st.error_message();
@ -172,7 +172,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget))
EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget))
<< st.error_message();
}

View File

@ -51,7 +51,7 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags,
MetaGraphDef* meta_graph_def) {
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
<< " }";
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
// Get tags from the graph_def.
@ -69,7 +69,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
error::Code::NOT_FOUND,
strings::StrCat(
"Could not find meta graph def matching supplied tags: { ",
str_util::Join(tags, " "),
absl::StrJoin(tags, " "),
" }. To inspect available tag-sets in the SavedModel, please "
"use the SavedModel CLI: `saved_model_cli`"));
}

View File

@ -64,7 +64,7 @@ TEST_F(ReaderTest, NoTagMatch) {
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message();
@ -78,7 +78,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
EXPECT_TRUE(absl::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: "))
<< st.error_message();

View File

@ -3,10 +3,9 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(

View File

@ -167,8 +167,7 @@ namespace {
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
int32* dst) {
if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
tensorflow::str_util::ConsumePrefix(&arg, "=")) {
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
char extra;
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
}
@ -178,7 +177,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool* dst) {
if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
if (absl::ConsumePrefix(&arg, flag)) {
if (arg.empty()) {
*dst = true;
return true;

View File

@ -1,7 +1,6 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")

View File

@ -1,4 +1,12 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "internal",
@ -14,15 +22,6 @@ package_group(
],
)
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@ -200,6 +199,7 @@ cc_library(
"//tensorflow/core/kernels:host_constant_op",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:logging_ops",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
@ -257,10 +257,8 @@ cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
hdrs = ["xla_launch_util.h"],
# TODO(skyewm): remove this once XlaAllocator is factored out.
visibility = [
":internal",
"//tensorflow/compiler/xla/python:__pkg__",
],
deps = [
":common",

View File

@ -244,11 +244,11 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
"resource variable op in called function");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
return LogNotCompilableAndReturn(node, "operation with correctness issues");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
return LogNotCompilableAndReturn(node, "slow operation");
}
@ -268,8 +268,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
registration.elide_assert_and_checknumerics;
op_filter.allow_ops_producing_or_consuming_variant =
registration.cluster_variant_ops;
op_filter.allow_slow_and_inaccurate_ops =
registration.cluster_slow_and_inaccurate_ops;
op_filter.allow_slow_ops = registration.cluster_slow_ops;
op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
return op_filter;
}

View File

@ -97,9 +97,12 @@ class RecursiveCompilabilityChecker {
// live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant;
// Whether ops known to be slow or to have correctness issues should be
// auto-clustered.
bool allow_slow_and_inaccurate_ops;
// Whether ops known to be slow on XLA-GPU should be considered compilable..
bool allow_slow_ops;
// Whether ops known to have numerical accuracy issues should be considered
// compilable..
bool allow_inaccurate_ops;
};
RecursiveCompilabilityChecker(const OperationFilter* op_filter,

View File

@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -43,12 +45,12 @@ limitations under the License.
// ------------------------------------------
//
// If we ignore cycles for a moment, computing predicates is fairly
// straightforward. We traverse the graph in RPO, mapping each node to a
// predicate based on the predicates its inputs are mapped to. For instance a
// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
// Roughtly speaking, we abstract interpret each node on the "liveness" domain,
// where values in the domain represent if a tensor carries a dead signal or
// not.
// straightforward. We traverse the graph in a topological order, mapping each
// node to a predicate based on the predicates its inputs are mapped to. For
// instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
// PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
// the "liveness" domain, where values in the domain represent if a tensor
// carries a dead signal or not.
//
//
// DEALING WITH CYCLES
@ -85,22 +87,28 @@ limitations under the License.
// true on iteration 0, 1, 2 respectively. This is made more precise in the
// comment on the AndRecurrence class.
//
// The general algorithm that deals with cycles does two RPO (reverse post
// order) passes over the graph. On the first pass it assigns a symbolic
// predicate to merge nodes with backedges. On the second pass it tries to
// pattern matche the predicates for the backedges of these merges and infer an
// AndRecurrence for the merge.
// The general algorithm that deals with cycles does two topological-order
// iterations over the graph. On the first iteration it assigns a symbolic
// predicate to merge nodes with backedges. On the second iteration it tries
// to pattern match the predicates for the backedges of these merges and infer
// an AndRecurrence for the merge. In other words, we do a data flow analysis
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
// sufficient to converge.
//
// In other words, we do a pessimistic data flow analysis where the data-flow
// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
// converge. We don't do an optimistic data flow analysis to make pattern
// matching easier: if we assigned the predicate of the initial value to the
// merge during the first pass, on the second pass the backedge may see a
// simplified value that would be difficult to pattern match.
// We first do an optimisitc analysis and, if it does not converge, we then fall
// back to a pessimistic analysis. The optimistic analysis assigns the same
// symbolic predicate to all the merge nodes whose preceding enter nodes have
// the same frame name on the first iteration. On the second iteration, if all
// the merge nodes are pattern matched into the same AndRecurrence predicate
// instance, the optimistic assignment of the same symbolic predicate is correct
// and the analyzed result is taken.
//
// We still use symbolic predicates for merges for which we can't pattern match
// on the backedge predicate. This is conservatively correct.
// Otherwise, if the optimistic analysis fails to converge, we then obtain the
// result by falling back to the pessimistic analysis which assigns a unique
// symbolic predicate to each merge on the first iteration. We still use
// symbolic predicates for merges for which we can't pattern match on the
// backedge predicate. This is conservatively correct.
namespace tensorflow {
@ -636,6 +644,35 @@ Predicate* PredicateFactory::MakeAndOrImpl(
negated_ops.insert(negated_op);
}
// Simplify {S,&,X} & ~X & ... => S & ...
if (is_and) {
absl::flat_hash_set<Predicate*> to_remove;
std::vector<Predicate*> to_add;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kAndRecurrence) {
auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
if (negated_ops.contains(and_rec->step())) {
// Remove and_rec and ~X and insert S. Note that checking the
// existence of ~X through negated_ops is sufficient since it makes
// sure the predicate is in the input operands. It does not need to
// be in simplified_ops if it was already cancelled out.
to_remove.insert(and_rec);
to_remove.insert(MakeNotPredicate(and_rec->step()));
to_add.push_back(and_rec->start());
}
}
}
auto it = simplified_ops.begin();
while (it != simplified_ops.end()) {
if (to_remove.contains(*it)) {
it = simplified_ops.erase(it);
} else {
++it;
}
}
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
}
// If all ops contain the same subop, then factor it out thanks to the
// distributive property. Such as:
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
@ -699,8 +736,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
explicit DeadnessAnalysisImpl(const Graph* graph)
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
Status Populate(bool enable_optimistic);
Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
bool* success);
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override;
void Print() const override;
@ -742,16 +780,29 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
}
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode);
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
Status HandleNode(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode = false);
Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
bool IsRootEnter(const Node* n) const {
return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
}
bool IsRootExit(const Node* n) const {
return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
}
const Graph& graph_;
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
std::vector<ControlFlowInfo> control_flow_info_;
bool vlog_;
absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
};
TensorId InputEdgeToTensorId(const Edge* e) {
@ -914,10 +965,32 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
return Status::OK();
}
// If the node is inside some frames, get the name of the outermost non-empty
// frame. Otherwise, get an empty frame name.
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
absl::string_view* frame) {
int depth = 0;
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
while (!cfi_iter->parent_frame->IsSource()) {
n = cfi_iter->parent_frame;
cfi_iter = &cfi_infos[n->id()];
if (depth++ > 5000) {
return errors::Internal(
"Frame of depth > 5000: Probably malformed graph or a bug in "
"BuildControlFlowInfo");
}
}
*frame = cfi_iter->frame_name;
return Status::OK();
}
} // namespace
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<bool>* should_revisit) {
std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
// Merge ignores deadness of its control inputs. A merge that isn't the
// target of a backedge has is alive iff any of its data inputs are. The
// liveness of a merge that is the target of a backedge can sometimes be
@ -937,8 +1010,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// We're visiting this merge for the first time and it has an unvisited
// backedge.
Predicate* input_data_pred;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
if (use_optimistic_mode) {
// In the optimistic mode, we use the first-seen Merge node per
// frame as the representative Merge node. It is just convenient and
// does not affect the result after pattern-matching into the
// AndRecurrence form.
absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
auto insert_result = frame_to_merge_node_.insert({frame_name, n});
Node* representative = insert_result.first->second;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
representative, /*output_idx=*/0, /*must_be_true=*/false,
&input_data_pred));
} else {
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
}
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit);
@ -948,7 +1034,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<Predicate*> input_preds;
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
// We're visiting this merge for the first time and it is a acyclic merge.
// We're visiting this merge for the first time and it is an acyclic merge.
Predicate* input_data_pred =
predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
@ -1022,11 +1108,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
}
Status DeadnessAnalysisImpl::HandleNode(Node* n,
std::vector<bool>* should_revisit) {
std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
if (n->IsSwitch()) {
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
} else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
} else if (n->IsControlTrigger()) {
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
nullptr);
@ -1040,17 +1127,129 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
return Status::OK();
}
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
return PopulateWithReversePostOrder(rpo);
// Compute a special topological order for the Graph, where nodes having the
// same root frame are placed adjacent to each other. The traversal uses a
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
// many inputs of each node are ready; a node is ready to be scheduled if all
// of its inputs are ready.
// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
std::vector<Node*>* order) {
absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
Node* src_node = graph_.source_node();
for (const auto* node : graph_.op_nodes()) {
const ControlFlowInfo& cf = control_flow_info_[node->id()];
if (IsRootEnter(node)) {
// Since we care only the root-level frame, full frame names are the same
// as frame names.
++num_enters_for_frame[cf.frame_name];
} else if (IsRootExit(node)) {
++num_exits_for_frame[cf.frame_name];
}
// Edge NextIteration->Merge is counted before starting the traveral to
// break the backedges.
if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) {
++num_ready_inputs[node->id()];
}
}
}
}
// dequeue is used to ensure that the nodes are first-in-first-out. This
// order guarantees that the exits in the ready queue are visited before
// nodes that will become ready in the future.
std::deque<Node*> ready;
ready.push_back(src_node);
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
// the ready enters/exits before they are moved to the `ready` queue for
// controlling the start and end of a processing frame.
absl::flat_hash_map<absl::string_view, std::vector<Node*>>
ready_enters_per_frame;
// Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough.
std::vector<Node*> ready_exits;
while (!ready.empty()) {
Node* curr_node = ready.front();
ready.pop_front();
VLOG(4) << "Visiting " << curr_node->name();
order->push_back(curr_node);
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
if (IsNextIteration(curr_node) && IsMerge(out)) {
// Edge NextIteration->Merge has been counted.
continue;
}
++num_ready_inputs[out->id()];
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
absl::string_view frame_name = control_flow_info_[out_id].frame_name;
if (IsRootEnter(out)) {
ready_enters_per_frame[frame_name].push_back(out);
} else if (IsRootExit(out)) {
ready_exits.push_back(out);
} else {
ready.push_back(out);
}
}
if (ready.empty()) {
// Try moving nodes from ready_enters_per_frame and ready_exits to
// `ready`.
if (!ready_exits.empty()) {
// If there are nodes in ready_exits we must process them before
// processing ready_enters_per_frame to make sure all nodes in the
// currently processing frame are visited before starting processing
// other frames.
absl::string_view frame_name =
control_flow_info_[ready_exits.front()->id()].frame_name;
CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
ready_exits.clear();
} else {
// Otherwise, try moving nodes from ready_enters to `ready`.
for (auto iter = ready_enters_per_frame.begin();
iter != ready_enters_per_frame.end(); ++iter) {
absl::string_view frame_name = iter->first;
const std::vector<Node*>& ready_enters = iter->second;
if (ready_enters.size() == num_enters_for_frame[frame_name]) {
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
ready_enters_per_frame.erase(iter);
break;
}
}
}
}
}
if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
return errors::InvalidArgument(
"Some enters/exits have never been visited in the traversal."
" Most probably the input graph is malformed.");
}
return Status::OK();
}
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
absl::Span<Node* const> rpo) {
// We populate the nodes along a special topological order where nodes having
// the same root frame are placed adjacent to each other. This grouping enables
// processing the graph per root frame at a time and guarantees that when a root
// frame is being processed, nodes in the downstream frames have not yet been
// processed. This property is important because we need to process an entire
// frame to know whether the optimistic mode converges or not. In other words,
// nodes in the downstream frames shall not be populated until all of its
// upstream frames are populated. In effect, this order enables processing each
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
// (root) frame. Note that we don't separate while loops belonging to the same
// nested while, as there is no clean cut for separating them in the topological
// order.
Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
std::vector<string> unreachable_nodes;
// Compute the loop structure of the graph.
TF_RETURN_IF_ERROR(
@ -1069,14 +1268,63 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
absl::StrJoin(unreachable_nodes, ", "));
}
std::vector<Node*> topo;
TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
size_t frame_start = 0;
while (frame_start < topo.size()) {
// Batching nodes who have the same root frame.
absl::string_view cur_frame_name;
TF_RETURN_IF_ERROR(
GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
size_t frame_end = frame_start;
for (size_t i = frame_start + 1; i < topo.size(); ++i) {
absl::string_view i_frame_name;
TF_RETURN_IF_ERROR(
GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
if (i_frame_name == cur_frame_name) {
frame_end = i;
} else {
break;
}
}
absl::Span<Node*> sub_topo(topo.data() + frame_start,
/*length=*/frame_end - frame_start + 1);
frame_start = frame_end + 1;
// First, try the optimistic mode.
bool success = false;
if (enable_optimistic && !cur_frame_name.empty()) {
TF_RETURN_IF_ERROR(
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
}
if (!success) {
// The optimistic mode does not converge. Let's fall back to the
// pessimistic mode.
TF_RETURN_IF_ERROR(
PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
}
VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
<< (success ? "optimistic" : "pessimistic") << " mode.";
}
return Status::OK();
}
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
bool use_optimistic_mode,
bool* success) {
CHECK(use_optimistic_mode && success != nullptr ||
!use_optimistic_mode && success == nullptr);
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
//
// We iterate over the graph twice, each time in RPO. On the first iteration
// merge nodes with backedges are mapped to symbolic predicates. On the
// second iteration we use the predicates assigned to the backedges in the
// previous iteration to infer a more precise predicate for the backedge merge
// nodes and all the nodes that transitively use it.
// We iterate over the graph twice, each time in a topological order. On the
// first iteration merge nodes with backedges are mapped to symbolic
// predicates. On the second iteration we use the predicates assigned to the
// backedges in the previous iteration to infer a more precise predicate for
// the backedge merge nodes and all the nodes that transitively use it.
//
// We don't track the output indices for should_revisit. Instead, putting a
// node in `should_revisit` denotes that the deadness flowing out from any
@ -1086,9 +1334,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
// delta should not change in the second iteration.
std::vector<bool> should_revisit;
should_revisit.resize(graph_.num_node_ids());
for (Node* n : rpo) {
for (Node* n : topo) {
VLOG(4) << "Visiting " << n->name();
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
TF_RETURN_IF_ERROR(
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
if (n->IsNextIteration()) {
// If this is a backedge for a merge node then remember to reprocess the
// merge the next time we run.
@ -1100,11 +1349,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
}
}
for (Node* n : rpo) {
for (Node* n : topo) {
// The nodes added to should_revisit in the previous loop need to be
// revisited now. Reprocesing these initial nodes may add *their* consumers
// to should_revisit, and these newly added nodes will also be processed by
// this very same loop. Since we're traversing the graph in reverse post
// this very same loop. Since we're traversing the graph in topological
// order (producers before consumers) and HandleNode(n) can only ever add
// n's consumers to should_revisit, we won't "miss" an addition to
// should_revisit.
@ -1114,6 +1363,71 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
}
}
// Check if the optimistic analysis converges. Specifically, check whether
// all the predicates of the merge nodes in the same frame are the same. If
// yes, report success. If not, report failure and clear the assigned
// predicates.
if (use_optimistic_mode) {
bool is_converged = true;
absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
for (Node* n : topo) {
if (!n->IsMerge()) {
continue;
}
const Edge* e;
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
if (e == nullptr) {
// Skip acyclic merge nodes.
continue;
}
Node* merge = n;
// Note that here uses frame names instead of root frame names. In the
// case of a nested while loop, each level of while loops can have merges
// with different predicate instances, while the merge nodes on the same
// level must have the same predicate instances.
absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
auto it = predicate_map_.find(TensorId(merge->name(), 0));
Predicate* merge_pred = it->second;
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge because node " << merge->name()
<< " cannot be mapped into the AndRecurrence form.";
break;
}
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
if (!insert_result.second) {
// If we have already seen this frame name, verify the predicate is the
// same as the previously seen one's.
Predicate* curr_andrec = merge_pred;
Predicate* prev_andrec = insert_result.first->second;
if (curr_andrec != prev_andrec) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge. Seeing different Merge predicates: \n"
<< curr_andrec->ToString() << " and \n"
<< prev_andrec->ToString();
break;
}
}
}
// Clear the assigned predicates if the optimistic mode does not converge.
if (!is_converged) {
for (Node* n : topo) {
for (int oid = 0; oid < n->num_outputs(); ++oid) {
predicate_map_.erase(TensorId(n->name(), oid));
}
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
}
}
if (success != nullptr) {
*success = is_converged;
}
}
return Status::OK();
}
@ -1149,7 +1463,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
std::unique_ptr<DeadnessAnalysisImpl> analysis(
new DeadnessAnalysisImpl(&graph));
TF_RETURN_IF_ERROR(analysis->Populate());
TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
if (VLOG_IS_ON(2)) {
analysis->Print();
@ -1170,22 +1484,18 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
}
namespace deadness_analysis_internal {
Status ComputePredicates(const Graph& graph,
PredicateMapTy* out_predicate_map) {
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool enable_optimistic) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.Populate());
TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
} // namespace deadness_analysis_internal
string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
return static_cast<Predicate*>(predicate.pred_)->ToString();
}
} // namespace tensorflow

View File

@ -82,6 +82,8 @@ class DeadnessAnalysis {
virtual void Print() const = 0;
virtual ~DeadnessAnalysis();
string DebugString(DeadnessPredicate predicate) const;
// Run the deadness analysis over `graph` and returns an error or a populated
// instance of DeadnessAnalysis in `result`.
static Status Run(const Graph& graph,

View File

@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool enable_optimistic = true);
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only. Makes deadness analysis visit the graph in the order
// specified in `reverse_post_order` which must be a valid RPO for the graph
// minus NextIteration->Merge edges.
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow

View File

@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph());
// To make deadness analysis think that dependent_iv is a loop we need an RPO
// that visits the merge before the backedge. This is a legal RPO for
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
// Right now dependent_iv has an edge from Merge to NextIteration so do the
// RPO with this edge in place. Then remove this edge to get our test case.
std::vector<Node*> rpo;
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
TF_ASSERT_OK(root.graph()->UpdateEdge(
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"{#true,&,*iv0/cond:0}<frame>");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"div0/iv:0");
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
"{(*iv_outer/cond:0 & "
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
"cond:0}<inner_loop;outer_loop>");
// enable_optimistic = true or not should produce the same results because
// of fallback. However, note that the order of iv_inner/cond:0 and
// iv_inner/iv:0 is different because the optimistic approach does not
// create predicates for all merges and it can change the predicate id and
// hence the symbol order.
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
"iv_inner/iv:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
}
}
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
}
}
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_outer =
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
Output enter_constant_outer_loop = ops::internal::Enter(
root.WithOpName("constant_enter_outer_loop"),
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
ops::Switch inner_value(root.WithOpName("outer_is_live"),
enter_constant_outer_loop, iv_outer.loop_cond);
InductionVarInfo iv_inner = CreateInductionVariable(
root, "iv_inner", "inner_loop", inner_value.output_true);
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
div0_outer.induction_var);
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
div1_outer.induction_var);
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
"tensor_a", "sender", 0, "receiver");
Output capture_enter_outer = ops::internal::Enter(
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output capture_enter_inner = ops::internal::Enter(
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
capture_enter_inner);
TF_ASSERT_OK(root.graph()->UpdateEdge(
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
div1_inner.induction_var);
VLogGraphIfAsked(*root.graph());
{
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
}
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
DependentInductionVar div0 =
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
DependentInductionVar div1 =
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph());
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
div0.latch.output_true.node(), 0));
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
div1.latch.output_true.node(), 0));
VLogGraphIfAsked(*root.graph());
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
// This tests the rule {S,&,X} & ~X => S.
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
div1.latch.output_false.index()};
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
}
}
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);

File diff suppressed because it is too large Load Diff

View File

@ -52,16 +52,6 @@ typedef std::function<Status(
// 'group_attribute' must be a string valued-attribute that names the new
// functions to introduce.
//
// 'outside_compilation_attribute' must be a string-valued attribute that is
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
// cluster within the subgraph. A cluster is formed from the set of nodes with
// the same value of outside_compilation_subgraph and group_attribute. The nodes
// in an outside_compilation cluster are left in the original graph. Edges
// crossing from the subgraph to an outside_compilation cluster nested in the
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
// crossing from an outside_compilation cluster into its enclosing subgraph are
// lifted into a SendFromHost/RecvFromHost pair of nodes.
//
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
//
@ -74,10 +64,9 @@ typedef std::function<Status(
// dep from B. Originally D must run after C, post-transformation this
// dependency is lost.
Status EncapsulateSubgraphsInFunctions(
string group_attribute, string outside_compilation_attribute,
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library);
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.

View File

@ -514,10 +514,10 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions(
"_encapsulate", /*outside_compilation_attribute=*/"", *graph,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph_out, lib_def.get());
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false,
&graph_out, lib_def.get());
if (!s.ok()) return s;
std::unordered_map<string, XlaClusterInfo> clusters;
@ -746,7 +746,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", "", graph_before_encapsulation,
"_cluster", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph, &library));
@ -798,7 +798,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "", graph_before,
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
@ -843,7 +843,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "", graph_before,
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
@ -1109,7 +1109,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"F"}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
@ -1990,7 +1990,8 @@ TEST(EncapsulateSubgraphsTest,
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
@ -2117,7 +2118,8 @@ TEST(EncapsulateSubgraphsTest,
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}}},
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@ -2267,7 +2269,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@ -2282,7 +2284,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}},
{}}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});

View File

@ -231,9 +231,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
auto output = absl::make_unique<Graph>((*graph)->op_registry());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
EncapsulateSubgraphsInFunctions(
kXlaClusterAttr, "", **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true, &output, flib_def),
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true,
&output, flib_def),
"EncapsulateXlaComputationsPass failed");
graph->swap(output);
return Status::OK();

View File

@ -393,7 +393,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
// Replace outside compilation function call node with XlaHostCompute node.
// If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node.
Status ReplaceOrRemoveOutsideCompilationCallNode(
xla::StatusOr<Node*> ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
// If the function call node has no input/output edges, just remove it.
@ -413,7 +413,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
if (!has_edge) {
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
g->RemoveNode(call_node);
return Status::OK();
return nullptr;
}
// Build XlaHostCompute NodeDef.
@ -424,7 +424,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
ReplaceNode(g, call_node, node_def));
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
return Status::OK();
return host_compute_node;
}
// Resets "device_ordinal" attr to placeholder value for related nodes
@ -1634,7 +1634,7 @@ Status ExtractOutsideCompilationForFunction(
RewriteOutsideCompilationSubgraphFn rewrite_fn(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name);
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
outside_compilation_attr_name, "", *fbody->graph, rewrite_fn,
outside_compilation_attr_name, *fbody->graph, rewrite_fn,
/*reuse_existing_functions=*/true, &graph_out, fld));
// Replace outside_compilation function nodes with HostCompute ops.
@ -1670,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction(
}
}
}
std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps));
auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status());
Node* host_compute_node = host_compute_node_or.ValueOrDie();
if (host_compute_node) {
host_compute_nodes[host_compute_node->name()] = host_compute_node;
}
}
// For XlaHostCompute nodes with dependencies, add control edges between them
// so XlaCompiler can handle them in correct order.
for (auto iter : host_compute_nodes) {
Node* host_compute_node = iter.second;
std::vector<string> token_input_node_names;
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
kXlaTokenInputNodesAttrName,
&token_input_node_names));
for (const string& node_name : token_input_node_names) {
if (node_name == kXlaTokenArgNodeName) {
continue;
}
auto iter = host_compute_nodes.find(node_name);
if (iter != host_compute_nodes.end()) {
graph_out->AddControlEdge(iter->second, host_compute_node);
}
}
}
// Handle nodes with associated functions.

View File

@ -990,6 +990,16 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
@ -1062,5 +1072,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
}
} // namespace tensorflow

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:internal",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
@ -29,6 +28,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:tf_allocator_adapter",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],

View File

@ -61,7 +61,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
std::unique_ptr<XlaAllocator> xla_allocator;
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
se::DeviceMemoryAllocator* device_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
@ -93,7 +93,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
se::MultiPlatformManager::PlatformWithId(platform_id);
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
xla_allocator = absl::make_unique<XlaAllocator>(
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow {
@ -36,11 +37,11 @@ class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<XlaAllocator> xla_allocator,
se::DeviceMemoryAllocator* device_allocator)
explicit XlaPlatformInfo(
const DeviceType device_type, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
@ -84,8 +85,8 @@ class XlaPlatformInfo {
// then device_allocator_ is the xla::Backend's memory allocator and
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
// then device_allocator_ is null and xla_allocator_ points to an appropriate
// XlaAllocator instance.
std::unique_ptr<XlaAllocator> xla_allocator_;
// se::TfAllocatorAdapter instance.
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);

View File

@ -229,32 +229,18 @@ class MarkForCompilationPassImpl {
// Initialize some internal data structures.
Status Initialize();
// Runs through all the nodes in `cycles_graph_` and tries to create clusters.
// Returns true if any new clusters were created.
StatusOr<bool> RunEdgeContractionLoopInPostOrderOnce();
// Runs through the entire cluster graph in post-order and calls `fn(from,
// to)` on each edge. `fn(from, to)` is expected to return true if it was
// able to contract `from`->`to`.
//
// Returns true if `fn` returned true for any edge.
template <typename FnTy>
StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
// Runs through all the nodes in `cycles_graph_` and tries to contract high
// priority edges for clusters. Returns true if any new clusters were created.
//
// There are potentially many maximal clustering results, but they will not
// all be equally performant. Some clustering decision are likely to improve
// performance much more than others, and we cannot order contractions on this
// cost function, nor can we look at global information while deciding on
// individual edges to contract. Instead, we will make decisions on these
// important edges then make decisions on all other edges, causing the highest
// chance of all most important edges to be contracted.
//
// An example of where this might occur is with a digraph:
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
// should be clustered with A because it will prevent a potentially large
// tensor from A being computed and copied.
//
// This pass will ensure that contraction happens, which cannot be enforced in
// a single pass with the current algorithm.
// graph and prevent B->C from being clusterd in anticipation of a later A->B
// cluster.
StatusOr<bool> ContractPreferredEdges();
// If from->to is a "preferred" edge (i.e. if we have a choice, we want to
// prioritize contracting from->to over contracting other edges) then
// contracts it and returns true. Else returns false.
StatusOr<bool> ContractEdgeIfPreferred(Cluster* from, Cluster* to);
// Contracts as many edges as possible to create XLA clusters. After this
// finishes the clustering decisions made are implicitly stored in
@ -276,10 +262,6 @@ class MarkForCompilationPassImpl {
// true if successful.
StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
// Tries to contract each edge from `cluster_from`. Returns true if any edges
// were contracted, false otherwise.
StatusOr<bool> TryToContractEdgesFrom(Cluster* cluster_from);
// Nodes that XLA can compile are put in `compilation_candidates_`.
Status FindCompilationCandidates();
@ -401,6 +383,13 @@ class MarkForCompilationPassImpl {
return true;
}
string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
absl::string_view reason) {
return absl::StrCat("Could not contract ", from->DebugString(*graph_),
" -> ", to->DebugString(*graph_), " because ", reason,
".");
}
DebugOptions debug_options_;
Graph* graph_;
FunctionLibraryDefinition* flib_def_;
@ -611,7 +600,8 @@ Status MarkForCompilationPassImpl::Initialize() {
return BuildInitialClusterSet();
}
StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
template <typename FnTy>
StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
bool changed = false;
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
@ -632,55 +622,33 @@ StatusOr<bool> MarkForCompilationPassImpl::ContractPreferredEdges() {
continue;
}
if (cluster_to->cluster_size() == 1) {
Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode());
// Shape consuming operations are desirable to cluster with their
// operands because they return a small set of scalar values after
// consuming a large amount of data. For example, given a graph X -> Y
// -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
// [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
// output of size will be a small tensor while Y is a potentially large
// tensor that must be computed and possible transposed/copied before
// the second cluster executes.
if (IsShapeConsumerOp(*n)) {
TF_ASSIGN_OR_RETURN(bool contracted_edge,
TryToContractEdge(cluster_from, cluster_to));
changed |= contracted_edge;
}
}
TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
changed |= contracted_edge;
}
}
return changed;
}
StatusOr<bool>
MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() {
bool changed = false;
// Iterating over the graph once in post-order is sufficient to produce a
// maximal clustering:
//
// A. We visit a cluster only after maximally clustering all its children.
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
// its children that could have been absorbed into `node` have been
// absorbed.
// C. We have an invariant that making a cluster larger does not make edges
// leaving it more contractable. That is, if we have
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
// to contract Y->Z if Y->Z was not contractible originally.
for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
if (!cluster_from) {
continue;
}
StatusOr<bool> MarkForCompilationPassImpl::ContractEdgeIfPreferred(
Cluster* from, Cluster* to) {
if (to->cluster_size() == 1) {
Node* n = graph_->FindNodeId(to->GetIdOfOnlyNode());
TF_ASSIGN_OR_RETURN(bool contracted_one_edge,
TryToContractEdgesFrom(cluster_from));
changed |= contracted_one_edge;
// Shape consuming operations are desirable to cluster with their
// operands because they return a small set of scalar values after
// consuming a large amount of data. For example, given a graph X -> Y
// -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
// [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
// output of size will be a small tensor while Y is a potentially large
// tensor that must be computed and possible transposed/copied before
// the second cluster executes.
if (IsShapeConsumerOp(*n)) {
return TryToContractEdge(from, to);
}
}
return changed;
return false;
}
Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
@ -694,25 +662,68 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
// without restrictions. This helps to minimize data output from clusters (and
// possible transpose operations before outputs) that might occur if a
// ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations.
TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges());
//
// There are potentially many maximal clustering results, but they will not
// all be equally performant. Some clustering decision are likely to improve
// performance much more than others, and we cannot order contractions on this
// cost function, nor can we look at global information while deciding on
// individual edges to contract. Instead, we will make decisions on these
// important edges then make decisions on all other edges, causing the highest
// chance of all most important edges to be contracted.
//
// An example of where this might occur is with a digraph:
// {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
// not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
// should be clustered with A because it will prevent a potentially large
// tensor from A being computed and copied.
//
// This pass will ensure that contraction happens, which cannot be enforced in
// a single pass with the current algorithm.
// graph and prevent B->C from being clusterd in anticipation of a later A->B
// cluster.
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
TF_ASSIGN_OR_RETURN(bool changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return ContractEdgeIfPreferred(from, to);
}));
// Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the
// linear time post-order scheme has been battle tested we can move this to
// happen only in debug builds.
TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce());
// Iterating over the whole graph once in post-order is sufficient to produce
// a maximal clustering:
//
// A. We visit a cluster only after maximally clustering all its children.
// B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of
// its children that could have been absorbed into `node` have been
// absorbed.
// C. We have an invariant that making a cluster larger does not make edges
// leaving it more contractable. That is, if we have
// digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
// to contract Y->Z if Y->Z was not contractible originally.
TF_ASSIGN_OR_RETURN(changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return TryToContractEdge(from, to);
}));
// Check that the conclusion made above (that iterating over the graph once in
// post order gives a maximal clustering) holds. Once the linear time
// post-order scheme has been battle tested we can move this to happen only in
// debug builds.
TF_ASSIGN_OR_RETURN(changed,
ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
return TryToContractEdge(from, to);
}));
TF_RET_CHECK(!changed);
return Status::OK();
}
std::atomic<int64> cluster_sequence_num;
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
Status MarkForCompilationPassImpl::CreateClusters() {
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
clusters_created_ = true;
static std::atomic<int64> cluster_sequence_num;
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
@ -745,7 +756,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
string& name = cluster_names[cluster->cycles_graph_node_id()];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
}
n->AddAttr(kXlaClusterAttr, name);
@ -1065,8 +1076,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
Cluster* from, Cluster* to, absl::string_view reason) {
VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> "
<< to->DebugString(*graph_) << " because " << reason << ".";
VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
return false;
}
@ -1075,8 +1085,14 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
DCHECK(from->deadness_predicate().has_value() ==
to->deadness_predicate().has_value());
if (from->deadness_predicate() != to->deadness_predicate()) {
return LogNotContractableAndReturnFalse(
from, to, "the two nodes have mismatching deadness");
VLOG(3) << EdgeContractionFailureMsg(
from, to,
absl::StrCat(
"the two nodes have mismatching deadness: ",
deadness_analysis_->DebugString(*from->deadness_predicate()),
" and ",
deadness_analysis_->DebugString(*to->deadness_predicate())));
return false;
}
TF_ASSIGN_OR_RETURN(bool devices_compatible,
@ -1133,32 +1149,6 @@ StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
return MergeClusters(from, to);
}
StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdgesFrom(
Cluster* cluster_from) {
bool changed = false;
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32> successors_copy =
cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
iteration_count_++;
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
if (!cluster_to) {
continue;
}
TF_ASSIGN_OR_RETURN(bool contracted_edge,
TryToContractEdge(cluster_from, cluster_to));
changed |= contracted_edge;
}
return changed;
}
Status MarkForCompilationPassImpl::Run() {
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
@ -1485,7 +1475,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_and_inaccurate_ops = true;
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
return RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
.IsCompilableCall(ndef, flr);
@ -1522,4 +1513,8 @@ Status MarkForCompilationPass::RunForTest(
return MarkForCompilation(options, debug_options);
}
namespace testing {
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
} // namespace testing
} // namespace tensorflow

View File

@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
namespace testing {
// DO NOT USE IN PRODUCTION.
//
// Resets some internal state to let us write reliable unit tests.
void ResetClusterSequenceNumber();
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_

View File

@ -1,7 +1,6 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")

View File

@ -49,7 +49,7 @@ Status ShapeAnnotationsMatch(
missing.push_back(entry.first);
}
return errors::InvalidArgument("Missing shapes for nodes: ",
str_util::Join(missing, ","));
absl::StrJoin(missing, ","));
}
return Status::OK();
}

View File

@ -60,7 +60,8 @@ Status XlaCpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -71,7 +71,7 @@ class XlaDeviceContext : public DeviceContext {
StatusCallback done) const override;
xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); }
se::Stream* stream() const override { return stream_.get(); }
se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get();
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/kernels/host_constant_op.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/logging_ops.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
@ -81,6 +82,11 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("Assert") \
.Device(DEVICE) \
.HostMemory("condition") \
.HostMemory("data"), \
AssertOp); \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
REGISTER_KERNEL_BUILDER( \

View File

@ -95,7 +95,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -63,7 +63,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_and_inaccurate_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);

View File

@ -167,32 +167,6 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
return Status::OK();
}
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {}
xla::StatusOr<se::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
void* data = nullptr;
if (size != 0) {
data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
if (data == nullptr) {
return errors::ResourceExhausted(
"Out of memory while trying to allocate ", size, " bytes.");
}
}
return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
}
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
wrapped_->DeallocateRaw(mem.opaque());
return Status::OK();
}
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams)

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/stream_executor/device_memory_allocator.h"
namespace tensorflow {
class XlaAllocator;
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
@ -104,74 +103,6 @@ class VariableInfo {
Status LockVariables(absl::Span<VariableInfo> variables)
EXCLUSIVE_LOCK_FUNCTION();
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public se::DeviceMemoryAllocator {
public:
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override;
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
Allocator* wrapped_;
};
// Adapter class that wraps per-device TF allocators as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation;
// see comment on `AllowsAsynchronousDeallocation()`.
class MultiDeviceAdapter : public se::DeviceMemoryAllocator {
public:
MultiDeviceAdapter(
const se::Platform* platform,
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators)
: DeviceMemoryAllocator(platform),
tf_allocators_(std::move(tf_allocators)) {
for (const auto& tf_allocator : tf_allocators_) {
per_device_allocators_.emplace_back(platform, tf_allocator.get());
}
}
xla::StatusOr<se::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size,
retry_on_failure);
}
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
CHECK_LT(device_ordinal, per_device_allocators_.size());
return per_device_allocators_[device_ordinal].Deallocate(device_ordinal,
mem);
}
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
// compute stream to enforce a happens-before relationship between a memory
// allocation and code that reuses the same memory. If Tensorflow adds
// support for multiple GPU streams or allocators with different ordering
// requirements, this code may need to change.
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
private:
std::vector<tensorflow::XlaAllocator> per_device_allocators_;
// The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does
// not take ownership of its underlying Allocator).
std::vector<std::unique_ptr<tensorflow::Allocator>> tf_allocators_;
};
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext {

View File

@ -28,10 +28,9 @@
** Please don't remove this file - it is supporting some 3rd party plugins **
"""
licenses(["notice"])
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
cc_library(

View File

@ -954,7 +954,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
size = "medium",
srcs = ["ternary_ops_test.py"],
deps = [
":xla_test",

View File

@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import session
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit()
def testCondAndTensorArrayInDefun_constFolding(self):
g = ops.Graph()
with session.Session(graph=g), g.as_default(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@function.defun
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(False),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testCondAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit()
def testCondConstPropagation(self):
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit()
def testSwitchCaseAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseConstPropagation(self):
self.skipTest("b/127846988")
with self.session() as sess, self.test_scope():

View File

@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
padding="VALID",
patches=patches)
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 2]
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
# [1, 1, 1, 8]
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="VALID",
patches=patches)
if __name__ == "__main__":
test.main()

View File

@ -72,21 +72,21 @@ class TernaryOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types:
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array(False),
np.array(2, dtype=dtype),
np.array(7, dtype=dtype),
expected=np.array(7, dtype=dtype))
self._testTernary(
array_ops.where,
np.array(1, dtype=np.bool),
np.array(True),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([1, 2, 3, 4], dtype=dtype))
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array(False),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
@ -105,6 +105,74 @@ class TernaryOpsTest(xla_test.XLATestCase):
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
def testSelectV2(self):
for dtype in self.numeric_types:
self._testTernary(
array_ops.where_v2,
np.array(False),
np.array(2, dtype=dtype),
np.array(7, dtype=dtype),
expected=np.array(7, dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array(True),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([1, 2, 3, 4], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array(False),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array([0, 1, 1, 0], dtype=np.bool),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([5, 2, 3, 8], dtype=dtype))
# Broadcast the condition
self._testTernary(
array_ops.where_v2,
np.array([0, 1], dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype))
# Broadcast the then branch to the else
self._testTernary(
array_ops.where_v2,
np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool),
np.array([[1, 2]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
# Broadcast the else branch to the then
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
np.array([[1, 2]], dtype=dtype),
expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
# Broadcast the then/else branches to the condition
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool),
np.array(7, dtype=dtype),
np.array(8, dtype=dtype),
expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype))
self._testTernary(
array_ops.where_v2,
np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool),
np.array(7, dtype=dtype),
np.array([8, 9], dtype=dtype),
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
def testSlice(self):
for dtype in self.numeric_types:
self._testTernary(

View File

@ -104,6 +104,24 @@ struct EdgePtrCompare {
}
};
// TODO(laigd): instead of deciding the device here, the converter should accept
// a device name as one of the conversion parameter so users can control on
// which device they want to run the conversion.
std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
PlatformGpuId platform_gpu_id;
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
<< platform_gpu_id.value();
return std::make_pair(tf_gpu_id, platform_gpu_id);
}
}
LOG(ERROR) << "Could not find any TF GPUs";
return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
}
// Function to get subsegment information structure.
Status GetEngineInfo(const Graph* g,
const grappler::GraphProperties& graph_properties,
@ -128,27 +146,43 @@ Status GetEngineInfo(const Graph* g,
if (segment_nodes.count(node) == 0) continue;
auto node_device = node->requested_device();
if (!node_device.empty()) {
// If device is CPU, treat as if no device was assigned. Don't add CPU to
// segment_device because that would cause a segfault in
// GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes
// any already set device is a GPU.
// If device is set, it means device placement may have been done before,
// so we need to assign a device for the TRTEngineOp to maintain the
// invariance.
// If the device is CPU in this case, it tries to find the first available
// GPU and use it as the device.
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseFullName(node_device, &parsed_name);
if (parsed_name.type == "CPU") {
VLOG(1) << "Node " << node->name() << " was assigned to the CPU. "
<< "Attempting to place on GPU.";
const bool parse_succeeded =
DeviceNameUtils::ParseFullName(node_device, &parsed_name);
if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) {
string msg;
if (!parse_succeeded) {
msg = StrCat("Failed to parse assigned device of node ", node->name(),
". ");
} else {
msg = StrCat("Node ", node->name(), " was assigned to the CPU. ");
}
VLOG(1) << msg << "Attempting to place on GPU.";
TfGpuId tf_gpu_id;
PlatformGpuId platform_gpu_id;
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
if (tf_gpu_id.value() >= 0) {
parsed_name.type = "GPU";
parsed_name.id = tf_gpu_id.value();
segment_devices.insert(DeviceNameUtils::FullName(
parsed_name.job, parsed_name.replica, parsed_name.task,
parsed_name.type, parsed_name.id));
}
} else {
segment_devices.insert(node_device);
}
} else if (node->has_assigned_device_name()) {
// It appears that nodes will not have assigned devices at this point in
// execution.
segment_devices.insert(node->assigned_device_name());
} else {
if (node->has_assigned_device_name()) {
// It appears that nodes will not have assigned devices at this point in
// execution.
segment_devices.insert(node->assigned_device_name());
} else {
VLOG(2) << "Node " << node->name()
<< " neither have requested device nor assigned device";
}
VLOG(2) << "Node " << node->name()
<< " neither have requested device nor assigned device";
}
subgraph_nodes.push_back(node);
@ -251,13 +285,11 @@ Status GetEngineInfo(const Graph* g,
info->engine_name = StrCat(scope_name, info->engine_name);
VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
<< "' to a GraphDef";
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) {
info->device = *segment_devices.begin();
} else if (segment_devices.size() > 1) {
LOG(WARNING) << "Detected multiple(" << segment_devices.size()
<< ") devices for the segment. Picking first one to continue "
<< "but this shouldn't have happened";
LOG(WARNING) << "Detected multiple (" << segment_devices.size()
<< ") devices for the segment. Picking first one to continue.";
info->device = *segment_devices.begin();
} else {
VLOG(1) << "No device is assigned to the segment. "
@ -543,10 +575,10 @@ Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
std::map<string, Node*> io_nodes;
int num_inputs = 0;
for (auto n : sgraph.op_nodes()) {
if (str_util::StartsWith(n->name(), kInputPHName)) {
if (absl::StartsWith(n->name(), kInputPHName)) {
num_inputs++;
io_nodes.insert({n->name(), n});
} else if (str_util::StartsWith(n->name(), kOutputPHName)) {
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
io_nodes.insert({n->name(), n});
}
}
@ -640,24 +672,17 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
engine.device.empty()) {
// If device is not set, use the first found GPU device for the conversion.
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
PlatformGpuId platform_gpu_id;
Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
<< platform_gpu_id.value();
cuda_device_id = platform_gpu_id.value();
GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the
// GetGPUAllocator() call won't create a new allocator.
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
gpu_options, tf_gpu_id, 1);
break;
}
LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
<< s;
TfGpuId tf_gpu_id;
PlatformGpuId platform_gpu_id;
std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
cuda_device_id = platform_gpu_id.value();
if (cuda_device_id >= 0) {
GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the
// GetGPUAllocator() call won't create a new allocator.
dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
gpu_options, tf_gpu_id, 1);
}
return std::make_pair(cuda_device_id, dev_allocator);
}
@ -750,8 +775,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
EngineInfo curr_engine;
curr_engine.engine_name = StrCat("TRTEngineOp_", t);
Status status =
GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
node_map, reverse_topo_order, &curr_engine);
GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
reverse_topo_order, &curr_engine);
if (!status.ok()) {
LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
<< status;
@ -776,7 +801,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
total_engine_bytes_size += engine_bytes_size.back();
total_num_nodes_in_segments += curr_segment.first.size();
total_num_nodes_in_segments += curr_segment.size();
engine_segments.push_back(std::move(curr_engine));
converted_segments.push_back(std::move(curr_segment));
@ -806,7 +831,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
engine.max_workspace_size_bytes =
params.max_workspace_size_bytes *
(engine_bytes_size.at(i) / total_engine_bytes_size +
converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
converted_segments.at(i).size() / total_num_nodes_in_segments) /
2.0;
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
<< engine.engine_name;
@ -828,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) {
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
alloc.get(), &engine_nodes);
string msg = StrCat("TensorRT node ", engine.engine_name,
" added for segment ", i, " consisting of ",
converted_segments.at(i).first.size(), " nodes");
string msg =
StrCat("TensorRT node ", engine.engine_name, " added for segment ", i,
" consisting of ", converted_segments.at(i).size(), " nodes");
if (status.ok()) {
LOG(INFO) << msg << " succeeded.";
} else {
@ -839,7 +864,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
}
if (VLOG_IS_ON(1)) {
msg = "Segment consists of nodes: ";
for (const Node* node : converted_segments.at(i).first) {
for (const Node* node : converted_segments.at(i)) {
StrAppend(&msg, node->name(), ", ");
}
VLOG(1) << msg;
@ -848,7 +873,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
// If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified.
if (status.ok()) {
for (const Node* node : converted_segments.at(i).first) {
for (const Node* node : converted_segments.at(i)) {
graph.RemoveNode(const_cast<Node*>(node));
}
}

View File

@ -239,7 +239,7 @@ class ConvertAfterShapesTest : public ::testing::Test {
params.output_names = &output_names;
params.max_workspace_size_bytes = 8 << 20;
params.output_graph_def = output_graph_def;
params.minimum_segment_size = 2;
params.minimum_segment_size = 1;
params.graph_properties = &graph_properties;
params.use_calibration = false;

View File

@ -385,11 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) {
", dims=", DebugString(tensor.getDimensions()), ")");
}
Status Converter::GetTrtBroadcastShape(
const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims) const {
// ***************************************************************************
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims) {
// TensorRT Elementwise op supports broadcast but requires both tensor to be
// of Identical rank
//
@ -473,14 +472,13 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
if (!layer) return nullptr;
const nvinfer1::DataType trt_dtype = trt_weights.type;
nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
// TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
// the data type below, it will always be kFLOAT regardless what the data type
// of the weights is. Once NVIDIA fixes this bug, we should remove the data
// type setting logic below and test should still pass.
trt_tensor->setType(trt_dtype);
trt_tensor->setType(trt_weights.type);
#endif
return trt_tensor;
}
@ -1677,190 +1675,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
return Status::OK();
}
// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
// right operand. If swapped_inputs is true, those two are swapped.
//
// TODO(jie): broadcast is needed yet not implemented.
// Only implemented channel wise for the time being.
Status BinaryTensorOpWeight(OpConverterParams* params,
nvinfer1::ITensor* tensor,
TRT_ShapedWeights weights, bool swapped_inputs) {
static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
"Div", "RealDiv"};
const auto& node_def = params->node_def;
if (!supported_ops.count(node_def.op())) {
return errors::Unimplemented(node_def.op(), " is not supported, at ",
node_def.name());
}
// Check scale mode.
auto dims_w = weights.shape_;
const auto dims_t = tensor->getDimensions();
// TODO(jie): addScale checks for input tensor dimension
if (dims_t.nbDims != 3) {
return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
node_def.name());
}
// Default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
// TODO(jie): maybe use a permutation instead to support more cases;
bool need_to_permute = false;
if (weights.count() == 1) {
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
} else {
VLOG(2) << "weights dims: " << DebugString(dims_w)
<< "; tensor dims: " << DebugString(dims_t);
// Make sure no broadcasting on batch dimension.
if (dims_w.nbDims == dims_t.nbDims + 1) {
if (dims_w.d[0] == 1) {
for (int i = 1; i < dims_w.nbDims; i++) {
dims_w.d[i - 1] = dims_w.d[i];
}
dims_w.nbDims--;
} else {
return errors::InvalidArgument("Binary op cannot operate on batch, at ",
node_def.name());
}
}
if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
// Default is element-wise
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != dims_t.d[i]) {
// If dimension does not match, switch back to per-channel
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
// If the mode is per-channel, since channel dimension is assumed to be
// the third to last dimension, we need to make sure all other dimensions
// have size 1.
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != 1)
return errors::InvalidArgument(
"Weight dims not compatible for channel-wise broadcast at ",
node_def.name());
}
}
} else if (dims_w.nbDims == 1 &&
dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
// Channel wise and broadcast required. We compare the last dimension of
// the tensor shape because of tensorflow default broadcasting rules.
need_to_permute = true;
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else {
return errors::InvalidArgument("Weight dims not compatible at ",
node_def.name());
}
}
// TODO(laigd): we should add validation_only support in TransposeTensor() and
// PrepareTensorForShape().
if (params->validation_only) return Status::OK();
// Transpose last dimension.
std::vector<int> permutation(dims_t.nbDims + 1);
if (need_to_permute) {
// We swap the last dimension into channel for trt, because of tensorflow
// default broadcasting rules.
for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
permutation[i] = i;
}
permutation[1] = dims_t.nbDims;
permutation[dims_t.nbDims] = 1;
TF_RETURN_IF_ERROR(
params->converter->TransposeTensor(tensor, permutation, &tensor));
}
// Prepare weights
TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.TrtDType());
TRT_ShapedWeights power_weights(weights.TrtDType());
if (node_def.op() == "Sub") {
if (swapped_inputs) {
shift_weights = weights;
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
*tensor, nvinfer1::UnaryOperation::kNEG);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
// Since quantization ranges are symmetric, the same range as the input
// will work for the negation of the input.
params->converter->MarkQuantizationRangesAsInferrable(
tensor, layer->getOutput(0));
tensor = layer->getOutput(0);
} else {
TRT_ShapedWeights neg_weights =
params->weight_store->GetTempWeights(weights);
LambdaFactory unary_op;
unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
shift_weights = neg_weights;
}
} else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
if (swapped_inputs) {
// We need to infer the quantization range for this intermediate tensor.
//
// x -> [Recip] -> 1/x -> [Scale] -> s/x
// ^
// need range for this
//
// We have the quantization scales for x and s/x - can we divide the scale
// for s/x by s? Only if it is a scalar.
//
// Because of this issue, fall back to BinaryTensorOpTensor if we are
// doing INT8 with no calibration. There is most likely no performance
// penalty by falling back here.
if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
!params->converter->use_calibration()) {
return errors::Unimplemented(
"Intermediate quantization range cannot be determined without"
" calibration. Falling back to BinaryTensorOpTensor for ",
node_def.op(), ", at ", node_def.name());
}
scale_weights = weights;
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
*tensor, nvinfer1::UnaryOperation::kRECIP);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
tensor = layer->getOutput(0);
} else {
TRT_ShapedWeights recip_weights =
params->weight_store->GetTempWeights(weights);
LambdaFactory unary_op;
unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
scale_weights = recip_weights;
}
} else if (node_def.op() == "Mul") {
scale_weights = weights;
} else if (node_def.op() == "Add") {
shift_weights = weights;
} else {
// This should not happen.
return errors::Unimplemented("Binary op not supported at ", node_def.op());
}
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
*tensor, scale_mode, shift_weights.GetTrtWeights(),
scale_weights.GetTrtWeights(), power_weights.GetTrtWeights());
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// Transpose back dimension
if (need_to_permute) {
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
output_tensor, permutation, &output_tensor));
}
// Pass the output
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
Status ConvertConv2DHelper(OpConverterParams* params, int group,
bool is_conv2d_backprop_input) {
const auto& inputs = params->inputs;
@ -1951,7 +1765,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
// Add padding.
// Before TRT 5.1.3, we have to calculate padding ourselves.
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
std::vector<std::pair<int, int>> padding;
if (attrs.get<string>("padding") == "SAME") {
nvinfer1::DimsHW effective_kernel_size = kernel_size;
@ -1978,12 +1793,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
padding = {{0, 0}, {0, 0}};
}
// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we
// can only use asymmetric padding in convolutions with 5.1.3+.
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
// Handle asymmetric padding. TensorRT 5.1 added support for asymmetric
// padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can
// only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we
// will always use setPaddingMode for simplicity.
if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) {
// Handle asymmetric padding.
auto pad_layer = params->converter->network()->addPadding(
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
@ -2006,20 +1821,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
layer->setStride(stride);
// TensorRT 5.1.3 added support for padding modes.
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
// VALID padding is the default TRT behavior.
if (attrs.get<string>("padding") == "SAME") {
VLOG(2) << "Using SAME padding";
// SAME_UPPER means that post padding is preferred.
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
// For VALID padding, we need to manually set the padding.
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(
nvinfer1::DimsHW{padding[0].second, padding[1].second});
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
<< " and post-padding to: " << DebugString(layer->getPostPadding());
#else
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
#endif
layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
@ -2033,17 +1841,10 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
layer->setStride(stride);
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
if (attrs.get<string>("padding") == "SAME") {
VLOG(2) << "Using SAME padding";
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(
nvinfer1::DimsHW{padding[0].second, padding[1].second});
VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding())
<< " and post-padding to: " << DebugString(layer->getPostPadding());
#else
layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
VLOG(2) << "Set padding to: " << DebugString(layer->getPadding());
#endif
layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
@ -2061,74 +1862,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
return Status::OK();
}
Status BinaryTensorOpTensor(OpConverterParams* params,
const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r) {
const auto& node_def = params->node_def;
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
{"Div", nvinfer1::ElementWiseOperation::kDIV},
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
};
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end()) {
return errors::Unimplemented("Binary op ", node_def.op(),
" not supported at: ", node_def.name());
}
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
Status status = params->converter->GetTrtBroadcastShape(
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
if (!status.ok()) {
return errors::InvalidArgument(
"Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
status.error_message());
}
TFAttrs attrs(node_def);
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
if (dtype == nvinfer1::DataType::kINT32) {
return errors::Unimplemented("Binary op ", node_def.op(),
" does not support INT32, at ",
node_def.name());
}
if (params->validation_only) return Status::OK();
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
status = params->converter->PrepareTensorForShape(
operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l);
if (status.ok()) {
status = params->converter->PrepareTensorForShape(
operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r);
}
if (!status.ok()) {
return errors::Internal("Failed to convert binary op ", node_def.name(),
": ", status.error_message());
}
// Check type consistency.
TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
<< DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
<< DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
// Add ElementWise layer.
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
op_pair->second);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// Pass the output
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
Status ConvertPlugin(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
@ -2777,6 +2510,8 @@ Status ConvertPool(OpConverterParams* params) {
const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
// Before TRT 5.1.3, we have to calculate padding ourselves.
#if !IS_TRT_VERSION_GE(5, 1, 3, 0)
auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding;
if (padding_type == "SAME") {
@ -2789,13 +2524,13 @@ Status ConvertPool(OpConverterParams* params) {
} else if (padding_type == "VALID") {
padding = {{0, 0}, {0, 0}};
}
// TensorRT 5.1 added support for asymmetric padding.
#endif
// TensorRT 5.1 added support for asymmetric padding. Before that, we need an
// extra padding layer.
#if !IS_TRT_VERSION_GE(5, 1, 0, 0)
// Asymmetric padding case.
if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) {
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
auto pad_layer = params->converter->network()->addPadding(
*tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
@ -2817,16 +2552,13 @@ Status ConvertPool(OpConverterParams* params) {
layer->getOutput(0));
layer->setStride(stride);
// TensorRT 5.1.3 added support for padding modes.
#if IS_TRT_VERSION_GE(5, 1, 3, 0)
// VALID padding is the default TRT behavior.
if (attrs.get<string>("padding") == "SAME") {
// SAME_UPPER means that post padding is preferred.
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
#endif
// TensorRT 5.1 has support for asymmetric padding.
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
// If padding mode is not SAME, then these values will be used instead.
#elif IS_TRT_VERSION_GE(5, 1, 0, 0)
layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
#else
@ -3350,9 +3082,6 @@ Status ConvertIdentity(OpConverterParams* params) {
Status ConvertBinary(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
// false}}));
if (inputs.size() != 2) {
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
" inputs but expected 2, at ",
@ -3368,33 +3097,45 @@ Status ConvertBinary(OpConverterParams* params) {
"both input as constant at: ",
node_def.name());
}
const TRT_TensorOrWeights& operand_l = inputs.at(0);
const TRT_TensorOrWeights& operand_r = inputs.at(1);
// TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with
// IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For
// now, the performance will be slightly better with IScaleLayer because it
// can be fused in more situations. However, most of the benefits of
// IScaleLayer are when the layer performs both a shift and a scale, which we
// don't do except for convolutions.
//
// Try to convert into Scale layer first (for better performance).
// Since scale layer supports restricted broadcast policy and op types, we
// allow failure and try to handle it through Elementwise op
// (BinaryTensorOpTensor).
Status status = Status::OK();
if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
status = BinaryTensorOpWeight(params, inputs.at(0).tensor(),
inputs.at(1).weights(), false);
} else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
inputs.at(0).weights(), true);
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
{"Div", nvinfer1::ElementWiseOperation::kDIV},
{"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
{"Pow", nvinfer1::ElementWiseOperation::kPOW},
};
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end()) {
return errors::Unimplemented("Binary op ", node_def.op(),
" not supported at: ", node_def.name());
}
// If both input are tensors, or one of them is weights but the conversion
// above failed, try the conversion using BinaryTensorOpTensor.
if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
if (!status.ok()) VLOG(2) << status;
status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
}
return status;
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r));
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
// This will also convert constants to tensors, and set quantization ranges.
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
operand_l, broadcasted_dims_l, params->validation_only, &tensor_l));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
operand_r, broadcasted_dims_r, params->validation_only, &tensor_r));
if (params->validation_only) return Status::OK();
// Add ElementWise layer.
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(*tensor_l, *tensor_r,
op_pair->second);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
return Status::OK();
}
Status ConvertRsqrt(OpConverterParams* params) {
@ -4547,7 +4288,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) {
const auto& node_def = params->node_def;
// Broadcast inputs.
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape(
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r));
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
@ -4692,8 +4433,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) {
TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
// Create plugin
nvinfer1::IPluginV2* plugin =
creator->createPlugin(node_def.name().c_str(), &fc);
TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
creator->createPlugin(node_def.name().c_str(), &fc));
TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
// Set plugin inputs
@ -4875,7 +4616,8 @@ static void RegisterValidatableOpConverters(
for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
(*registration)[pool_op_type] = ConvertPool;
}
for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) {
for (auto normalization_op_type :
{"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) {
(*registration)[normalization_op_type] = ConvertFusedBatchNorm;
}
for (auto unary_op_pair : *UnaryOperationMap()) {

View File

@ -512,13 +512,6 @@ class Converter {
const bool validation_only,
nvinfer1::ITensor** tensor);
// Return OK if the broadcast scheme is supported and compute the shapes after
// broadcasting.
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims) const;
// Creates an IConstantLayer using 'weights' whose dimensions are specified by
// 'dims', and returns the output ITensor.
nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
@ -592,6 +585,13 @@ class Converter {
friend class OpConverterTest;
};
// Return OK if the broadcast scheme is supported and compute the shapes after
// broadcasting.
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims);
// Map of all supported UnaryOperations
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
// Map of all supported ActivationTypes

View File

@ -988,19 +988,17 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
// operand_1 broadcast operand_2
ExpectStatus(
this->converter_->GetTrtBroadcastShape(
operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
expected_code, expected_error_msg_substr);
ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims,
&operand_2_new_dims),
expected_code, expected_error_msg_substr);
if (expected_code == error::OK) {
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
}
// operand_2 broadcast operand_1
ExpectStatus(
this->converter_->GetTrtBroadcastShape(
operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
expected_code, expected_error_msg_substr);
ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims,
&operand_1_new_dims),
expected_code, expected_error_msg_substr);
if (expected_code == error::OK) {
ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
@ -1033,18 +1031,29 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 4 vs broadcast #dims 5)");
symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 2 vs broadcast #dims 3)",
/*operand_1_batch_size=*/2);
// Both inputs are tensors.
symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 3 vs broadcast #dims 4)");
symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 2 vs broadcast #dims 3)");
symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
{2, 1, 4});
symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT,
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims 4 vs broadcast #dims 5)");
symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {},
error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
}
TEST_F(ConverterTest, CreateConstantLayer) {
@ -1070,7 +1079,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
int batch_size = -1;
for (const NodeDef& node : gdef.node()) {
absl::string_view node_name(node.name());
if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
int port = -1;
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
@ -1351,10 +1360,6 @@ class OpConverterTest : public ::testing::Test {
}
}
void TestMatMulHelper(
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
const std::string& op_name);
// Expose quantization_ranges_ for tests
std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
return converter_->quantization_ranges_;
@ -1682,59 +1687,60 @@ TEST_F(OpConverterTest, ConvertReshape) {
// Helper function for testing MatMul and BatchMatMul
// get_matmul corresponds to the function used to generate the node. It should
// accept (DataType, transpose_a, transpose_b) as parameters.
void OpConverterTest::TestMatMulHelper(
void TestMatMulHelper(
OpConverterTest* test,
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
const std::string& op_name) {
// HACK: This needs to be done in a better way.
const bool is_batch_matmul = op_name == "BatchMatMul";
{
// Unsupported data type.
Reset();
test->Reset();
NodeDef node_def = get_matmul(DT_INT32, false, false);
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
RunValidationAndConversion(
test->AddTestTensor("input", {2}, /*batch_size=*/1,
nvinfer1::DataType::kINT32);
test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
test->RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
("Data type int32 is not supported for " + op_name +
", "
"must be one of [float, half], at my_matmul")
StrCat("Data type int32 is not supported for ", op_name,
", must be one of [float, half], at my_matmul")
.c_str());
}
// OK.
for (bool transpose_a : {false, true}) {
for (bool transpose_b : {false, true}) {
Reset();
test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
test->AddTestTensor("input", {2}, /*batch_size=*/1);
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) {
if (transpose_a || transpose_b) {
RunValidationAndConversion(
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension for "
"BatchMatMul, at my_matmul");
} else {
RunValidationAndConversion(
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension");
}
continue;
} else if (transpose_a) {
RunValidationAndConversion(
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions");
continue;
}
RunValidationAndConversion(node_def);
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
BuildAndRun(input_data, &output_data);
test->BuildAndRun(input_data, &output_data);
if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else {
@ -1744,31 +1750,31 @@ void OpConverterTest::TestMatMulHelper(
}
// OK, 3D inputs
for (bool transpose_b : {false, true}) {
Reset();
test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
test->AddTestTensor("input", {2}, /*batch_size=*/1);
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) {
if (transpose_b) {
RunValidationAndConversion(
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension for "
"BatchMatMul, at my_matmul");
} else {
RunValidationAndConversion(
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Input weight attempts to broadcast across batch dimension");
}
continue;
}
RunValidationAndConversion(node_def);
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
BuildAndRun(input_data, &output_data);
test->BuildAndRun(input_data, &output_data);
if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else {
@ -1832,7 +1838,7 @@ TEST_F(OpConverterTest, ConvertMatMul) {
node_def, error::INVALID_ARGUMENT,
"Cannot currently transpose constant input if it is not 2 dimensional");
}
TestMatMulHelper(get_matmul_nodedef, "MatMul");
TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
}
TEST_F(OpConverterTest, ConvertBatchMatMul) {
@ -1889,7 +1895,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
}
}
TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul");
TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
}
template <DataType dtype>
@ -2010,250 +2016,82 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
}
template <typename OpType, DataType dtype>
void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
for (auto swap_inputs : {false, true}) {
test->Reset();
NodeDef node_def;
if (swap_inputs) {
node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
} else {
node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
}
const std::vector<CType> operand1{CType(3), CType(7.5)};
const std::vector<CType> operand2{CType(2), CType(3)};
// It requires the dims to be at least of rank 3 to apply an IScaleLayer.
test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
/*values=*/swap_inputs ? operand1 : operand2);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(10.5)));
} else if (node_def.op() == "Sub") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1), CType(4.5)));
} else if (node_def.op() == "Mul") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(6), CType(22.5)));
} else if (node_def.op() == "Div") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(2.5)));
} else if (node_def.op() == "RealDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(2.5)));
} else {
ASSERT_TRUE(false);
}
}
}
template <DataType dtype>
void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const NodeDef node_def =
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
const std::vector<CType> weights{CType(10), CType(20)};
// There are two types of valid dim pairs which requires channel-wise
// broadcasting:
// - input dims (X Y Z) vs weights dims (X 1 1)
// - input dims (X Y Z) vs weights dims (Z)
// Here X=Z=2 and Y=1.
for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
test->Reset();
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", weights_dims, weights);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
test->BuildAndRun(input_data, &output_data);
if (weights_dims.size() == 1) {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(22), CType(13), CType(24)));
} else {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(23), CType(24)));
}
}
}
template <DataType dtype>
void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const NodeDef node_def =
GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
const std::vector<CType> weights{CType(10)};
test->Reset();
test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
CheckAddedLayers(test, /*expect_scale_layer=*/true);
// Check the dims of the output ITensor.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
test->BuildAndRun(input_data, &output_data);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(13), CType(14)));
}
template <typename OpType>
void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
const std::vector<int32>& input_dims,
const std::vector<int>& weights_dims,
error::Code code = error::OK,
const char* error_msg_substr = nullptr,
const int input_batch_size = 1) {
const DataType dtype = DT_FLOAT;
typedef typename EnumToDataType<dtype>::Type CType;
const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims));
const size_t num_weights =
TrtWeightDimsNumElements(GetTestDims(weights_dims));
test->Reset();
const NodeDef node_def =
GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
TfDataTypeToTrt(dtype));
test->AddTestWeights<CType>(
"weights", /*dims=*/weights_dims,
/*values=*/std::vector<CType>(num_weights, CType(1)));
test->RunValidationAndConversion(node_def, code, error_msg_substr);
if (code != error::OK) return;
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
CheckAddedLayers(test, /*expect_scale_layer=*/false);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
// Check the dims of the output ITensor.
std::vector<int> expected_output_dims = input_dims;
for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
i >= 0 && j >= 0; --i, --j) {
if (expected_output_dims[i] == 1) {
expected_output_dims[i] = weights_dims[j];
}
}
ExpectTrtDimsEqualsArray(expected_output_dims,
output.tensor()->getDimensions());
// Check the result of running the engine.
const int expected_num_outputs =
TrtTensorDimsNumElements(GetTestDims(expected_output_dims));
const DataVec input_data{
{"input", ConstructTensor<CType>(num_inputs, CType(2))}};
DataVec output_data{
{"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
test->BuildAndRun(input_data, &output_data);
if (node_def.op() == "Add") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
} else if (node_def.op() == "Minimum") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
} else {
ASSERT_TRUE(false);
}
}
template <typename OpType, DataType dtype>
void TestBinaryTensorOpTensor(OpConverterTest* test) {
void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor,
bool operand_2_is_tensor) {
typedef typename EnumToDataType<dtype>::Type CType;
test->Reset();
const NodeDef node_def =
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
if (operand_1_is_tensor) {
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2,
TfDataTypeToTrt(dtype));
} else {
test->AddTestWeights("input1", /*dims=*/{1, 2},
/*values=*/std::vector<CType>{CType(3), CType(6)});
}
if (operand_2_is_tensor) {
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2,
TfDataTypeToTrt(dtype));
} else {
test->AddTestWeights("input2", /*dims=*/{2, 1},
/*values=*/std::vector<CType>{CType(2), CType(3)});
}
test->RunValidationAndConversion(node_def);
// Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
CheckAddedLayers(test, /*expect_scale_layer=*/false);
DataVec input_data;
if (operand_1_is_tensor) {
input_data.push_back(
{"input1",
test::AsTensor<CType>({CType(3), CType(6), CType(3), CType(6)})});
}
if (operand_2_is_tensor) {
input_data.push_back(
{"input2",
test::AsTensor<CType>({CType(2), CType(3), CType(2), CType(3)})});
}
DataVec output_data{{"my_binary", ConstructTensor<CType>(8)}};
// Check output dims.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
const DataVec input_data{
{"input1", test::AsTensor<CType>({CType(3), CType(6)})},
{"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
// After broadcasting first input becomes {3, 6, 3, 6} and second input
// becomes {2, 3, 2, 3}.
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32,
/*batch_size=*/2);
if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({5, 8, 6, 9, 5, 8, 6, 9})));
} else if (node_def.op() == "Sub") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1), CType(4), CType(0), CType(3)));
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({1, 4, 0, 3, 1, 4, 0, 3})));
} else if (node_def.op() == "Mul") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(6), CType(12), CType(9), CType(18)));
ElementsAreArray(
CastTestVector<int, CType>({6, 12, 9, 18, 6, 12, 9, 18})));
} else if (node_def.op() == "Div") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "RealDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "Minimum") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(2), CType(2), CType(3), CType(3)));
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({2, 2, 3, 3, 2, 2, 3, 3})));
} else if (node_def.op() == "Maximum") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(3), CType(6), CType(3), CType(6)));
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({3, 6, 3, 6, 3, 6, 3, 6})));
} else if (node_def.op() == "Pow") {
ExpectArrayNear(
std::vector<CType>{CType(9), CType(36), CType(27), CType(216)},
CastTestVector<int, CType>({9, 36, 27, 216, 9, 36, 27, 216}),
GetSpanForData<CType>(output_data[0]));
} else {
ASSERT_TRUE(false);
@ -2287,58 +2125,48 @@ TEST_F(OpConverterTest, ConvertBinary) {
"both input as constant at: my_add");
}
// Test BinaryTensorOpWeight() without broadcasting.
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
// Test BinaryTensorOpWeight() with channel-wise broadcasting.
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
// Test BinaryTensorOpWeight() with uniformly broadcasting.
TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
// Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
// Unsupported op.
TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
// Rank of input tensor dimension <3.
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
// Broadcast on batch dimension, should fail.
TestBinaryTensorOpWeightFallback<ops::Add>(
this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
"Unsupported binary op broadcast scheme for op my_binary",
/*input_batch_size=*/2);
// Incompatible dims with per-channel mode.
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
// Incompatible dims.
TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
// Test BinaryTensorOpTensor() with broadcasting.
TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
// Test combinations of tensor vs weight inputs (except when both inputs are
// weights).
for (const bool operand_1_is_tensor : {true, false}) {
for (const bool operand_2_is_tensor : {true, false}) {
if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
// FP32 tests
TestBinaryOp<ops::Add, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Sub, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Mul, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Div, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::RealDiv, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Minimum, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Maximum, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Pow, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
// FP16 tests
// TODO(tmorris): Use templates to avoid duplication.
TestBinaryOp<ops::Add, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Sub, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Mul, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Div, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::RealDiv, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Minimum, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Maximum, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Pow, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
}
}
}
TEST_F(OpConverterTest, ConvertQuantize) {
@ -2583,7 +2411,6 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
// implementation that, the extra output classes that are outside of the
// range specified by valid_detections[i] are not zeros but -1s.
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
const int batch_size = 1;
for (int i = 0; i < kCombinedNMSOKCases; ++i) {
Reset();

View File

@ -14,6 +14,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
@ -32,9 +34,9 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
// TODO(sami): Remove VLOG messages once the code matures
using absl::AsciiStrToUpper;
using absl::StrAppend;
using absl::StrCat;
using str_util::Uppercase;
Status TRTOptimizationPass::Init(
const RewriterConfig_CustomGraphOptimizer* config) {
@ -67,7 +69,7 @@ Status TRTOptimizationPass::Init(
}
if (params.count("precision_mode")) {
TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
Uppercase(params.at("precision_mode").s()), &precision_mode_));
AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
}
if (params.count("use_calibration")) {
use_calibration_ = params.at("use_calibration").b();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
TF_ASSERT_OK(RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = context_->mutable_output(0);
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
// string type output will remain on CPU, so we're not using GetOutput() here.
EXPECT_EQ("my_serialized_str",
context_->mutable_output(0)->scalar<string>()());
}
} // namespace tensorrt

View File

@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = OpsTestBase::context_->mutable_output(0);
const auto& tensor_map = output->flat<TypeParam>();
std::vector<TypeParam> output_data(tensor_map.size());
ASSERT_EQ(0, cudaDeviceSynchronize());
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
Tensor* output = OpsTestBase::GetOutput(0);
EXPECT_THAT(
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
output->NumElements()),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
}
} // namespace tensorrt

View File

@ -681,31 +681,33 @@ Status SegmentGraph(const Graph* tf_graph,
<< " with parent=" << segment_root << ":" << s;
}
// Don't use small segments.
if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
const int num_effective_nodes = std::count_if(
segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
static auto noops =
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
return noops->count(node->type_string()) == 0;
});
// Don't use segments whose number of effective nodes is small.
if (num_effective_nodes < options.minimum_segment_size) {
VLOG(1) << "Segment " << segments->size() << " has only "
<< segment_nodes.size() << " nodes, dropping";
<< num_effective_nodes << " effective nodes, dropping";
continue;
}
// TODO(sami): Make segmenter placement aware once trtscopes are in place
const auto& dev_itr = device_maps.find(segment_root);
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
VLOG(1) << "No device assigned to segment " << segments->size();
segments->emplace_back(std::make_pair(segment_nodes, string()));
} else if (dev_itr->second.size() > 1) {
string s("Segment ");
StrAppend(&s, segments->size(), " has multiple devices attached: ");
string s = StrCat("Segment ", segments->size(),
" has multiple devices attached: ");
for (const auto& dev : dev_itr->second) {
StrAppend(&s, dev, ", ");
}
LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
segments->emplace_back(
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
} else {
segments->emplace_back(
std::make_pair(segment_nodes, *(dev_itr->second.begin())));
LOG(WARNING) << s;
}
segments->emplace_back(segment_nodes);
}
if (VLOG_IS_ON(1)) {
for (const auto& d : device_maps) {

View File

@ -31,10 +31,8 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
// Vector of segments, each entry contains a set of node pointers and a device
// name in the segment.
using SegmentNodesVector =
std::vector<std::pair<std::set<const Node*>, string>>;
// Vector of segments, each entry contains a set of node pointers.
using SegmentNodesVector = std::vector<std::set<const Node*>>;
struct SegmentOptions {
// Segment must contain at least this many nodes.

View File

@ -77,7 +77,7 @@ class SegmentTest : public ::testing::Test {
EXPECT_EQ(expected_segments.size(), segments.size());
for (int i = 0; i < segments.size(); ++i) {
std::set<string> segment_node_names;
for (const Node* node : segments[i].first) {
for (const Node* node : segments[i]) {
segment_node_names.insert(node->name());
}
const auto& expected = expected_segments[i];
@ -262,6 +262,23 @@ TEST_F(SegmentTest, BigIfElse) {
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
}
TEST_F(SegmentTest, IdentityOps) {
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
Graph g(OpRegistry::Global());
TF_EXPECT_OK(s.ToGraph(&g));
const std::set<string> all_identities = {"identity0", "identity1",
"identity2", "identity3"};
// Identity ops are not counted as effective ops in the segment, so no segment
// will be formed in this case.
RunTest(&g, all_identities, all_identities, all_identities, {});
}
} // namespace test
} // namespace segment
} // namespace tensorrt

View File

@ -1,7 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
package(
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "internal",
packages = [
@ -23,15 +26,12 @@ package_group(
],
)
package(
default_visibility = [":internal"],
)
load(
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
cc_library(
name = "tf2xla_supported_ops_lib",
@ -67,6 +67,19 @@ xla_proto_library(
],
)
# A proto library that is minimal in size and dependencies for platforms like Android.
tf_portable_proto_library(
name = "portable_tf2xla_proto",
config_string = "allow_all:true",
header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"],
portable_deps = ["//tensorflow/core:android_proto_lib"],
proto_deps = [
":tf2xla_proto",
"//tensorflow/core:protos_all_cc",
],
visibility = ["//visibility:public"],
)
xla_py_proto_library(
name = "tf2xla_py",
has_services = False,

View File

@ -1,9 +1,8 @@
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
tf_gen_op_wrapper_cc(

View File

@ -918,10 +918,16 @@ string Conditional::name() const {
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
int port) {
NodeBuilder id_builder(replacee->name(), "Identity");
id_builder.Input(if_node, port);
string outside_compilation;
if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {
id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
}
Node* id;
TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
.Input(if_node, port)
.Finalize(graph_, &id));
TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
return Status::OK();

View File

@ -247,8 +247,8 @@ Status FunctionalizeControlFlowPass::Run(
// multiple times, and we want to avoid functionalize it again.
static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
new std::map<string, string>{
// TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"TPUReplicate", "computation"},
// _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
{"_TPUReplicate", "computation"},
// XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
{"XlaLaunch", "function"},
};

View File

@ -1,9 +1,8 @@
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
)
tf_kernel_library(
@ -195,6 +194,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:training_ops",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -43,7 +43,7 @@ class AssertOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(AssertOp);
};
REGISTER_XLA_OP(Name("Assert"), AssertOp);
REGISTER_XLA_OP(Name("Assert").CompilationOnly(), AssertOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -39,7 +39,10 @@ class FusedBatchNormOp : public XlaOpKernel {
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
}
void Compile(XlaOpKernelContext* ctx) override {
void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); }
protected:
virtual void CompileImpl(XlaOpKernelContext* ctx) {
xla::PrimitiveType input_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
@ -116,8 +119,29 @@ class FusedBatchNormOp : public XlaOpKernel {
bool is_on_gpu_;
};
class FusedBatchNormOpV3 : public FusedBatchNormOp {
public:
explicit FusedBatchNormOpV3(OpKernelConstruction* ctx)
: FusedBatchNormOp(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
FusedBatchNormOp::CompileImpl(ctx);
if (!ctx->status().ok()) {
return;
}
ctx->SetConstantOutput(5, Tensor());
}
private:
float epsilon_;
TensorFormat data_format_;
bool is_training_;
bool is_on_gpu_;
};
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
class FusedBatchNormGradOp : public XlaOpKernel {
public:
@ -233,6 +257,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp);
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/util/tensor_format.h"
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
xla::XlaOp conv =
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims, depth);
// Feature group convolution, will end up with the kernel_size change more
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
conv_dims.back() = depth;
conv_dims.push_back(kernel_size);
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
conv_dims.pop_back();
conv_dims.back() *= kernel_size;
conv = xla::Reshape(conv, conv_dims);
ctx->SetOutput(0, conv);
}

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@ -20,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
class GatherOp : public XlaOpKernel {
public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
// Set batch_dims_ to 0 if the attribute does not exist.
if (context->HasAttr("batch_dims")) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
} else {
batch_dims_ = 0;
}
}
void Compile(XlaOpKernelContext* context) override {
xla::XlaBuilder* builder = context->builder();
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
auto indices = context->Input(1);
auto indices_shape = context->InputShape(1);
int64 axis = 0;
absl::optional<int64> axis;
if (context->num_inputs() == 3) {
const TensorShape axis_shape = context->InputShape(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
errors::InvalidArgument("axis must be int32 or int64"));
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
int64 axis_input;
OP_REQUIRES_OK(context,
context->ConstantInputAsIntScalar(2, &axis_input));
const auto params_dims = input_shape.dims();
OP_REQUIRES(
context, -params_dims <= axis && axis < params_dims,
errors::InvalidArgument("Expected axis in the range [", -params_dims,
", ", params_dims, "), but got ", axis));
if (axis < 0) {
axis += params_dims;
OP_REQUIRES(context,
-params_dims <= axis_input && axis_input < params_dims,
errors::InvalidArgument("Expected axis in the range [",
-params_dims, ", ", params_dims,
"), but got ", axis_input));
if (axis_input < 0) {
axis_input += params_dims;
}
axis = axis_input;
}
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices_shape.dims() + batch_dims_;
}
axis = axis.value_or(batch_dims_);
OP_REQUIRES(context,
batch_dims_ >= -indices_shape.dims() &&
batch_dims_ < indices_shape.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices_shape.dims(), ", ",
indices_shape.dims(), "), but got ",
batch_dims_));
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(input) (",
input_shape.dims(), ")."));
OP_REQUIRES(context, *axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than or equal to ",
"axis (", *axis, ")."));
}
axis = axis.value_or(0);
DataType index_type = input_type(1);
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64"));
xla::XlaOp gather;
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, axis,
/*indices_are_nd=*/false, input_type(0), index_type,
builder, &gather));
if (batch_dims_ > 0) {
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
} else {
// XlaGather() manages degenerate cases, like empty-indices, which are
// error conditions and caught above if batch_dims is not 0.
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
/*indices_are_nd=*/false, input_type(0),
index_type, context->builder(), &gather));
}
context->SetOutput(0, gather);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
// The number of batch dimensions, as passed in the batch_dims attribute.
// It must be less than rank(indices).
int32 batch_dims_ = 0;
};
REGISTER_XLA_OP(Name("Gather"), GatherOp);

View File

@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
// Calculate in each row of `predictions`, how many values are larger than
// the value of target class. Then return the result whether the count <= k,
// the value of target class. Then return the result whether the count < k,
// which indicates the target is in topk.
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0});
xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2);
xla::XlaOp num_ge_r1 = xla::Reduce(
xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
xla::XlaOp num_gt_r1 = xla::Reduce(
one_hot_r2, zero_r0,
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
xla::XlaOp result =
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k));
xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
xla::IsFinite(targets_values_r1));
context->SetOutput(0, result);
}

View File

@ -67,9 +67,9 @@ class MatMulOp : public XlaOpKernel {
OP_REQUIRES(ctx,
a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
errors::InvalidArgument("Matrix size-compatible: In[0]: ",
a_shape.DebugString(), ", In[1]: ",
b_shape.DebugString()));
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", a_shape.DebugString(),
", In[1]: ", b_shape.DebugString()));
xla::XlaOp a = ctx->Input(0);
xla::XlaOp b = ctx->Input(1);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <numeric>
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -22,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
@ -77,5 +79,58 @@ class SelectOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Select"), SelectOp);
class SelectOpV2 : public XlaOpKernel {
public:
explicit SelectOpV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape cond_shape = ctx->InputShape(0);
const TensorShape then_shape = ctx->InputShape(1);
const TensorShape else_shape = ctx->InputShape(2);
// Compute the output shape from the broadcast of the two data inputs, with
// the broadcast of the conditional.
// Then Broadcast all three inputs to the output shape and emit a select.
BCast bcast_then_else(BCast::FromShape(then_shape),
BCast::FromShape(else_shape),
/*fewer_dims_optimization=*/false);
if (!bcast_then_else.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", then_shape.DebugString(), " vs. ",
else_shape.DebugString()));
return;
}
BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape),
/*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ",
BCast::ToShape(bcast_then_else.output_shape()).DebugString(), " vs. ",
cond_shape.DebugString()));
return;
}
auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_cond.status());
auto cond_handle = bcasted_cond.ValueOrDie();
auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_then.status());
auto then_handle = bcasted_then.ValueOrDie();
auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape());
OP_REQUIRES_OK(ctx, bcasted_else.status());
auto else_handle = bcasted_else.ValueOrDie();
ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOpV2);
};
REGISTER_XLA_OP(Name("SelectV2"), SelectOpV2);
} // namespace
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for softmax.
#include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -145,23 +146,36 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape logits_shape = ctx->InputShape(0);
const TensorShape labels_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
errors::InvalidArgument(
"logits and labels must be same size: logits_size=",
logits_shape.DebugString(),
" labels_size=", labels_shape.DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional"));
// As we already tested that both inputs have the same shape no need to
// check that "labels" is a matrix too.
const DataType type = input_type(0);
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
const TensorShape logits_shape = ctx->InputShape(0);
const TensorShape labels_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_shape),
errors::InvalidArgument("labels must be 2-dimensional"));
// Confirm that any necessary broadcasting to make the shapes the same will
// succeed.
for (int dim = 0; dim < 2; dim++) {
OP_REQUIRES(
ctx,
labels_shape.dim_size(dim) == 1 ||
logits_shape.dim_size(dim) == labels_shape.dim_size(dim),
errors::InvalidArgument("logits and labels must be same size after "
"broadcasting of labels: logits_size=",
logits_shape.DebugString(),
" labels_size=", labels_shape.DebugString()));
}
if (!logits_shape.IsSameSize(labels_shape)) {
auto labels_or = BroadcastTo(labels, logits_shape.dim_sizes());
OP_REQUIRES_OK(ctx, labels_or.status());
labels = labels_or.ConsumeValueOrDie();
}
xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);

View File

@ -1,9 +1,8 @@
# Utilities for building XLA computations.
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
licenses = ["notice"], # Apache 2.0
)
# Filegroup used to collect source files for dependency checking.

View File

@ -1,9 +1,8 @@
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",

View File

@ -1,9 +1,8 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
load(

View File

@ -550,6 +550,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
};
GraphOptimizer::Options graph_optimizer_options;
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
graph_optimizer_options.inline_multi_device_functions = true;
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, graph_optimizer_options);

View File

@ -116,9 +116,12 @@ class XlaOpRegistry {
// If we should cluster operations returning DT_VARIANT.
bool cluster_variant_ops = false;
// Whether ops known to be slow or to have correctness issues should be
// Whether ops known to be slow should be auto-clustered.
bool cluster_slow_ops = false;
// Whether ops known to have numerical accuracy issues should be
// auto-clustered.
bool cluster_slow_and_inaccurate_ops = false;
bool cluster_inaccurate_ops = false;
};
// Registers an XLA backend. `compilation_device_name` is the name of the

View File

@ -1,6 +1,7 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
@ -575,6 +576,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":types",
"@com_google_absl//absl/strings",
],
)
@ -881,6 +883,26 @@ tf_cc_test(
],
)
cc_library(
name = "refcounting_hash_map",
hdrs = ["refcounting_hash_map.h"],
deps = [
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "refcounting_hash_map_test",
srcs = ["refcounting_hash_map_test.cc"],
deps = [
":refcounting_hash_map",
":test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.

Some files were not shown because too many files have changed in this diff Show More