Merge branch 'master' into eager_op_rewrite_registration

This commit is contained in:
Mahmoud Abuzaina 2019-06-27 11:18:41 -07:00 committed by GitHub
commit 4e48f0664c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2136 changed files with 103041 additions and 56248 deletions

View File

@ -9,6 +9,12 @@ build:android_arm --fat_apk_cpu=armeabi-v7a
build:android_arm64 --config=android
build:android_arm64 --cpu=arm64-v8a
build:android_arm64 --fat_apk_cpu=arm64-v8a
build:android_x86 --config=android
build:android_x86 --cpu=x86
build:android_x86 --fat_apk_cpu=x86
build:android_x86_64 --config=android
build:android_x86_64 --cpu=x86_64
build:android_x86_64 --fat_apk_cpu=x86_64
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
@ -124,6 +130,26 @@ build --define=INCLUDEDIR=$(PREFIX)/include
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
# Options when using remote execution
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --define=EXECUTOR=remote
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_accept_cached=true
build:rbe --remote_cache=remotebuildexecution.googleapis.com
build:rbe --remote_executor=remotebuildexecution.googleapis.com
build:rbe --remote_local_fallback=false
build:rbe --remote_timeout=600
build:rbe --spawn_strategy=remote
build:rbe --strategy=Genrule=remote
build:rbe --strategy=Closure=remote
build:rbe --strategy=Javac=remote
build:rbe --strategy=TestRunner=remote
build:rbe --tls_enabled
test:rbe --test_env=USER=anon
# Default options should come above this line
# Options from ./configure

View File

@ -1,12 +1,13 @@
# Where component owners are known, add them here.
/tensorflow/c/eager @jaingurav @alextp
/tensorflow/core/common_runtime/eager @jaingaurav @alextp
/tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
/tensorflow/java/ @asimshankar
/tensorflow/python/debug @caisq
/tensorflow/python/eager @jaingurav @alextp
/tensorflow/python/tools/api/generator/ @annarev
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust
@ -25,7 +26,7 @@
/tensorflow/contrib/data/ @mrry
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
/tensorflow/contrib/eager @alextp @asimshankar
/tensorflow/contrib/eager @jaingurav @alextp
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
/tensorflow/contrib/ffmpeg/ @fredbertsch
/tensorflow/contrib/framework/ @ebrevdo

View File

@ -168,11 +168,11 @@ There are two ways to run TensorFlow unit tests.
[GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile)
for the required packages. Alternatively, use the said
[Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g.,
`tensorflow/tensorflow:nightly-devel` and
`tensorflow/tensorflow:nightly-devel-gpu` for development to avoid
installing the packages directly on your system (in which case remember to
change directory from `/root` to `/tensorflow` once you get into the running
container so `bazel` can find the `tensorflow` workspace).
`tensorflow/tensorflow:devel` and `tensorflow/tensorflow:devel-gpu` for
development to avoid installing the packages directly on your system (in
which case remember to change directory from `/root` to `/tensorflow` once
you get into the running container so `bazel` can find the `tensorflow`
workspace).
Once you have the packages installed, you can run a specific unit test in
bazel by doing as follows:

View File

@ -188,7 +188,7 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017, The TensorFlow Authors.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -116,7 +116,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux s390x Nightly** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)

View File

@ -4,14 +4,20 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file"
http_archive(
name = "io_bazel_rules_closure",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
],
)
# Load tf_repositories() before loading dependencies for other repository so
# that dependencies like com_google_protobuf won't be overridden.
load("//tensorflow:workspace.bzl", "tf_repositories")
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_repositories()
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
@ -83,15 +89,15 @@ swift_rules_dependencies()
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
check_bazel_version_at_least("0.19.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
load("//third_party/android:android_configure.bzl", "android_configure")
android_configure(name="local_config_android")
load("@local_config_android//:android.bzl", "android_workspace")
android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
# If a target is bound twice, the later one wins, so we have to do tf bindings
# at the end of the WORKSPACE file.
load("//tensorflow:workspace.bzl", "tf_bind")
tf_bind()
http_archive(
name = "inception_v1",

20
configure.cmd Normal file
View File

@ -0,0 +1,20 @@
:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
::
:: Licensed under the Apache License, Version 2.0 (the "License");
:: you may not use this file except in compliance with the License.
:: You may obtain a copy of the License at
::
:: http://www.apache.org/licenses/LICENSE-2.0
::
:: Unless required by applicable law or agreed to in writing, software
:: distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
:: WARRANTIES OR CONDITIONS OF ANY KIND< either express or implied. See the
:: License for the specific language governing permissions and limitations under
:: the License.
@echo off
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python %configure_dir%\configure.py %* || ( exit /b )
echo Configuration finished

View File

@ -49,6 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '0.24.1'
_TF_MAX_BAZEL_VERSION = '0.26.1'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -234,7 +236,9 @@ def setup_python(environ_cp):
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
_ = get_python_major_version(python_bin_path)
python_major_version = get_python_major_version(python_bin_path)
if python_major_version == '2':
write_to_bazelrc('build --host_force_python=PY2')
# Convert python path to Windows style before writing into bazel.rc
if is_windows() or is_cygwin():
@ -1391,7 +1395,8 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.24.1', '0.26.1')
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1423,7 +1428,7 @@ def main():
if is_ppc64le():
write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
xla_enabled_by_default = is_linux()
xla_enabled_by_default = is_linux() or is_macos()
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
@ -1582,6 +1587,7 @@ def main():
config_info_line(
'dynamic_kernels',
'(Experimental) Build kernels into separate shared objects.')
config_info_line('v2', 'Build TensorFlow 2.x instead of 1.x.')
print('Preconfigured Bazel build configs to DISABLE default on features:')
config_info_line('noaws', 'Disable AWS S3 filesystem support.')

View File

@ -529,6 +529,10 @@ tf_cc_shared_object(
linkopts = select({
"//tensorflow:macos": [],
"//tensorflow:windows": [],
"//tensorflow:freebsd": [
"-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)",
"-lexecinfo",
],
"//conditions:default": [
"-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)",
],

View File

@ -149,4 +149,5 @@ if hasattr(_current_module, 'keras'):
optimizers = keras.optimizers
initializers = keras.initializers
compat.v2.compat.v1 = compat.v1
# pylint: enable=undefined-variable

View File

@ -141,4 +141,6 @@ try:
vars()['__all__'].remove('compiler')
except NameError:
pass
compat.v2.compat.v1 = compat.v1
# pylint: enable=undefined-variable

View File

@ -77,7 +77,10 @@ tf_cuda_library(
"//tensorflow/core:op_gen_lib",
"//tensorflow/core/distributed_runtime:server_lib",
],
}) + [":tf_status_internal"],
}) + [
":tf_status_internal",
":tf_tensor_internal",
],
)
cc_library(
@ -211,9 +214,10 @@ cc_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -221,6 +225,24 @@ cc_library(
}),
)
tf_cuda_library(
name = "tf_tensor_internal",
hdrs = [
"tf_tensor.h",
"tf_tensor_internal.h",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
],
}),
)
tf_cuda_library(
name = "c_api_experimental",
srcs = [
@ -263,7 +285,7 @@ tf_cuda_library(
hdrs = ["tf_status_helper.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_no_xla",
":tf_status",
":tf_status_internal",
"//tensorflow/core:lib",
],
@ -329,17 +351,16 @@ tf_cuda_library(
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
deps = [
":tf_status",
":tf_status_helper",
] + select({
"//tensorflow:android": [
":c_api_no_xla",
":c_api_internal",
":tf_status_helper",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_no_xla",
":c_api_internal",
":tf_status_helper",
":tf_tensor",
"//tensorflow/core:framework",
],
@ -357,6 +378,8 @@ tf_cuda_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_datatype",
":tf_status",
":tf_status_helper",
] + select({
"//tensorflow:android": [
@ -365,7 +388,7 @@ tf_cuda_library(
"//conditions:default": [
"//tensorflow/core:framework",
],
}) + [":c_api_internal"],
}),
)
# -----------------------------------------------------------------------------

View File

@ -363,7 +363,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
}
*graph_def.mutable_library() = graph.flib_def().ToProto();
session->graph->mu.unlock();
status->status = session->session->Extend(graph_def);
status->status = session->session->Extend(std::move(graph_def));
if (TF_GetCode(status) != TF_OK) {
// Contract is we always delete input_values[i].
return false;

View File

@ -737,13 +737,14 @@ int TF_PickUnusedPortOrDie() {
}
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
void* data, size_t len) {
void* data, size_t len,
TF_Status* status) {
auto dtype = static_cast<tensorflow::DataType>(data_type);
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
return new TFE_TensorHandle(tensor);
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
}
namespace {
@ -994,3 +995,23 @@ TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(
<< handle->DebugString();
return ret;
}
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
TFE_ContextMirroringPolicy policy) {
options->mirroring_policy = policy;
}
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
ctx->context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
// Note: this function looks up a thread local policy. So it should be called in
// the appropriate client thread. In particular, in async mode, it may not be
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextMirroringPolicy>(
ctx->context->GetMirroringPolicy());
}

View File

@ -285,7 +285,7 @@ TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void);
// Fast path method that makes constructing a single scalar tensor require less
// overhead and copies.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
TF_DataType data_type, void* data, size_t len);
TF_DataType data_type, void* data, size_t len, TF_Status* status);
// Specify the server_def that enables collective ops.
// This is different to the above function in that it doesn't create remote

View File

@ -30,6 +30,7 @@ limitations under the License.
// clang-format on
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/framework/op_gen_lib.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
@ -53,14 +54,6 @@ class ServerInterface;
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_Tensor {
~TF_Tensor();
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
struct TF_SessionOptions {
tensorflow::SessionOptions options;
};
@ -193,15 +186,6 @@ struct TF_Server {
namespace tensorflow {
class TensorCApi {
public:
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
TensorBuffer* buf) {
return Tensor(static_cast<DataType>(type), shape, buf);
}
};
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);

View File

@ -26,6 +26,7 @@ tf_cuda_library(
srcs = [
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
],
hdrs = ["c_api.h"],
@ -81,6 +82,7 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
@ -274,7 +276,6 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"c_api_experimental.h",
"*test*",
],
),

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
@ -89,12 +90,6 @@ bool IsCPU(const tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
bool IsXLA(const tensorflow::Device* d) {
if (d == nullptr) return false;
const auto& device_type = d->attributes().device_type();
return device_type.find("XLA") != std::string::npos;
}
string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
@ -241,10 +236,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
*base_request.add_cluster_device_attributes() = da;
}
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
@ -369,7 +364,7 @@ void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->policy = policy;
options->device_placement_policy = policy;
}
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
@ -392,7 +387,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
@ -406,7 +402,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options, opts->policy,
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
}
@ -476,7 +473,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle(tensor);
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -569,14 +566,15 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
"The passed in handle is a nullptr");
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
tensorflow::TensorHandle* handle = h->handle;
if (h->handle->IsRemote()) {
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
h->handle, h->handle->Context(),
h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
handle, handle->Context(), handle->Context()->HostCPU()->name().c_str(),
false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
}
@ -585,28 +583,23 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
h_cpu->Unref();
return nullptr;
}
} else {
status->status = h->handle->Tensor(&t);
if (!status->status.ok()) return nullptr;
if (!IsCPU(h->handle->device())) {
status->status = h->handle->CopyToDevice(
h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
if (!status->status.ok()) {
return nullptr;
}
status->status = h_cpu->Tensor(&t);
if (!status->status.ok()) {
h_cpu->Unref();
return nullptr;
}
}
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
if (h_cpu != nullptr) {
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
h_cpu->Unref();
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle->device())) {
const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src);
if (!status->status.ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle->Context();
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
return retval;
}
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
@ -924,9 +917,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
tensorflow::TensorHandle* handle;
tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
device_name, &handle);
device_name, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
}
@ -971,8 +964,9 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
return new TFE_TensorHandle(t);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return TFE_TensorHandle::CreateLocalHandle(t, status);
}
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
@ -996,9 +990,9 @@ TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
// TensorHandles created by PyFuncOp lack context and therefore could
// not be copied.
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
tensorflow::TensorHandle* handle;
tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::EagerCopyToDevice(
h->handle, h->handle->Context(), "CPU:0", &handle);
h->handle, h->handle->Context(), "CPU:0", false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
} else {

View File

@ -60,6 +60,8 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
// LINT.IfChange
// Note: Keep in sync with internal copy of enum in eager/context.h.
typedef enum TFE_ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
@ -72,6 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Placement policy which silently copies int32 tensors but not other dtypes.
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
// Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetAsyncForThread.
@ -465,7 +468,8 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status);
#endif
#endif // TENSORFLOW_C_EAGER_C_API_H_

View File

@ -42,10 +42,8 @@ bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
TF_Buffer* buf, TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (TF_GetCode(status) != TF_OK) return;
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
TF_Status* status) {
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
@ -102,6 +100,28 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -40,8 +40,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
// The output string is a binary string of tensorflow.tpu.Trace. User can write
// the string to file for offline analysis by tensorboard.
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx,
TFE_Profiler* profiler,
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
TF_Buffer* buf,
TF_Status* status);
@ -88,6 +87,16 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
@ -311,6 +320,29 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
// LINT.IfChange
// Note: Keep in sync with internal copy of enum in eager/context.h.
typedef enum TFE_ContextMirroringPolicy {
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
// copies with their own lifetime.
TFE_MIRRORING_NONE = 0,
// Mirroring any remote tensor handles, associating them with the lifetime of
// the local TensorHandle.
TFE_MIRRORING_ALL = 1,
} TFE_ContextMirroringPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
// Sets a thread-local mirroring policy. After this call, other calls to
// TFE_Execute in the same thread will use the mirroring policy specified here
// instead of the mirroring policy used to construct the context. This has no
// effect on the mirroring policy used by other program threads.
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context*, TFE_ContextMirroringPolicy);
// Returns the mirroring policy to be used by this context in the current
// thread.
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context*);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -72,7 +72,11 @@ void ExecuteWithProfiling(bool async) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Buffer* profiler_result = TF_NewBuffer();
TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status);
if (async) {
TFE_ContextAsyncWait(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
TFE_DeleteProfiler(profiler);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
profiler::Trace profile_proto;

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <memory>
#include <queue>
#include <string>
#include <thread>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -53,19 +54,24 @@ struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextDevicePlacementPolicy device_placement_policy{
TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
};
struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy, bool async,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
@ -75,12 +81,25 @@ struct TFE_Context {
};
struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* handle)
: handle(handle) {}
explicit TFE_TensorHandle(const tensorflow::Tensor& t)
: handle(new tensorflow::TensorHandle(t)) {}
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
: handle(new tensorflow::TensorHandle(t, d, nullptr)) {}
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(handle);
}
static tensorflow::Status CreateLocalHandle(const class tensorflow::Tensor& t,
tensorflow::Device* d,
TFE_TensorHandle** h) {
tensorflow::TensorHandle* handle;
TF_RETURN_IF_ERROR(
tensorflow::TensorHandle::CreateLocalHandle(t, d, nullptr, &handle));
*h = new TFE_TensorHandle(handle);
return tensorflow::Status::OK();
}
tensorflow::TensorHandle* handle;
@ -92,7 +111,7 @@ struct TFE_TensorHandle {
};
struct TFE_TensorDebugInfo {
TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
: dev_dims(dims) {}
// Fully-padded, minor-to-major.
@ -124,7 +143,7 @@ struct TFE_ProfilerContext {
};
struct TFE_Profiler {
TFE_Profiler(TFE_ProfilerContext* ctx) {
explicit TFE_Profiler(TFE_ProfilerContext* ctx) {
profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context);
}
@ -211,7 +230,7 @@ struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
};
struct TFE_MonitoringBuckets {
TFE_MonitoringBuckets(
explicit TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
@ -266,7 +285,7 @@ struct TFE_TraceContext {
std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>* input_tensors =
nullptr;
TFE_TraceContext(TF_Graph* graph) : graph(graph) {}
explicit TFE_TraceContext(TF_Graph* graph) : graph(graph) {}
~TFE_TraceContext() {
delete input_tensors;

View File

@ -19,9 +19,11 @@ limitations under the License.
// maintains the data structures required to do so.
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
@ -99,6 +101,12 @@ class VSpace {
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0;
// Looks up the ID of a Gradient.
virtual int64 TensorId(Gradient* tensor) const = 0;
// Converts a Gradient to a TapeTensor.
virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
// Marks the following gradient as a result so it's not consumed by backward
// functions.
virtual void MarkAsResult(Gradient* gradient) const = 0;
@ -129,7 +137,7 @@ class GradientTape {
void Watch(int64 tensor_id);
void RecordOperation(
const string& op_type, std::vector<TapeTensor>& output_tensors,
const string& op_type, const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
@ -165,6 +173,122 @@ class GradientTape {
bool persistent_;
};
// Computes Jacobian-vector products using forward-mode automatic
// differentiation.
//
// While GradientTape's RecordOperation is trivial, ForwardAccumulator's
// Accumulate runs the gradient computation immediately.
//
// Keeps references to Tensors watched via Watch and computed in Accumulate
// corresponding to output_tensors, and releases these references in its
// destructor. However, waiting until the destructor runs loses the memory
// efficiency of forward-mode autodiff. Instead, language bindings should call
// DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
// Tensor passed to Accumulate goes out of scope.
//
// Not thread-safe.
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class ForwardAccumulator {
public:
// Does not take ownership of `vspace`, which must outlive the
// ForwardAccumulator.
explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
virtual ~ForwardAccumulator() {
for (auto accumulated : accumulated_gradients_) {
vspace_.DeleteGradient(accumulated.second);
}
}
// Tell the forward accumulator to watch tensor_id, with a Tensor tangent
// vector `tangent` of matching shape and dtype. Tangents are the "vector" in
// "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
// FetchJVP for it would return `tangent`.
void Watch(int64 tensor_id, Gradient* tangent);
// Removes the gradient associated with tensor_id. Should be called when the
// Tensor associated with `tensor_id` is deleted.
void DeleteGradient(int64 tensor_id);
// Runs forward autodiff. Should be called whenever a new operation is
// available and the accumulator is active.
//
// Like GradientTape::RecordOperation, this method takes the operation type
// `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
// `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
// redundant but taken as arguments to avoid repeatedly fetching these values
// between calls to ShouldRecord and Accumulator), and its outputs
// (`output_tensors`).
//
// Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
// immediately. It stores the results, which feed into Accumulate for future
// operations and may be fetched by calling FetchJVP. ForwardAccumulator
// maintains a reference to these JVPs: if an `output_tensors` Tensor is
// deleted, `DeleteGradient` should be called as soon as possible to free the
// (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
// will release remaining references.
//
// This method is not thread-safe (and in general ForwardAccumulator is not
// thread-safe).
Status Accumulate(
const string& op_type, const std::vector<TapeTensor>& input_tensors,
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
// a nullptr if none is available.
//
// Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
// return value. The caller should increment the reference count before
// deleting the ForwardAccumulator or calling DeleteGradient if keeping a
// persistent reference to a non-null result.
Gradient* FetchJVP(int64 tensor_id);
// Indicates whether the forward accumulator should run on an operation with
// the specified inputs and dtypes.
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
private:
// Helper for Accumulate: uses a GradientTape to compute forward gradients
// from a backward gradient function. Fills `out_grads` corresponding to
// `output_tensors`. `out_grads` must not be null.
//
// Executes the backward function in order to trace its gradient, which will
// waste computation if executing eagerly (when graph building the unneeded
// computation is pruned). Temporarily sets `backward_tape_` so that
// Accumulate will forward op executions to the tape while the backward
// function is running; this effectively adds the backward tape to the active
// set (but does not require complicated callbacks to the language bindings).
Status ForwardpropFromTape(
const std::vector<TapeTensor>& output_tensors,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter,
const std::vector<Gradient*>& in_grads,
std::vector<Gradient*>* out_grads);
// Maps from tensor IDs to corresponding JVPs.
std::unordered_map<int64, Gradient*> accumulated_gradients_;
// Not owned; provides operations on Tensors which are currently only
// available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Set temporarily while in the Accumulate method; if backward_tape_ is not
// nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function.
//
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
// keeps ownership.
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
// While the Accumulate method is running (accumulating_ is True), any op
// executions not forwarded to backward_tape_ should be ignored.
bool accumulating_;
};
// Template instantiations here
inline bool IsDtypeTrainable(DataType dtype) {
@ -206,7 +330,7 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
const string& op_type, std::vector<TapeTensor>& output_tensors,
const string& op_type, const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
@ -691,6 +815,228 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
return Status::OK();
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
if (backward_tape_ != nullptr) {
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
// we should also delegate ShouldRecord.
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
}
if (accumulating_) {
return false;
}
for (int i = 0; i < tensor_ids.size(); ++i) {
if (accumulated_gradients_.find(tensor_ids[i]) !=
accumulated_gradients_.end()) {
if (IsDtypeTrainable(dtypes[i])) {
return true;
}
}
}
return false;
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status
ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
const std::vector<TapeTensor>& output_tensors,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter,
const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
/* This function is approximately equivalent to this Python code:
forwardprop_aids = tf.ones_like(output_tensors)
with tf.GradientTape() as g:
g.watch(forwardprop_aids)
grad = backward_function(forwardprop_aids)
forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
accumulated_gradients_[ID(output_tensors)] = forward_grads
*/
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
backward_tape_ = tape.get();
auto pop_backward_tape =
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
std::vector<Gradient*> forwardprop_aids;
std::vector<int64> sources;
std::unordered_set<int64> sources_set;
sources.reserve(output_tensors.size());
for (const TapeTensor& output_tensor : output_tensors) {
// Ownership of `aid` transferred to CallBackwardFunction below.
Gradient* aid = vspace_.Ones(output_tensor);
forwardprop_aids.push_back(aid);
int64 aid_id = vspace_.TensorId(aid);
sources.push_back(aid_id);
sources_set.insert(aid_id);
tape->Watch(aid_id);
}
std::vector<Gradient*> grad;
auto delete_grad = gtl::MakeCleanup([&grad, this] {
for (Gradient* tensor : grad) {
this->vspace_.DeleteGradient(tensor);
}
});
{
std::vector<int64> unneeded_gradients;
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
backward_function(backward_function_getter(),
backward_function_deleter);
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
}
// Stop the tape from recording
pop_backward_tape.release()();
std::vector<int64> targets;
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (Gradient* grad_tensor : grad) {
if (grad_tensor != nullptr) {
int64 tensor_id = vspace_.TensorId(grad_tensor);
targets.push_back(tensor_id);
if (sources_set.find(tensor_id) != sources_set.end()) {
sources_that_are_targets.emplace(
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
}
}
}
if (targets.size() > in_grads.size()) {
return tensorflow::errors::Internal("Too many gradients returned.");
}
for (int target_index = 0; target_index < targets.size(); ++target_index) {
Gradient* in_grad = in_grads[target_index];
Gradient* grad_tensor = grad[target_index];
if (grad_tensor != nullptr && in_grad != nullptr) {
// ComputeGradient steals a reference
vspace_.MarkAsResult(in_grad);
}
}
return tape->ComputeGradient(vspace_, targets, sources,
sources_that_are_targets, in_grads, out_grads);
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const string& op_type, const std::vector<TapeTensor>& input_tensors,
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (backward_tape_ != nullptr) {
// If backward_tape_ is not null, then this call to Accumulate is the result
// of a still-active call to Accumulate which is running operations. We
// forward these operations to backward_tape_ so the outer Accumulate call
// can do its work.
//
// Rather than re-entering and delegating Accumulate like this, we could
// instead allow ForwardAccumulator some control over the current tape set
// (so it can deactivate itself and activate its GradientTape). Currently
// that is managed by the language binding and would require relatively
// messy callbacks.
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
input_dtypes, backward_function_getter,
backward_function_deleter);
return Status::OK();
}
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
return Status::OK();
}
// We may need to allocate zero inputs for trainable dtypes we don't have JVPs
// for. Make sure they get cleaned up.
std::vector<Gradient*> new_zeros;
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
for (Gradient* tensor : new_zeros) {
this->vspace_.DeleteGradient(tensor);
}
});
std::vector<Gradient*> in_grads;
in_grads.reserve(input_tensors.size());
for (int target_index = 0; target_index < input_tensors.size();
++target_index) {
const auto current_grad =
accumulated_gradients_.find(input_tensors[target_index].GetID());
if (current_grad == accumulated_gradients_.end()) {
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
// GradientTape which uses ones.
Gradient* zero = vspace_.Zeros(input_tensors[target_index]);
new_zeros.push_back(zero);
in_grads.push_back(zero);
} else {
in_grads.push_back(nullptr);
}
} else {
in_grads.push_back(current_grad->second);
}
}
accumulating_ = true;
auto reset_accumulating =
gtl::MakeCleanup([this] { this->accumulating_ = false; });
std::vector<Gradient*> forward_grads;
TF_RETURN_IF_ERROR(
ForwardpropFromTape(output_tensors, backward_function_getter,
backward_function_deleter, in_grads, &forward_grads));
for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) {
int64 tensor_id = output_tensors[i].GetID();
auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) {
vspace_.DeleteGradient(existing->second);
}
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
}
}
return Status::OK();
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id, Gradient* tangent) {
typename std::unordered_map<int64, Gradient*>::iterator existing =
accumulated_gradients_.find(tensor_id);
vspace_.MarkAsResult(tangent);
if (existing == accumulated_gradients_.end()) {
accumulated_gradients_.emplace(tensor_id, tangent);
} else {
std::array<Gradient*, 2> to_aggregate;
to_aggregate[0] = tangent;
to_aggregate[1] = existing->second;
// AggregateGradients steals a reference to each of its arguments. We
// MarkAsResult on `tangent` above so we don't steal a reference to it.
existing->second = vspace_.AggregateGradients(to_aggregate);
}
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
int64 tensor_id) {
auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) {
vspace_.DeleteGradient(existing->second);
accumulated_gradients_.erase(existing);
}
}
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
int64 tensor_id) {
auto lookup = accumulated_gradients_.find(tensor_id);
if (lookup == accumulated_gradients_.end()) {
return nullptr;
} else {
return lookup->second;
}
}
} // namespace eager
} // namespace tensorflow

View File

@ -16,12 +16,36 @@ limitations under the License.
#ifndef TENSORFLOW_C_KERNELS_H_
#define TENSORFLOW_C_KERNELS_H_
#include "tensorflow/c/c_api.h"
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TF_Tensor TF_Tensor;
// --------------------------------------------------------------------------
// C API for TensorFlow Kernels.
//

View File

@ -14,8 +14,12 @@ tf_kernel_library(
prefix = "bitcast_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:ops",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:lib",
],
)
@ -28,6 +32,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

View File

@ -16,11 +16,12 @@ limitations under the License.
#include <sstream>
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.h"
// BitcastOp implements a bitcast kernel, creating an output tensor that shares
// the same data buffer as the input but with a different shape and/or data
@ -135,9 +136,8 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
TF_DeleteTensor(tensor);
}
static void RegisterBitcastOp() {
void RegisterBitcastOp() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
&BitcastOp_Create, &BitcastOp_Compute,

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -78,15 +78,10 @@ void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
TF_SetStatus(status, TF_OK, "");
::tensorflow::OpRegistry::Global()->Register(
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
return cc_builder->Finalize(op_reg_data);
Status result = cc_builder->Finalize(op_reg_data);
delete cc_builder;
return result;
});
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
// is called and that the builder can be deleted.
Set_TF_Status_from_Status(
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
delete cc_builder;
}
void TF_OpDefinitionBuilderSetShapeInferenceFunction(

View File

@ -73,7 +73,8 @@ limitations under the License.
#include <stdint.h>
#include <stdlib.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#ifdef SWIG
#define TF_CAPI_EXPORT

View File

@ -43,7 +43,6 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Status* status) {
AttrValue attr_val;
mutex_lock l(graph->mu);
op->node.ClearAttr(attr_name);

View File

@ -15,10 +15,13 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/coding.h"
@ -227,13 +230,15 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len);
if (sz < src_len) {
status->status = InvalidArgument("src string is too large to encode");
Set_TF_Status_from_Status(
status, InvalidArgument("src string is too large to encode"));
return 0;
}
if (dst_len < sz) {
status->status =
Set_TF_Status_from_Status(
status,
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
src_len, "-byte string");
src_len, "-byte string"));
return 0;
}
dst = tensorflow::core::EncodeVarint64(dst, src_len);
@ -259,7 +264,8 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len,
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
size_t* dst_len, TF_Status* status) {
status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
Set_TF_Status_from_Status(status,
TF_StringDecode_Impl(src, src_len, dst, dst_len));
if (TF_GetCode(status) != TF_OK) return 0;
return static_cast<size_t>(*dst - src) + *dst_len;
}
@ -299,8 +305,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) {
status->status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
Set_TF_Status_from_Status(
status, FailedPrecondition(
"attempt to use a tensor with an uninitialized value"));
return nullptr;
}
if (src.NumElements() == 0) {
@ -308,13 +315,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
status->status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error.");
Set_TF_Status_from_Status(
status, InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error."));
return nullptr;
}
const string str =
@ -353,9 +361,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (TF_GetCode(status) != TF_OK) {
status->status = InvalidArgument(
"invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", status->status.error_message());
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
@ -363,9 +372,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dst_len -= consumed;
}
if (dst != base + size) {
status->status = InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes");
Set_TF_Status_from_Status(
status, InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes"));
delete[] base;
return nullptr;
}

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_Tensor {
~TF_Tensor();
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
namespace tensorflow {
class TensorCApi {
public:
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
TensorBuffer* buf) {
return Tensor(static_cast<DataType>(type), shape, buf);
}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -236,10 +236,13 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//third_party/eigen3",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -141,6 +141,15 @@ Status ClientSession::RunCallable(CallableHandle handle,
run_metadata);
}
Status ClientSession::RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& options) {
return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
run_metadata, options);
}
Status ClientSession::ReleaseCallable(CallableHandle handle) {
return impl()->session_->ReleaseCallable(handle);
}

View File

@ -27,6 +27,12 @@ limitations under the License.
namespace tensorflow {
namespace thread {
struct ThreadPoolOptions;
}
/// @addtogroup core
/// @{
@ -110,6 +116,20 @@ class ClientSession {
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata);
/// \brief Invokes the subgraph named by `handle` with the given options and
/// input tensors.
///
/// The order of tensors in `feed_tensors` must match the order of names in
/// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
/// match the order of names in `CallableOptions::fetch()` when this subgraph
/// was created.
/// NOTE: This API is still experimental and may change.
Status RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& options);
/// \brief Releases resources associated with the given `handle` in this
/// session.
/// NOTE: This API is still experimental and may change.

View File

@ -13,24 +13,67 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#define EIGEN_USE_THREADS
#include "tensorflow/cc/client/client_session.h"
#include <vector>
#include "absl/synchronization/barrier.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/core/threadpool_options.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
using ops::Add;
using ops::BatchMatMul;
using ops::Const;
using ops::Mul;
using ops::Placeholder;
using ops::Sub;
class CustomThreadPoolImpl : public thread::ThreadPoolInterface {
public:
explicit CustomThreadPoolImpl(int numThreads) {
underlying_threadpool_.reset(new thread::ThreadPool(
tensorflow::Env::Default(), "custom_threadpool", numThreads));
num_schedule_called_ = 0;
}
void Schedule(std::function<void()> fn) override {
num_schedule_called_ += 1;
underlying_threadpool_->Schedule(std::move(fn));
}
void ScheduleWithHint(std::function<void()> fn, int start, int end) override {
num_schedule_called_ += 1;
underlying_threadpool_->ScheduleWithHint(std::move(fn), start, end);
}
void Cancel() override {}
int NumThreads() const override {
return underlying_threadpool_->NumThreads();
}
int CurrentThreadId() const override {
return underlying_threadpool_->CurrentThreadId();
}
int GetNumScheduleCalled() { return num_schedule_called_; }
private:
int num_schedule_called_;
std::unique_ptr<tensorflow::thread::ThreadPool> underlying_threadpool_;
};
TEST(ClientSessionTest, Basic) {
Scope root = Scope::NewRootScope();
auto c = Const(root, {{1, 1}});
@ -95,7 +138,7 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, Callable) {
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);
auto b = Placeholder(root, DT_INT32);
@ -116,5 +159,60 @@ TEST(ClientSessionTest, Callable) {
TF_EXPECT_OK(session.ReleaseCallable(callable));
}
TEST(ClientSessionTest, CallableWithCustomThreadPool) {
Scope root = Scope::NewRootScope();
int num_threads = 3;
TensorShape data_shape({1, 1});
auto a = Placeholder(root, DT_INT32, Placeholder::Shape(data_shape));
auto b = Placeholder(root, DT_INT32, Placeholder::Shape(data_shape));
auto c = BatchMatMul(root, a, b);
ClientSession session(root);
std::vector<Tensor> outputs;
auto inter_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
auto intra_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
CallableOptions options;
options.add_feed(a.node()->name());
options.add_feed(b.node()->name());
options.add_fetch(c.node()->name());
ClientSession::CallableHandle callable;
TF_CHECK_OK(session.MakeCallable(options, &callable));
// This is needed to have BatchMatMul computation be scheduled in the
// intra_op_threadpool.
absl::Barrier barrier(num_threads + 1);
for (int i = 0; i < num_threads; i++) {
intra_op_threadpool->Schedule([&barrier, num_threads]() {
tensorflow::SetPerThreadMaxParallelism(num_threads - 1);
barrier.Block();
});
}
barrier.Block();
TF_EXPECT_OK(session.RunCallable(
callable,
{test::AsTensor<int>({2}, {1, 1}), test::AsTensor<int>({10}, {1, 1})},
&outputs, nullptr, threadPoolOptions));
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({20}, {1, 1}));
TF_EXPECT_OK(session.ReleaseCallable(callable));
ASSERT_GT(inter_op_threadpool->GetNumScheduleCalled(), 0);
ASSERT_GT(intra_op_threadpool->GetNumScheduleCalled(), 0);
// Free intra_op_threadpool and wait for its threads to exit before freeing
// other objects (e.g. barrier). This is needed to avoid data race.
intra_op_threadpool.reset();
}
} // namespace
} // namespace tensorflow

View File

@ -1,10 +1,10 @@
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
tf_cuda_cc_test(
name = "profiler_test",
srcs = ["profiler_test.cc"],

View File

@ -1,13 +1,6 @@
# Description:
# TensorFlow SavedModel.
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"if_android",
@ -22,6 +15,13 @@ load(
"if_static_and_not_mobile",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
cc_library(
name = "constants",
hdrs = ["constants.h"],

View File

@ -34,7 +34,6 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:cpu_function_runtime",

View File

@ -1,11 +1,11 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).

View File

@ -286,8 +286,6 @@ def tf_library(
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually
# needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",

View File

@ -1,3 +1,8 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
package(
default_visibility = [
":internal",
@ -22,10 +27,6 @@ package_group(
],
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
name = "jit",
@ -50,7 +51,6 @@ cc_library(
deps = [
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
@ -192,18 +192,11 @@ cc_library(
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:host_constant_op",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:logging_ops",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:stack",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
@ -286,6 +279,7 @@ cc_library(
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
deps = [
":xla_activity_listener",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
@ -322,9 +316,10 @@ cc_library(
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
":xla_activity_logging_listener",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
] + tf_jit_compilation_passes_extra_deps(),
alwayslink = 1,
)
@ -512,6 +507,7 @@ cc_library(
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
"report_clustering_info_pass.cc",
],
hdrs = [
"build_xla_ops_pass.h",
@ -525,6 +521,7 @@ cc_library(
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
"report_clustering_info_pass.h",
],
deps = [
"compilability_check_util",
@ -535,6 +532,7 @@ cc_library(
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_activity_listener",
":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
@ -577,6 +575,7 @@ cc_library(
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":xla_activity_proto_cc",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -843,6 +842,27 @@ tf_cc_test(
],
)
tf_cc_test(
name = "xla_activity_listener_test",
srcs = ["xla_activity_listener_test.cc"],
deps = [
":flags",
":xla_activity_listener",
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:partitioned_function_ops",
],
)
tf_custom_op_py_library(
name = "xla_ops_py",
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
@ -855,6 +875,37 @@ tf_custom_op_py_library(
],
)
cc_library(
name = "xla_activity_listener",
srcs = ["xla_activity_listener.cc"],
hdrs = ["xla_activity_listener.h"],
visibility = ["//visibility:public"],
deps = [
":xla_activity_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
tf_proto_library(
name = "xla_activity_proto",
srcs = ["xla_activity.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
provide_cc_alias = True,
)
cc_library(
name = "xla_activity_logging_listener",
srcs = ["xla_activity_logging_listener.cc"],
deps = [
":xla_activity_listener",
"//tensorflow/core:logger",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",

View File

@ -72,7 +72,47 @@ void LogNotCompilable(const Node& node, absl::string_view reason = "") {
} // anonymous namespace
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
RecursiveCompilabilityChecker::FindUncompilableNodes(
const Node& node, FunctionLibraryRuntime* lib_runtime,
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
node_stack_trace) const {
std::vector<StackFrameView> stack_trace;
// If `node_stack_trace` is provided, that means `node` is inside
// a function body, and therefore, arg nodes and retval nodes are
// not considered uncompilable.
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
}
}
stack_trace.emplace_back(StackFrameView{node.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
return uncompilable_nodes;
}
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
RecursiveCompilabilityChecker::FindUncompilableNodes(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
node_stack_trace) const {
// If `node_stack_trace` is provided, that means `call_def` is inside
// a function body, and therefore, arg nodes and retval nodes are
// not considered uncompilable.
std::vector<StackFrameView> stack_trace;
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
}
}
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
return uncompilable_nodes;
}
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
@ -104,7 +144,7 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
bool RecursiveCompilabilityChecker::IsCompilableWhile(
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
const NameAttrList* name_attr;
NodeDef call;
Status status;
@ -155,7 +195,7 @@ bool RecursiveCompilabilityChecker::IsCompilableWhile(
bool RecursiveCompilabilityChecker::IsCompilableCall(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
if (stack_trace->size() > kMaxRecursionDepth) {
std::string uncompilable_reason = "function depth limit exceeded";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
@ -191,22 +231,24 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
return is_compilable;
}
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd";
}
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
// b/135640736: MatrixInverse performance issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd" || node.type_string() == "Qr";
node.type_string() == "Svd" || node.type_string() == "Qr" ||
node.type_string() == "MatrixInverse";
}
bool RecursiveCompilabilityChecker::IsCompilableNode(
const Node& node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
auto stack_depth = stack_trace->size();
if (node.IsSource() || node.IsSink()) {
absl::string_view uncompilable_reason = "source or sink node";
@ -358,7 +400,7 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
return op_filter;
}
void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
/*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
const absl::string_view reason,
const std::vector<StackFrameView>& stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_node_list) {

View File

@ -129,29 +129,25 @@ class RecursiveCompilabilityChecker {
const DeviceType* jit_device_type)
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
// Returns a list of uncompilable nodes.
// Returns a list of uncompilable nodes. When `node` is inside a function
// body, users can set `node_stack_trace` to provide an additional
// context for `node`'s placement within the outer most graph.
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
const Node& node, FunctionLibraryRuntime* lib_runtime) {
std::vector<StackFrameView> stack_trace;
stack_trace.emplace_back(StackFrameView{node.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
return uncompilable_nodes;
}
const Node& node, FunctionLibraryRuntime* lib_runtime,
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
// Returns a list of uncompilable nodes in `call_def` that cannot be
// compiled by XLA. It is assumed that `call_def` is a call operation.
// When `node` is inside a function body, users can set
// `node_stack_trace` to provide an additional context for `node`'s
// placement within the outer most graph.
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime) {
std::vector<StackFrameView> stack_trace;
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
std::vector<UncompilableNodeInfo> uncompilable_nodes;
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
return uncompilable_nodes;
}
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
// Returns true if `node` can be compiled by XLA.
bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) {
bool IsCompilableNode(const Node& node,
FunctionLibraryRuntime* lib_runtime) const {
std::vector<StackFrameView> stack_trace;
stack_trace.emplace_back(StackFrameView{node.name(), ""});
return IsCompilableNode(node, lib_runtime, &stack_trace);
@ -168,8 +164,8 @@ class RecursiveCompilabilityChecker {
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
// due to performance or correctness concerns).
bool OpIsInaccurate(const Node& node);
bool OpIsSlow(const Node& node);
bool OpIsInaccurate(const Node& node) const;
bool OpIsSlow(const Node& node) const;
private:
struct StackFrameView {
@ -180,47 +176,47 @@ class RecursiveCompilabilityChecker {
bool IsCompilableNode(
const Node& node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr);
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
bool IsCompilableCall(
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr);
bool IsCompilableWhile(const Node& while_node,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes);
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
bool IsCompilableWhile(
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
bool IsStackOp(const Node& node) {
bool IsStackOp(const Node& node) const {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
}
bool IsTensorArrayOp(const Node& node) {
bool IsTensorArrayOp(const Node& node) const {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
}
bool IsAssertOrCheckNumerics(absl::string_view op_name) {
bool IsAssertOrCheckNumerics(absl::string_view op_name) const {
return op_name == "Assert" || op_name == "CheckNumerics";
}
bool IsStatefulRandomOp(absl::string_view op_name) {
bool IsStatefulRandomOp(absl::string_view op_name) const {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
op_name == "TruncatedNormal" || op_name == "Multinomial";
}
bool OpProducesOrConsumesVariant(const Node& node) {
bool OpProducesOrConsumesVariant(const Node& node) const {
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
return absl::c_any_of(node.input_types(), is_variant) ||
absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node);
bool HasXLAKernel(const Node& node) const;
void MaybeMarkUncompilableNode(
static void MaybeMarkUncompilableNode(
const absl::string_view reason,
const std::vector<StackFrameView>& stack_trace,
std::vector<UncompilableNodeInfo>* uncompilable_node_list);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@ -36,7 +37,7 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
IntroduceFloatingPointJitterPass);
// from
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27
//
// This pass looks at the graph and all associated FunctionDefs, and turns
@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
PartiallyDeclusterPass);
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
// passes have run but before encapsulation has run. This way it can easily
// compute a summary of the clustering decisions we made and broadcast it via
// xla_activity_listener.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
ReportClusteringInfoPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
BuildXlaOpsPass);
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/stream_executor_util.h"
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
@ -525,9 +526,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
// We're missing the must-be-constant inputs, tell `PopulateInputs`
// about this. We don't actually need these inputs because they've
// already been baked into the compiled kernel.
launch_context.PopulateInputs(
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
/*missing_ctx_input_prefix=*/closure.num_constant_args());
{
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
return absl::StrCat(
"Populate Inputs (",
closure.compilation_result()->xla_input_shapes.size(), ")");
},
tensorflow::profiler::TraceMeLevel::kInfo);
launch_context.PopulateInputs(
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
/*missing_ctx_input_prefix=*/closure.num_constant_args());
}
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@ -546,6 +557,12 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
},
tensorflow::profiler::TraceMeLevel::kInfo);
OP_REQUIRES_OK(
ctx,
launch_context.PopulateOutputs(

View File

@ -1014,6 +1014,39 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
return Status::OK();
}
StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
if (!node->IsIdentity()) {
return false;
}
// Check if the Identity is driven by a Switch on its true path.
auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
return e->src()->IsSwitch() && e->src_output() == 1;
});
if (it == node->in_edges().end()) {
return false;
}
const Node* switch_node = (*it)->src();
// Check if the Switch is driven by LoopCond.
const Node* maybe_loop_cond;
TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
if (!maybe_loop_cond->IsLoopCond()) {
return false;
}
// Check if the Identity is driving any const nodes through a control edge.
bool driving_any_consts =
absl::c_any_of(node->out_edges(), [](const Edge* e) {
return e->dst()->IsConstant() && e->IsControlEdge();
});
if (!driving_any_consts) {
return false;
}
return true;
}
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
@ -1135,6 +1168,35 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
}
}
// This is a heuristic to avoid creating dependency between while loop
// condition and body computations. Dependency between them can be created
// if a special Identity node in the following pattern is clustered in.
// That is, an Identity node in the loop cond computation is used to drive
// const nodes consumed by the loop body. If this Identity node goes into
// the same cluster with nodes from the loop body, extra dependency is
// created between the loop cond and body computations and it hinders the
// progression of the loop cond computation at runtime with significant
// overhead. Specifically, we look for the below pattern and do not cluster
// in this Identity to avoid the described issue. Since Identity has low
// execution cost in native TF, the fact that this heuristic gives up these
// special Identity nodes as candidates should not harm any performance. If
// other considerations emerge in the future, we can revisit the heuristic
// and only disallow these Identities to go into the cluster with nodes from
// the loop body but still consider them candidates.
//
// LoopCond ->
// Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
// ..> Const -> LoopBody
// (control edge)
TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
IsIdentityDrivingConstsInLoop(node));
if (is_identity_driving_consts_in_loop) {
VLOG(2) << "Rejecting " << node->name()
<< ": including it can create dependencies between while loop "
"condition and body computations with runtime overhead.";
continue;
}
compilation_candidates_.insert(node);
--(*debug_options_.fuel);
}
@ -1302,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
return;
}
std::map<absl::string_view, int> cluster_name_to_size;
std::map<absl::string_view, std::map<absl::string_view, int>>
cluster_name_to_op_histogram;
std::map<absl::string_view, int> unclustered_op_histogram;
int clustered_node_count = 0;
for (Node* n : graph_->nodes()) {
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
if (cluster_name) {
clustered_node_count++;
cluster_name_to_size[*cluster_name]++;
cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
} else {
unclustered_op_histogram[n->type_string()]++;
}
}
int unclustered_node_count = graph_->num_nodes() - clustered_node_count;
XlaAutoClusteringSummary auto_clustering_info =
GetXlaAutoClusteringSummary(*graph_);
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
<< RatioToString(clustered_node_count, graph_->num_nodes());
VLOG(2) << " Built " << auto_clustering_info.clusters_size()
<< " clusters, size "
<< RatioToString(auto_clustering_info.clustered_node_count(),
graph_->num_nodes());
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
absl::string_view cluster_name = cluster_name_size_pair.first;
int size = cluster_name_size_pair.second;
for (XlaAutoClusteringSummary::Cluster cluster :
auto_clustering_info.clusters()) {
absl::string_view cluster_name = cluster.name();
int size = cluster.size();
VLOG(2) << " " << cluster_name << " "
<< RatioToString(size, graph_->num_nodes());
for (const auto& op_count_pair :
cluster_name_to_op_histogram[cluster_name]) {
VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
cluster.op_histogram()) {
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
<< " instances";
}
}
if (!unclustered_op_histogram.empty()) {
if (!auto_clustering_info.unclustered_op_histogram().empty()) {
VLOG(2) << " Unclustered nodes: "
<< RatioToString(unclustered_node_count, graph_->num_nodes());
for (const auto& pair : unclustered_op_histogram) {
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
<< RatioToString(auto_clustering_info.unclustered_node_count(),
graph_->num_nodes());
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
auto_clustering_info.unclustered_op_histogram()) {
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
<< " instances";
}
}

View File

@ -88,7 +88,6 @@ absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
TEST(XlaCompilationTest, Chains) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
@ -114,7 +113,6 @@ TEST(XlaCompilationTest, Chains) {
TEST(XlaCompilationTest, UncompilableCycles) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -135,7 +133,6 @@ TEST(XlaCompilationTest, UncompilableCycles) {
TEST(XlaCompilationTest, CompilableCycles) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -157,7 +154,6 @@ TEST(XlaCompilationTest, CompilableCycles) {
TEST(XlaCompilationTest, StringUnsupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
@ -177,7 +173,6 @@ TEST(XlaCompilationTest, StringUnsupported) {
TEST(XlaCompilationTest, HalfSupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Tensor t(DT_HALF, TensorShape());
@ -253,7 +248,6 @@ TEST(XlaCompilationTest, FunctionCalls) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
Node* a =
@ -291,7 +285,6 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
Node* resource =
@ -381,7 +374,6 @@ REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
TEST(XlaCompilationTest, SymbolicGradients) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
@ -449,7 +441,6 @@ TEST(XlaCompilationTest, Loops) {
TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -483,7 +474,6 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -512,7 +502,6 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -555,7 +544,6 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -789,7 +777,6 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
TEST(XlaCompilationTest, Retval) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
@ -1677,5 +1664,37 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
}
// Test a pattern where a special Identity node is driving consts in a loop.
// Expect that the Identity node will not go into any clusters. Note that we
// create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
// etc.) just enough to test the pattern, as a complete graph may be too
// cumbersome and unnecessary.
TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
Scope root = Scope::NewRootScope().ExitOnError();
Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
Output identity =
ops::Identity(root.WithOpName("identity"), switch_node.output_true);
Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
root.graph()->AddControlEdge(identity.node(), const_node.node());
Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph,
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
auto clusters = GetClusters(*graph);
EXPECT_EQ(clusters["identity"], "");
}
} // namespace
} // namespace tensorflow

View File

@ -1,10 +1,10 @@
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
namespace tensorflow {
Status ReportClusteringInfoPass::Run(
const GraphOptimizationPassOptions& options) {
XlaAutoClusteringActivity activity;
*activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph);
activity.set_global_jit_level(GetGlobalJitLevelForGraph(options));
activity.set_cpu_global_jit_enabled(
GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit);
return BroadcastXlaActivity(std::move(activity));
}
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// This is not really an optimization pass. It does not change the graph in any
// way; instead it computes a summary of the XLA clusters in the graph and
// broadcasts it via xla_activity_listener.
class ReportClusteringInfoPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_

View File

@ -1,8 +1,8 @@
Clustered nodes: 1988
Unclustered nodes: 3960
Clustered nodes: 1962
Unclustered nodes: 3974
Number of clusters: 29
unclustered size 3960
unclustered size 3974
Add 17
AddN 1
ApplyAdam 38
@ -14,7 +14,7 @@ unclustered size 3960
Cast 8
ConcatOffset 10
ConcatV2 2
Const 704
Const 708
ControlTrigger 5
DynamicStitch 1
Enter 874
@ -24,7 +24,7 @@ unclustered size 3960
FloorDiv 1
FloorMod 1
GreaterEqual 7
Identity 105
Identity 113
IsVariableInitialized 1
IteratorGetNext 1
IteratorV2 1
@ -42,7 +42,7 @@ unclustered size 3960
RefSwitch 166
Reshape 2
ScatterAdd 4
Shape 4
Shape 6
ShapeN 10
Size 2
Snapshot 1
@ -169,15 +169,13 @@ cluster 7 size 11
Mul 2
Pow 1
Sub 1
cluster 10 size 14
Add 2
cluster 10 size 8
Add 1
All 2
Const 4
Const 2
GreaterEqual 1
Identity 1
Less 1
LogicalOr 1
Shape 2
cluster 11 size 226
Add 24
BatchMatMulV2 1
@ -226,13 +224,12 @@ cluster 12 size 430
TanhGrad 17
Tile 2
ZerosLike 1
cluster 13 size 25
Add 3
cluster 13 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 3
Const 1
GreaterEqual 1
Identity 2
MatMul 1
Mul 3
Select 3
@ -256,13 +253,12 @@ cluster 14 size 52
Slice 2
Sum 9
TanhGrad 2
cluster 15 size 25
Add 3
cluster 15 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 3
Const 1
GreaterEqual 1
Identity 2
MatMul 1
Mul 3
Select 3
@ -290,14 +286,13 @@ cluster 17 size 52
Slice 2
Sum 9
TanhGrad 2
cluster 19 size 30
Add 3
cluster 19 size 25
Add 2
BiasAdd 1
Cast 1
ConcatV2 1
Const 5
Const 3
GreaterEqual 2
Identity 2
MatMul 1
Mul 5
Select 2
@ -305,77 +300,7 @@ cluster 19 size 30
Snapshot 1
Split 1
Tanh 2
cluster 20 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 21 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 22 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 24 size 24
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 25 size 363
cluster 20 size 363
Add 12
AddN 28
BiasAddGrad 6
@ -391,6 +316,71 @@ cluster 25 size 363
Slice 12
Sum 76
TanhGrad 12
cluster 21 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 22 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 24 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 25 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 26 size 9
AddN 1
MatMul 2

View File

@ -0,0 +1,96 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
import "tensorflow/core/protobuf/config.proto";
// Summarizes the results of auto-clustering a TensorFlow graph.
//
// Next ID: 5
message XlaAutoClusteringSummary {
// Represents a single element in a histogram of ops ("op" as in "TensorFlow
// operation").
//
// Next ID: 3
message OpAndCount {
// The TensorFlow operation (like MatMult, Add etc.)
string op = 1;
// The number of times this occurs.
int32 count = 2;
}
// Describes a single XLA cluster.
//
// Next ID: 4
message Cluster {
string name = 1;
// The number of nodes in the cluster.
int32 size = 2;
// A histogram of the TF operations in this cluster.
repeated OpAndCount op_histogram = 3;
};
// The number of nodes in the graph that are not inside an XLA cluster.
int32 unclustered_node_count = 1;
// The number of nodes in the graph that are in an XLA cluster.
int32 clustered_node_count = 2;
// All of the XLA clusters in the TF graph.
repeated Cluster clusters = 3;
// A histogram of the TF operations that were not clustered.
repeated OpAndCount unclustered_op_histogram = 4;
}
// Listeners listening for auto clustering events get messages of this type.
//
// Next ID: 4
message XlaAutoClusteringActivity {
// The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`.
// This determines if global auto-clustering is enabled.
OptimizerOptions.GlobalJitLevel global_jit_level = 1;
// Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS.
bool cpu_global_jit_enabled = 2;
XlaAutoClusteringSummary summary = 3;
}
// Listeners listening for JIT compilation events get messages of this type.
// Each instance of XlaJitCompilationActivity corresponds to a single
// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and
// B, and A is compiled 5 times and B is compiled 2 times then we will generate
// 7 instances of XlaJitCompilationActivity.
//
// Next ID: 5
message XlaJitCompilationActivity {
string cluster_name = 1;
// The number of time this cluster has been compiled.
int32 compile_count = 2;
// Microseconds spent in the individual compilation being reported.
int64 compile_time_us = 3;
// Total microseconds spent in (re-)compiling this cluster so far.
int64 cumulative_compile_time_us = 4;
}

View File

@ -0,0 +1,86 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
// The list of all registered `XlaActivityListener`s.
struct XlaActivityListenerList {
absl::Mutex mutex;
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
};
void FlushAllListeners();
XlaActivityListenerList* GetXlaActivityListenerList() {
static XlaActivityListenerList* listener_list = new XlaActivityListenerList;
static int unused = std::atexit(FlushAllListeners);
(void)unused;
return listener_list;
}
template <typename FnTy>
Status ForEachListener(FnTy fn) {
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
absl::ReaderMutexLock reader_lock(&listener_list->mutex);
for (const std::unique_ptr<XlaActivityListener>& listener :
listener_list->listeners) {
TF_RETURN_IF_ERROR(fn(listener.get()));
}
return Status::OK();
}
void FlushAllListeners() {
Status s = ForEachListener([](XlaActivityListener* listener) {
listener->Flush();
return Status::OK();
});
CHECK(s.ok());
}
} // namespace
Status BroadcastXlaActivity(
XlaAutoClusteringActivity auto_clustering_activity) {
return ForEachListener([&](XlaActivityListener* listener) {
return listener->Listen(auto_clustering_activity);
});
}
Status BroadcastXlaActivity(
XlaJitCompilationActivity jit_compilation_activity) {
return ForEachListener([&](XlaActivityListener* listener) {
return listener->Listen(jit_compilation_activity);
});
}
void RegisterXlaActivityListener(
std::unique_ptr<XlaActivityListener> listener) {
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
absl::WriterMutexLock writer_lock(&listener_list->mutex);
listener_list->listeners.push_back(std::move(listener));
}
void XlaActivityListener::Flush() {}
XlaActivityListener::~XlaActivityListener() {}
} // namespace tensorflow

View File

@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
#include <memory>
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Broadcast `auto_clustering_activity` to all the registered listeners.
Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
// Broadcast `jit_compilation_activity` to all the registered listeners.
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
// Various components of the system can subclass XlaActivityListener to
// notifications on auto-clustering and JIT compilation events.
//
// Subclasses of XlaActivityListener must be thread safe.
class XlaActivityListener {
public:
// Called after TensorFlow auto-clusters a graph.
virtual Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) = 0;
// Called after TensorFlow JIT compiles an XLA cluster.
virtual Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
// Called at program exit in best-effort manner to give listeners a chance to
// flush their state.
//
// Default implementation is a no-op.
virtual void Flush();
virtual ~XlaActivityListener();
};
// Registers an `XlaActivityListener`, which will be invoked on all subsequent
// `BroadcastXlaActivity` calls.
void RegisterXlaActivityListener(std::unique_ptr<XlaActivityListener> listener);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_

View File

@ -0,0 +1,193 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include <cstdlib>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/direct_session.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class TestListener : public XlaActivityListener {
public:
Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) override {
auto_clustering_activity_ = auto_clustering_activity;
return Status::OK();
}
Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) override {
jit_compilation_activity_ = jit_compilation_activity;
return Status::OK();
}
~TestListener() override {}
const XlaAutoClusteringActivity& auto_clustering_activity() const {
return auto_clustering_activity_;
}
const XlaJitCompilationActivity& jit_compilation_activity() const {
return jit_compilation_activity_;
}
private:
XlaAutoClusteringActivity auto_clustering_activity_;
XlaJitCompilationActivity jit_compilation_activity_;
};
class XlaActivityListenerTest : public ::testing::Test {
protected:
XlaActivityListenerTest() {
auto listener = absl::make_unique<TestListener>();
listener_ = listener.get();
RegisterXlaActivityListener(std::move(listener));
}
TestListener* listener() const { return listener_; }
private:
TestListener* listener_;
};
GraphDef CreateGraphDef() {
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
"/job:localhost/replica:0/task:0/device:CPU:0");
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
for (int i = 0; i < 5; i++) {
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
}
GraphDef graph_def;
root.graph()->ToGraphDef(&graph_def);
return graph_def;
}
TEST_F(XlaActivityListenerTest, Test) {
GraphDef graph_def = CreateGraphDef();
SessionOptions options;
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(graph_def));
std::vector<std::string> output_names = {std::string("add_4:0")};
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
for (int i = 0; i < 4; i++) {
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
}
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
for (int i = 0; i < 9; i++) {
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
}
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
&outputs));
absl::string_view expected_auto_clustering_activity =
R"(global_jit_level: ON_2
cpu_global_jit_enabled: true
summary {
unclustered_node_count: 4
clustered_node_count: 14
clusters {
name: "cluster_0"
size: 14
op_histogram {
op: "Add"
count: 1
}
op_histogram {
op: "Const"
count: 4
}
op_histogram {
op: "MatMul"
count: 5
}
op_histogram {
op: "Mul"
count: 4
}
}
unclustered_op_histogram {
op: "NoOp"
count: 2
}
unclustered_op_histogram {
op: "_Arg"
count: 1
}
unclustered_op_histogram {
op: "_Retval"
count: 1
}
}
)";
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
expected_auto_clustering_activity);
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
int64 first_compile_time =
listener()->jit_compilation_activity().compile_time_us();
EXPECT_GT(first_compile_time, 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time);
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
outputs.clear();
for (int i = 0; i < 3; i++) {
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
/*target_node_names=*/{}, &outputs));
}
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time +
listener()->jit_compilation_activity().compile_time_us());
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) {
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,69 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/core/platform/logger.h"
namespace tensorflow {
namespace {
// Listens to XLA activity and logs them using tensorflow::Logger.
class XlaActivityLoggingListener final : public XlaActivityListener {
public:
Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) override {
if (!IsEnabled()) {
VLOG(3) << "Logging XlaAutoClusteringActivity disabled";
return Status::OK();
}
VLOG(2) << "Logging XlaAutoClusteringActivity";
VLOG(3) << auto_clustering_activity.DebugString();
Logger::Singleton()->LogProto(auto_clustering_activity);
return Status::OK();
}
Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) override {
if (!IsEnabled()) {
VLOG(3) << "Logging XlaJitCompilationActivity disabled";
return Status::OK();
}
VLOG(2) << "Logging XlaJitCompilationActivity";
VLOG(3) << jit_compilation_activity.DebugString();
Logger::Singleton()->LogProto(jit_compilation_activity);
return Status::OK();
}
private:
bool IsEnabled() {
static bool result = ComputeIsEnabled();
return result;
}
bool ComputeIsEnabled() {
char* log_xla_activity = getenv("TF_LOG_XLA_ACTIVITY");
return log_xla_activity && !strcmp(log_xla_activity, "1");
}
};
bool Register() {
RegisterXlaActivityListener(absl::make_unique<XlaActivityLoggingListener>());
return false;
}
bool unused = Register();
} // namespace
} // namespace tensorflow

View File

@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
node.type_string() == "Size";
}
namespace {
struct ClusterInfo {
int size;
// Maps op names to the number of times they appear in the cluster.
absl::flat_hash_map<absl::string_view, int> op_histogram;
};
void HistogramMapToRepeatedOpAndCount(
protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
const absl::flat_hash_map<absl::string_view, int>& histogram) {
for (const auto& pair : histogram) {
XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
new_entry->set_op(std::string(pair.first));
new_entry->set_count(pair.second);
}
absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
const XlaAutoClusteringSummary::OpAndCount& b) {
return a.op() < b.op();
});
}
void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
absl::string_view name, const ClusterInfo& info) {
result->set_name(std::string(name));
result->set_size(info.size);
HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
info.op_histogram);
}
} // namespace
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
XlaAutoClusteringSummary result;
absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
for (Node* n : graph.nodes()) {
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
if (cluster_name) {
result.set_clustered_node_count(result.clustered_node_count() + 1);
ClusterInfo* info = &cluster_name_to_info[*cluster_name];
info->size++;
info->op_histogram[n->type_string()]++;
} else {
result.set_unclustered_node_count(result.unclustered_node_count() + 1);
unclustered_op_histogram[n->type_string()]++;
}
}
for (const auto& pair : cluster_name_to_info) {
XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
}
absl::c_sort(*result.mutable_clusters(),
[&](const XlaAutoClusteringSummary::Cluster& a,
const XlaAutoClusteringSummary::Cluster& b) {
return a.name() < b.name();
});
HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
unclustered_op_histogram);
return result;
}
} // namespace tensorflow

View File

@ -18,8 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h"
@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
// Returns true if `node` an operator that consumes only the shape of its input,
// not the data itself.
bool IsShapeConsumerOp(const Node& node);
// Computes a clustering summary for `graph`. See documentation on
// `XlaAutoClusteringSummary` for details.
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl(
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
XlaJitCompilationActivity jit_compilation_activity;
jit_compilation_activity.set_cluster_name(function.name());
jit_compilation_activity.set_compile_count(it->second.compile_count);
jit_compilation_activity.set_compile_time_us(compile_time_us);
jit_compilation_activity.set_cumulative_compile_time_us(
it->second.cumulative_compile_time_us);
TF_RETURN_IF_ERROR(
BroadcastXlaActivity(std::move(jit_compilation_activity)));
}
}
TF_RETURN_IF_ERROR(entry->compilation_status);

View File

@ -375,18 +375,6 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
profiler::TraceMe activity(
[&] {
return absl::StrCat("XlaDevice::Compute ", op_kernel->name(), ":",
op_kernel->type_string(),
"#step_id=", context->step_id(),
",step_container_name=",
context->step_container() == nullptr
? "n/a"
: context->step_container()->name(),
"#");
},
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
op_kernel->Compute(context);
}
@ -394,18 +382,6 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
profiler::TraceMe activity(
[&] {
return absl::StrCat("XlaDevice::ComputeAsync ", op_kernel->name(), ":",
op_kernel->type_string(),
"#step_id=", context->step_id(),
",step_container_name=",
context->step_container() == nullptr
? "n/a"
: context->step_container()->name(),
"#");
},
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
op_kernel->ComputeAsync(context, done);
}

View File

@ -21,22 +21,15 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/data/generator_dataset_op.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/function_ops.h"
#include "tensorflow/core/kernels/host_constant_op.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/logging_ops.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/stack.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@ -80,24 +73,14 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("Assert") \
.Device(DEVICE) \
.HostMemory("condition") \
.HostMemory("data"), \
AssertOp); \
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
ConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
ResourceHandleOp<Var>); \
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \
ResourceHandlesOp<Var>); \
@ -153,36 +136,6 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
XlaAssignVariableOp); \
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
ControlTriggerOp); \
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
SwitchOp); \
REGISTER_KERNEL_BUILDER( \
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
NextIterationOp); \
REGISTER_KERNEL_BUILDER(Name("LoopCond") \
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
LoopCondOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
.Device(DEVICE) \
.HostMemory("size") \
.HostMemory("handle"), \
QueueSizeOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
QueueIsClosedOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
@ -262,25 +215,7 @@ class XlaAssignVariableOp : public OpKernel {
.Device(DEVICE) \
.TypeConstraint<string>("T") \
.HostMemory("input"), \
RetvalOp); \
\
REGISTER_KERNEL_BUILDER(Name("StackV2") \
.Device(DEVICE) \
.HostMemory("max_size") \
.HostMemory("handle"), \
StackOp); \
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("T", TYPES), \
TemplatedStackPushOp</*allow_swapping=*/false>); \
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("elem_type", TYPES), \
StackPopOp); \
REGISTER_KERNEL_BUILDER( \
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
RetvalOp);
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read

View File

@ -314,6 +314,21 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "matrix_inverse_op_test",
size = "small",
timeout = "moderate",
srcs = ["matrix_inverse_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "matrix_triangular_solve_op_test",
size = "small",
@ -557,6 +572,7 @@ tf_xla_py_test(
name = "image_ops_test",
size = "small",
srcs = ["image_ops_test.py"],
shard_count = 10,
tags = [
"optonly", # Times out frequently in fastbuild mode.
],
@ -598,6 +614,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "manip_ops_test",
size = "small",
srcs = ["manip_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:manip_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "matrix_band_part_test",
size = "medium",
@ -970,6 +999,7 @@ tf_xla_py_test(
name = "unary_ops_test",
size = "medium",
srcs = ["unary_ops_test.py"],
tags = ["notap"], # b/136030724
deps = [
":xla_test",
"//tensorflow/python:array_ops",

View File

@ -51,7 +51,7 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
for dtype in self.int_types | self.float_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@ -312,30 +313,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
if dtype in [np.float32, np.float64]:
x = np.array([
-0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0,
1.0
],
dtype=dtype)
y = np.array(
[-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0],
dtype=dtype)
expected = np.nextafter(x, y)
# We use assertAllEqual to expose any bugs hidden by relative or
# absolute error tolerances.
def NextAfterEqualityTest(result, expected, rtol):
del rtol
return self.assertAllEqual(result, expected)
self._testBinary(
math_ops.nextafter,
x,
y,
expected=expected,
equality_test=NextAfterEqualityTest)
# min/max not supported for complex
if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
@ -423,6 +400,32 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
dtype=np.int64))
def testNextAfter(self):
for dtype in self.numeric_types:
if dtype in [np.float32, np.float64]:
x = np.array([
-0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0,
1.0
],
dtype=dtype)
y = np.array(
[-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0],
dtype=dtype)
expected = np.nextafter(x, y)
# We use assertAllEqual to expose any bugs hidden by relative or
# absolute error tolerances.
def NextAfterEqualityTest(result, expected, rtol):
del rtol
return self.assertAllEqual(result, expected)
self._testBinary(
math_ops.nextafter,
x,
y,
expected=expected,
equality_test=NextAfterEqualityTest)
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
@ -1478,10 +1481,12 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=None)
def testMatrixSetDiag(self):
# TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to
# gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2).
for dtype in self.numeric_types:
# Square
self._testBinary(
array_ops.matrix_set_diag,
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
dtype=dtype),
np.array([1.0, 2.0, 3.0], dtype=dtype),
@ -1489,7 +1494,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
gen_array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
dtype=dtype),
@ -1501,19 +1506,19 @@ class BinaryOpsTest(xla_test.XLATestCase):
# Rectangular
self._testBinary(
array_ops.matrix_set_diag,
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
gen_array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
np.array([[-1.0, -2.0], [-4.0, -5.0]],

View File

@ -131,15 +131,19 @@ class FFTTest(xla_test.XLATestCase):
signal.ifft3d)
def testRFFT(self):
self._VerifyFftMethod(
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
def _to_expected(x):
return np.fft.rfft(x, n=x.shape[-1])
def _tf_fn(x):
return signal.rfft(x, fft_length=[x.shape[-1]])
self._VerifyFftMethod(INNER_DIMS_1D, np.real, _to_expected, _tf_fn)
def testRFFT2D(self):
def _tf_fn(x):
return signal.rfft2d(
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
return signal.rfft2d(x, fft_length=[x.shape[-2], x.shape[-1]])
self._VerifyFftMethod(
INNER_DIMS_2D, np.real,
@ -153,8 +157,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.rfft3d(
x,
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
x, fft_length=[x.shape[-3], x.shape[-2], x.shape[-1]])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
@ -168,17 +171,14 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.rfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
def testIRFFT(self):
def _tf_fn(x):
return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
return signal.irfft(x, fft_length=[2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
@ -187,8 +187,7 @@ class FFTTest(xla_test.XLATestCase):
def testIRFFT2D(self):
def _tf_fn(x):
return signal.irfft2d(
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
return signal.irfft2d(x, fft_length=[x.shape[-2], 2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(
INNER_DIMS_2D,
@ -212,10 +211,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
])
x, fft_length=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
@ -235,10 +231,7 @@ class FFTTest(xla_test.XLATestCase):
def _tf_fn(x):
return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)

View File

@ -0,0 +1,68 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for manip ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import manip_ops
from tensorflow.python.platform import googletest
class ManipOpsTest(xla_test.XLATestCase):
"""Test cases for manip ops."""
def _testRoll(self, a, shift, axis):
with self.session() as session:
with self.test_scope():
p = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
output = manip_ops.roll(a, shift, axis)
result = session.run(output, {p: a})
self.assertAllEqual(result, np.roll(a, shift, axis))
def testNumericTypes(self):
for t in self.numeric_types:
self._testRoll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
self._testRoll(
np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -6, 6],
[0, 1, 2])
self._testRoll(
np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
[1, 2, 3])
def testFloatTypes(self):
for t in self.float_types:
self._testRoll(np.random.rand(5).astype(t), 2, 0)
self._testRoll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
self._testRoll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
def testComplexTypes(self):
for t in self.complex_types:
x = np.random.rand(4, 4).astype(t)
self._testRoll(x + 1j * x, 2, 0)
x = np.random.rand(2, 5).astype(t)
self._testRoll(x + 1j * x, [1, 2], [1, 0])
x = np.random.rand(3, 2, 1, 1).astype(t)
self._testRoll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,86 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class InverseOpTest(xla_test.XLATestCase):
def _verifyInverse(self, x, np_type):
for adjoint in False, True:
y = x.astype(np_type)
with self.session() as sess:
# Verify that x^{-1} * x == Identity matrix.
p = array_ops.placeholder(dtypes.as_dtype(y.dtype), y.shape, name="x")
with self.test_scope():
inv = linalg_ops.matrix_inverse(p, adjoint=adjoint)
tf_ans = math_ops.matmul(inv, p, adjoint_b=adjoint)
np_ans = np.identity(y.shape[-1])
if x.ndim > 2:
tiling = list(y.shape)
tiling[-2:] = [1, 1]
np_ans = np.tile(np_ans, tiling)
out = sess.run(tf_ans, feed_dict={p: y})
self.assertAllClose(np_ans, out, rtol=1e-3, atol=1e-3)
self.assertShapeEqual(y, tf_ans)
def _verifyInverseReal(self, x):
for np_type in self.float_types & {np.float64, np.float32}:
self._verifyInverse(x, np_type)
def _makeBatch(self, matrix1, matrix2):
matrix_batch = np.concatenate(
[np.expand_dims(matrix1, 0),
np.expand_dims(matrix2, 0)])
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
return matrix_batch
def testNonsymmetric(self):
# 2x2 matrices
matrix1 = np.array([[1., 2.], [3., 4.]])
matrix2 = np.array([[1., 3.], [3., 5.]])
self._verifyInverseReal(matrix1)
self._verifyInverseReal(matrix2)
# A multidimensional batch of 2x2 matrices
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
def testSymmetricPositiveDefinite(self):
# 2x2 matrices
matrix1 = np.array([[2., 1.], [1., 2.]])
matrix2 = np.array([[3., -1.], [-1., 3.]])
self._verifyInverseReal(matrix1)
self._verifyInverseReal(matrix2)
# A multidimensional batch of 2x2 matrices
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
def testEmpty(self):
self._verifyInverseReal(np.empty([0, 2, 2]))
self._verifyInverseReal(np.empty([2, 0, 0]))
if __name__ == "__main__":
googletest.main()

View File

@ -27,6 +27,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -107,16 +108,18 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
# TODO(penporn): Once XLA supports MatrixDiagV2, change the call to
# gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2).
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag,
gen_array_ops.matrix_diag,
np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
np.array(
@ -126,7 +129,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
[0, 0, 12]]]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag_part,
gen_array_ops.matrix_diag_part,
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))

View File

@ -3,16 +3,8 @@
# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
@ -31,6 +23,42 @@ load(
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
# Placeholder for Google-internal load statements.
# NOTE: we always assume that if_static returns "otherwise" list in open source.
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
cc_library(
name = "tensorrt_stub",
srcs = if_tensorrt([
"stub/nvinfer_stub.cc",
"stub/nvinfer_plugin_stub.cc",
]),
textual_hdrs = glob(["stub/*.inc"]),
deps = if_tensorrt([
"@local_config_tensorrt//:tensorrt_headers",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/platform:dso_loader",
]),
)
alias(
name = "tensorrt_lib",
actual = if_static(
"@local_config_tensorrt//:tensorrt",
":tensorrt_stub",
),
visibility = ["//visibility:private"],
)
tf_cuda_cc_test(
name = "tensorrt_test_cc",
size = "small",
@ -46,8 +74,8 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_tensorrt([
":tensorrt_lib",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_tensorrt//:tensorrt",
]),
)
@ -72,9 +100,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:stream_executor_headers_lib",
"//tensorflow/core/grappler/costs:graph_properties",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
@ -82,7 +108,7 @@ cc_library(
name = "trt_engine_resource_op_kernels",
srcs = ["kernels/trt_engine_resource_ops.cc"],
copts = tf_copts(),
visibility = ["//visibility:private"],
visibility = ["//tensorflow/core:__subpackages__"],
deps = [
":trt_allocator",
":trt_engine_instance_proto_cc",
@ -96,9 +122,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
@ -130,21 +154,6 @@ tf_cuda_cc_test(
],
)
tf_cc_shared_object(
name = "python/ops/libtftrt.so",
copts = tf_copts(is_external = True),
linkopts = ["-lm"],
deps = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
)
tf_cuda_cc_test(
name = "trt_engine_op_test",
size = "small",
@ -164,6 +173,7 @@ tf_cuda_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -197,9 +207,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
] + if_tensorrt([":tensorrt_lib"]),
)
tf_gen_op_wrapper_py(
@ -212,18 +220,6 @@ tf_gen_op_wrapper_py(
tf_custom_op_py_library(
name = "trt_ops_loader",
srcs = ["python/ops/trt_ops.py"],
dso = [
"python/ops/libtftrt.so",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
kernels = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":trt_ops",
@ -254,9 +250,7 @@ tf_cuda_library(
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
] + if_tensorrt([":tensorrt_lib"]),
)
tf_cuda_library(
@ -267,9 +261,7 @@ tf_cuda_library(
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
] + if_tensorrt([":tensorrt_lib"]),
)
tf_cc_test(
@ -338,9 +330,7 @@ tf_cuda_library(
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
@ -373,9 +363,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
] + if_tensorrt([":tensorrt_lib"]),
)
tf_cuda_cc_test(
@ -409,8 +397,8 @@ tf_cuda_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + if_tensorrt([
":tensorrt_lib",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_tensorrt//:tensorrt",
]),
)
@ -428,7 +416,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@protobuf_archive//:protobuf_headers",
"@com_google_protobuf//:protobuf_headers",
],
)
@ -463,9 +451,7 @@ tf_cuda_library(
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
] + if_tensorrt([":tensorrt_lib"]),
)
cc_library(
@ -491,9 +477,7 @@ cc_library(
srcs = ["utils/py_utils.cc"],
hdrs = ["utils/py_utils.h"],
copts = tf_copts(),
deps = if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]),
deps = if_tensorrt([":tensorrt_lib"]),
)
tf_py_wrap_cc(

View File

@ -64,24 +64,6 @@ namespace convert {
using absl::StrAppend;
using absl::StrCat;
TrtCandidateSelector::TrtCandidateSelector(
const grappler::GraphProperties& graph_properties,
TrtPrecisionMode precision_mode)
: graph_properties_(graph_properties), precision_mode_(precision_mode) {}
Status TrtCandidateSelector::IsTensorRTCandidate(const Node* node) {
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
input_node_and_ports.reserve(input_edges.size());
for (const Edge* input_edge : input_edges) {
input_node_and_ports.emplace_back(&input_edge->src()->def(),
input_edge->src_output());
}
return validator_.ValidateNode(node->def(), input_node_and_ports,
precision_mode_, graph_properties_);
}
namespace {
Status BuildNodeMap(const Graph& graph,
@ -478,10 +460,6 @@ Status CreateTRTNode(const ConversionParams& params,
node_builder.ControlInput(c);
}
if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
!info.cached_engine_batches.empty()) {
LOG(WARNING) << "Cached engine batches are ignored for static engines";
}
NodeDef trt_node;
Status status =
node_builder.Attr("input_shapes", input_shape_protos)
@ -740,13 +718,13 @@ Status ConvertAfterShapes(const ConversionParams& params) {
}
segment_options.minimum_segment_size = params.minimum_segment_size;
segment::SegmentNodesVector initial_segments;
TrtCandidateSelector candidate_selector(*params.graph_properties,
params.precision_mode);
TrtNodeValidator validator(*params.graph_properties, params.precision_mode,
params.use_calibration);
TF_RETURN_IF_ERROR(segment::SegmentGraph(
&graph,
std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
std::placeholders::_1),
// Input validation is already done by TrtCandidateSelector, so we don't
// Input validation is already done by TrtNodeValidator, so we don't
// need to check the input edges.
[](const Edge* edge) { return true; }, OutputEdgeValidator(),
segment_options, &initial_segments));
@ -782,7 +760,6 @@ Status ConvertAfterShapes(const ConversionParams& params) {
? EngineInfo::EngineType::TRTDynamic
: EngineInfo::EngineType::TRTStatic);
curr_engine.use_calibration = params.use_calibration;
curr_engine.cached_engine_batches = params.cached_engine_batches;
curr_engine.maximum_cached_engines = params.max_cached_engines;
if (params.use_function_backup) {
status = RegisterSegmentFunctionToFunctionLibrary(

View File

@ -31,30 +31,6 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
// Helper class for the segmenter to determine whether given TF node is
// supported by TRT.
class TrtCandidateSelector {
public:
TrtCandidateSelector(const grappler::GraphProperties& graph_properties,
TrtPrecisionMode precision_mode);
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
// to TRT subgraph and later converted into TRT engine.
Status IsTensorRTCandidate(const Node* node);
private:
// The TF-TRT node converter used to verify whether individual node is
// supported. It will operate in validation-only mode.
TrtNodeValidator validator_;
// GraphProperties of the graph whose nodes are to be validated by
// IsTensorRTCandidate().
const grappler::GraphProperties& graph_properties_;
// Quantization ops are only converted when using quantized precisions.
const TrtPrecisionMode precision_mode_;
};
struct ConversionParams {
const GraphDef* input_graph_def = nullptr;
const std::vector<string>* output_names = nullptr;
@ -70,8 +46,6 @@ struct ConversionParams {
// maximum number of cached engines
int max_cached_engines = 1;
bool use_calibration = true;
// list of cached engines
std::vector<int> cached_engine_batches;
// Whether to use function fallback for TRTEngineOp
bool use_function_backup = true;
};

View File

@ -50,81 +50,6 @@ void ExpectStatus(Status status, error::Code code = error::OK,
}
}
TEST(TrtCandidateSelector, Basics) {
// Create a graph containing both TRT-compatible and TRT-incompatible nodes
// and use it to test TrtCandidateSelector::IsTensorRTCandidate().
const std::vector<int32> input_shape_array{2, 2};
TensorShape input_shape;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape));
Scope s = Scope::NewRootScope();
ops::Placeholder::Attrs feed_attrs;
TF_EXPECT_OK(
TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_));
// Compatible input.
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs);
auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape);
// Compatible MatMul.
auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1);
// Incompatible MatMul.
ops::MatMul::Attrs matmul_attrs;
matmul_attrs.transpose_a_ = true;
auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"),
feed, const_1, matmul_attrs);
// Unsupported op.
auto unsupported_op = ops::Erf(s.WithOpName("sin"), feed);
// Incompatible input.
auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE);
auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape);
// Compatible op with incompatible input.
auto matmul_with_incompatible_input =
ops::MatMul(s.WithOpName("matmul_with_incompatible_input"),
incompatible_feed, const_2);
// Quantize ops.
auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed,
quantize_attrs);
// Get GrapplerItem and GraphProperties.
grappler::GrapplerItem item;
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
Tensor feed_tensor(DT_FLOAT, input_shape);
item.feed.push_back(std::make_pair("feed", feed_tensor));
grappler::GraphProperties graph_properties(item);
TF_EXPECT_OK(graph_properties.InferStatically(true));
for (const TrtPrecisionMode precision_mode :
{TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) {
TrtCandidateSelector selector(graph_properties, precision_mode);
TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
ExpectStatus(
selector.IsTensorRTCandidate(incompatible_matmul.operation.node()),
error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions.");
ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()),
error::UNIMPLEMENTED, "Op type Erf is not supported");
ExpectStatus(
selector.IsTensorRTCandidate(
matmul_with_incompatible_input.operation.node()),
error::INTERNAL,
"Failed to convert input with index 0 to a TRT_TensorOrWeights");
if (precision_mode == TrtPrecisionMode::INT8) {
TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node()));
} else {
ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()),
error::UNIMPLEMENTED,
"Op type FakeQuantWithMinMaxArgs is not supported");
}
}
}
class FakeCluster : public grappler::Cluster {
public:
FakeCluster() : Cluster(0) {}

View File

@ -504,12 +504,11 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
}
// Convert an axis from TF format to TRT format while validating. TF format
// includes the batch dimension, while TRT does not. TF can also use negative
// indices.
// TODO(tmorris): Use this method in more ops.
// includes the batch dimension, while TRT does not if implicit batching is used
// (i.e. for tensors). TF can also use negative indices.
Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
int* trt_axis) {
const int tf_nb_dims = trt_nb_dims + 1;
bool use_implicit_batch, int* trt_axis) {
const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0);
// Check bounds.
if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
return errors::InvalidArgument(
@ -519,13 +518,13 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
// Make negative axis positive.
if (tf_axis < 0) tf_axis += tf_nb_dims;
// Don't allow axis to be the batch dimension.
if (tf_axis == 0) {
if (use_implicit_batch && tf_axis == 0) {
return errors::Unimplemented(
"TensorRT does not allow manipulation of the batch dimension, at ",
node_name);
}
// Remove batch dimension.
*trt_axis = tf_axis - 1;
// Remove batch dimension if it is implicit.
*trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis;
return Status::OK();
}
@ -956,6 +955,31 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
return weights;
}
OpConverterParams::OpConverterParams(
const NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store,
TrtPrecisionMode precision_mode, bool use_calibration)
: node_def(node_def),
inputs(inputs),
outputs(outputs),
validation_only(true),
weight_store(weight_store),
precision_mode(precision_mode),
use_calibration(use_calibration) {}
OpConverterParams::OpConverterParams(
Converter* converter, const NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store)
: converter(converter),
node_def(node_def),
inputs(inputs),
outputs(outputs),
validation_only(false),
weight_store(weight_store),
precision_mode(converter->precision_mode()),
use_calibration(converter->use_calibration()) {}
const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
"QuantizeAndDequantizeV2",
"QuantizeAndDequantizeV3",
@ -963,11 +987,17 @@ const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
"FakeQuantWithMinMaxArgs",
};
TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); }
TrtNodeValidator::TrtNodeValidator(
const grappler::GraphProperties& graph_properties,
TrtPrecisionMode precision_mode, bool use_calibration)
: graph_properties_(graph_properties),
precision_mode_(precision_mode),
use_calibration_(use_calibration) {
RegisterOpValidators();
}
Status TrtNodeValidator::ConvertToTensorOrWeights(
const NodeDef& node_def, int output_port,
const grappler::GraphProperties& graph_properties,
TRT_TensorOrWeights* tensor_or_weights) {
if (node_def.op() == "Const") {
if (output_port != 0) {
@ -983,13 +1013,13 @@ Status TrtNodeValidator::ConvertToTensorOrWeights(
std::vector<TRT_TensorOrWeights> inputs;
return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
}
if (!graph_properties.HasOutputProperties(node_def.name())) {
if (!graph_properties_.HasOutputProperties(node_def.name())) {
return errors::InvalidArgument("Shape and data type are unknown");
}
// Validate and convert shape and dtype.
const auto& output_params =
graph_properties.GetOutputProperties(node_def.name());
graph_properties_.GetOutputProperties(node_def.name());
const auto& tensor_properties = output_params.at(output_port);
const DataType dtype = tensor_properties.dtype();
const PartialTensorShape shape = tensor_properties.shape();
@ -1007,20 +1037,16 @@ Status TrtNodeValidator::ConvertToTensorOrWeights(
return Status::OK();
}
Status TrtNodeValidator::ValidateNode(
const NodeDef& node_def,
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
const TrtPrecisionMode precision_mode,
const grappler::GraphProperties& graph_properties) {
const string& op = node_def.op();
Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) {
const string& op = node->def().op();
// In INT8 mode, we will always apply the quantization ranges provided by
// these ops to the relevant tensors. This happens regardless of the value of
// use_calibration.
bool is_supported_op = false;
if (quantize_ops->count(op)) {
is_supported_op = (precision_mode == TrtPrecisionMode::INT8);
is_supported_op = (precision_mode_ == TrtPrecisionMode::INT8);
} else {
is_supported_op = op_validators_.count(node_def.op());
is_supported_op = op_validators_.count(op);
}
if (!is_supported_op) {
return errors::Unimplemented("Op type ", op, " is not supported.");
@ -1029,23 +1055,24 @@ Status TrtNodeValidator::ValidateNode(
// Convert input NodeDef and corresponding output ports to
// TRT_TensorOrWeights.
std::vector<TRT_TensorOrWeights> inputs;
for (int i = 0; i < input_node_and_ports.size(); ++i) {
const auto& pair = input_node_and_ports[i];
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
for (const Edge* edge : input_edges) {
TRT_TensorOrWeights tensor_or_weights;
Status status = ConvertToTensorOrWeights(
*pair.first, pair.second, graph_properties, &tensor_or_weights);
const NodeDef& src_def = edge->src()->def();
Status status = ConvertToTensorOrWeights(src_def, edge->src_output(),
&tensor_or_weights);
if (!status.ok()) {
return errors::Internal(
"Failed to convert input with index ", i,
"Failed to convert input ", src_def.name(),
" to a TRT_TensorOrWeights: ", status.error_message());
}
inputs.push_back(tensor_or_weights);
}
OpConverter validator = op_validators_[node_def.op()];
OpConverterParams params(
/*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr,
/*arg_validation_only=*/true, &weight_store_);
OpConverter validator = op_validators_[op];
OpConverterParams params(node->def(), inputs, /*arg_outputs=*/nullptr,
&weight_store_, precision_mode_, use_calibration_);
return validator(&params);
}
@ -1054,9 +1081,8 @@ Status TrtNodeValidator::ConvertConstToWeights(
const std::vector<TRT_TensorOrWeights>& inputs,
TRT_TensorOrWeights* output) {
std::vector<TRT_TensorOrWeights> outputs;
OpConverterParams params(
/*arg_converter=*/nullptr, const_node_def, inputs, &outputs,
/*arg_validation_only=*/true, &weight_store_);
OpConverterParams params(const_node_def, inputs, &outputs, &weight_store_,
precision_mode_, use_calibration_);
Status status = op_validators_["Const"](&params);
if (status.ok() && output) *output = outputs[0];
return status;
@ -1108,8 +1134,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) {
std::vector<TRT_TensorOrWeights> inputs, outputs;
TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
OpConverterParams params(this, node_def, inputs, &outputs,
/*arg_validation_only=*/false, &weight_store_);
OpConverterParams params(this, node_def, inputs, &outputs, &weight_store_);
const string& op = node_def.op();
auto itr = op_registry_.find(op);
if (itr == op_registry_.end()) {
@ -2004,8 +2029,8 @@ Status ConvertExpandDims(OpConverterParams* params) {
// Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for
// ExpandDim's ability to add an axis at end of the shape.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
if (params->validation_only) return Status::OK();
// ExpandDims: Insert new dim of size 1.
@ -2040,8 +2065,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
for (int tf_axis : squeeze_dims) {
// Make sure axis is valid.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Make sure target dimension is size 1.
if (input_dims[trt_axis] != 1) {
return errors::InvalidArgument(
@ -2071,7 +2096,8 @@ template <typename Container>
Status ConvertStridedSliceHelper(OpConverterParams* params,
const TRT_TensorOrWeights& input,
Container begin, Container size,
const Container& stride) {
const Container& stride,
const nvinfer1::Dims* final_shape = nullptr) {
const auto& node_def = params->node_def;
// Get input dims.
nvinfer1::Dims dims = input.GetTrtDims();
@ -2110,7 +2136,14 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
*input.tensor(), begin_dims, size_dims, stride_dims);
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
nvinfer1::ITensor* tensor = layer->getOutput(0);
// Reshape for shrink_axis.
if (final_shape) {
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false,
&tensor));
}
params->outputs->push_back(TRT_TensorOrWeights(tensor));
return Status::OK();
#else
// Use IPaddingLayer.
@ -2228,8 +2261,13 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
tensor, inv_transpose_order, &tensor));
}
// Restore reshape
if (need_reshape) {
// Reshape for shrink_axis.
if (final_shape) {
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false,
&tensor));
} else if (need_reshape) {
// Restore reshape.
// Calculate output dimensions
for (int i = 0; i < pad_dims.size(); i++) {
const int axis = pad_dims[i];
@ -2313,17 +2351,17 @@ Status ConvertStridedSlice(OpConverterParams* params) {
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
TFAttrs attrs(node_def);
// Unsupported mask options.
for (const string& attr : {"new_axis_mask", "shrink_axis_mask"}) {
int attr_val = attrs.get<int64>(attr);
if (attr_val != 0) {
return errors::Unimplemented(
attr, " is not supported for StridedSlice, at ", node_def.name());
}
// new_axis_mask is not supported.
const int32 new_axis_mask = attrs.get<int64>("new_axis_mask");
if (new_axis_mask != 0) {
return errors::Unimplemented(
"new_axis_mask is not supported for StridedSlice, at ",
node_def.name());
}
const int32 begin_mask = attrs.get<int64>("begin_mask");
const int32 end_mask = attrs.get<int64>("end_mask");
const int32 ellipsis_mask = attrs.get<int64>("ellipsis_mask");
const int32 shrink_axis_mask = attrs.get<int64>("shrink_axis_mask");
// Get input dims.
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
@ -2355,9 +2393,9 @@ Status ConvertStridedSlice(OpConverterParams* params) {
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
&begin_weights.GetTensor(), &end_weights.GetTensor(),
stride_weights.GetTensor(), input_shape, begin_mask, end_mask,
ellipsis_mask, /*new_axis_mask=*/0,
/*shrink_axis_mask=*/0, &processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape,
&final_shape, &is_identity, &is_simple_slice, &slice_dim0, &begin, &end,
&strides));
// Negative or zero strides currently not supported.
for (int stride : strides) {
@ -2391,13 +2429,29 @@ Status ConvertStridedSlice(OpConverterParams* params) {
node_def.name());
}
}
// Can't shrink axis on batch dimension.
if (shrink_axis_mask & 1) {
return errors::Unimplemented(
"TensorRT does not allow modifications to the batch dimension, at ",
node_def.name());
}
// TRT Slice layer uses (begin, size) instead of (begin, end)
absl::InlinedVector<int64, 4> size(input_dims.size());
for (int i = 0; i < input_dims.size(); i++) {
// Divide by stride (round up)
size[i] = (end[i] - begin[i] + strides[i] - 1) / strides[i];
}
return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides);
// shrink_axis_mask requires a reshape after the slice.
nvinfer1::Dims final_shape_dims;
nvinfer1::Dims* final_shape_dims_ptr = nullptr;
if (shrink_axis_mask) {
final_shape_dims =
TensorShapeToTrtDims(final_shape, /*ignore_first_dim=*/true);
final_shape_dims_ptr = &final_shape_dims;
}
return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides,
final_shape_dims_ptr);
}
Status ConvertConv2D(OpConverterParams* params) {
@ -2795,7 +2849,7 @@ Status ConvertRelu6(OpConverterParams* params) {
#endif
}
Status ConvertBiasAdd(OpConverterParams* params) {
Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
@ -2893,6 +2947,71 @@ Status ConvertBiasAdd(OpConverterParams* params) {
return Status::OK();
}
Status ConvertBiasAdd(OpConverterParams* params) {
if (params->precision_mode == TrtPrecisionMode::INT8 &&
!params->use_calibration) {
// NOTE(laigd): based on some observation, it seems TensorRT cannot fuse
// IConvolutionLayer and IElementwiseLayer and will require range
// information for the output of Conv2D. Using IScaleLayer will fix the
// problem.
return ConvertBiasAddInt8WithoutCalibration(params);
}
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 2) {
return errors::InvalidArgument(
"BiasAdd expects exactly 2 inputs, but received ", inputs.size());
}
if (inputs[0].is_weights() && inputs[1].is_weights()) {
return errors::InvalidArgument(
"All inputs are weights, but Grappler is expected to fold them.");
}
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
TFAttrs attrs(node_def);
const string& data_format = attrs.get<string>("data_format");
nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims();
nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims();
// If the input is NCHW, then we need to unsqueeze the bias such that its last
// dimensions are 1s (and the first dimension is C).
if (data_format == "NCHW") {
bias_shape.nbDims = input_shape.nbDims;
std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1);
} else {
// Next, broadcast the bias across the input.
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1),
&input_shape, &bias_shape));
}
// Convert input to a TRT tensor
nvinfer1::ITensor* input_tensor{nullptr};
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
inputs.at(0), input_shape, params->validation_only, &input_tensor));
// Finally, reshape bias. Since the bias is usually a constant, this will
// normally happen at conversion-time.
nvinfer1::ITensor* bias_tensor{nullptr};
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
inputs.at(1), bias_shape, params->validation_only, &bias_tensor));
VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape);
if (params->validation_only) return Status::OK();
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(
*input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
if (tensor.dims() > 0) {
*dims = GetTrtDimsForTensor(tensor);
@ -3289,9 +3408,9 @@ Status ConvertReduce(OpConverterParams* params) {
}
for (int i = 0; i < tf_axes_list.size(); i++) {
int trt_axis;
TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i],
tensor->getDimensions().nbDims,
node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims,
node_def.name(), /*use_implicit_batch=*/true, &trt_axis));
axes |= (1 << trt_axis);
}
@ -3358,8 +3477,8 @@ Status ConvertPack(OpConverterParams* params) {
const nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
const int64 tf_axis = attrs.get<int64>("axis");
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Compute expanded dimensions and then reshape input tensors.
std::vector<int> tensor_dims(dims.d, dims.d + dims.nbDims);
@ -3506,8 +3625,8 @@ Status ConvertSplitHelper(OpConverterParams* params,
const nvinfer1::Dims dims = input.GetTrtDims();
// Convert axis.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Dimension must equal num_splits for Unstack (when squeeze_after is true)
if (squeeze_after && dims.d[trt_axis] != num_splits) {
return errors::InvalidArgument(
@ -3536,31 +3655,23 @@ Status ConvertSplitHelper(OpConverterParams* params,
begin.insert(begin.begin(), 0);
size.insert(size.begin(), 1);
stride.insert(stride.begin(), 1);
// Create final shape for Unpack/Unstack, where split axis is squeezed.
nvinfer1::Dims final_shape_for_unpack;
nvinfer1::Dims* final_shape_for_unpack_ptr = nullptr;
if (squeeze_after) {
std::vector<int> size_after_squeeze(size);
size_after_squeeze.erase(size_after_squeeze.begin() + trt_axis + 1);
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(
size_after_squeeze, &final_shape_for_unpack, /*ignore_frst_dim=*/true));
final_shape_for_unpack_ptr = &final_shape_for_unpack;
}
// Slice the input. ConvertStridedSliceHelper will push the outputs onto
// params->outputs.
for (int i = 0; i < num_splits; ++i) {
begin[trt_axis + 1] = i * split_size_on_axis;
TF_RETURN_IF_ERROR(
ConvertStridedSliceHelper(params, input, begin, size, stride));
}
if (params->validation_only) return Status::OK();
// For Unpack/Unstack, remove axis that we split upon.
if (squeeze_after) {
// Create the new shape.
size.erase(size.begin() + trt_axis + 1);
nvinfer1::Dims new_dims;
TF_RETURN_IF_ERROR(
TensorShapeArrayToTrtDims(size, &new_dims, /*ignore_frst_dim=*/true));
// Reshape each slice.
for (int i = 0; i < params->outputs->size(); i++) {
nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
params->outputs->at(i), new_dims, /*validation_only=*/false,
&output_tensor));
(*params->outputs)[i] = TRT_TensorOrWeights(output_tensor);
}
TF_RETURN_IF_ERROR(ConvertStridedSliceHelper(
params, input, begin, size, stride, final_shape_for_unpack_ptr));
}
return Status::OK();
}
@ -3635,8 +3746,8 @@ Status ConvertConcat(OpConverterParams* params) {
}
int trt_axis = 0;
const auto dim = inputs.at(0).GetTrtDims();
TF_RETURN_IF_ERROR(
ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Check that dimensions match on non-concatenate axis.
TF_RETURN_IF_ERROR(VerifyShapesMatch(
absl::Span<const TRT_TensorOrWeights>(inputs).first(num_inputs), trt_axis,
@ -3795,29 +3906,58 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
Status ConvertGather(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
// TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
// option for an input to be either tensor or weight.
if (inputs.size() != 3) {
return errors::InvalidArgument("GatherV2 got ", inputs.size(),
" inputs but expected 3, at ",
node_def.name());
}
const auto& params_input = inputs.at(0);
const auto& indices_input = inputs.at(1);
const auto& axis_input = inputs.at(2);
if (!axis_input.is_weights()) {
return errors::Unimplemented(
"The input \"axis\" for GatherV2 must be a constant, at ",
node_def.name());
}
if (!indices_input.is_tensor()) {
return errors::Unimplemented(
"The input \"indices\" for GatherV2 must be a tensor, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
/*dtype_attr_name=*/"Tparams"));
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32},
/*dtype_attr_name=*/"Tindices"));
absl::Span<const int> axis = axis_input.weights().GetSpan<int>();
if (axis.size() != 1) {
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
node_def.name());
}
int trt_axis = 0;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
node_def.name(), &trt_axis));
const TRT_TensorOrWeights& params_tensor = inputs.at(0);
const TRT_TensorOrWeights& indices_tensor = inputs.at(1);
if (indices_tensor.batch_size() != 1) {
return errors::InvalidArgument("Only indices with batch 1 are supported.");
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims,
node_def.name(), params_input.is_tensor(),
&trt_axis));
if (params_input.is_weights() && trt_axis != 0) {
return errors::Unimplemented(
"The input axis must be zero when params is a weight.");
}
if (params_input.is_tensor() && indices_input.batch_size() != 1) {
return errors::Unimplemented(
"Indices must have a batch size of 1 when params is a tensor.");
}
// Both input are tensors, and the TF gather result will have rank:
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
// where "+ 1" adds the batch dim.
const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
indices_tensor.GetTrtDims().nbDims + 1;
// where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
// the TF rank so we don't have to add + 1.
const int params_tf_rank =
params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0);
const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1;
const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
return errors::InvalidArgument(
"Result of gather has dimension greater than ",
@ -3825,38 +3965,50 @@ Status ConvertGather(OpConverterParams* params) {
}
if (params->validation_only) return Status::OK();
// Convert params to tensor is it is a weight.
nvinfer1::ITensor* params_tensor = nullptr;
if (params_input.is_weights()) {
params_tensor = params->converter->CreateConstantLayer(
params_input.weights(), params_input.GetTrtDims());
} else {
params_tensor = params_input.tensor();
}
// Note on how IGatherLayer works: if both the data and indices tensors have
// a batch size dimension of size N, it performs:
// for batchid in xrange(N):
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
*params_tensor.tensor(), *indices_tensor.tensor(), trt_axis);
*params_tensor, *indices_input.tensor(), trt_axis);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* gather_output = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
// Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
// and the other is for the output dimension that is squeezed by IGatherLayer
// because of the implicit batch dim in the indices (see the above note).
if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
const int expected_trt_output_rank =
tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
return errors::Internal(
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
tf_gather_output_rank - 2,
expected_trt_output_rank,
", actual nbDims: ", trt_gather_output_dims.nbDims);
}
// Reshape the output so after adding the implicit batch dim it'll match the
// output shape of TF GatherV2.
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
trt_gather_output_dims.d[trt_axis] = 1;
++trt_gather_output_dims.nbDims;
if (params_input.is_tensor()) {
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
trt_gather_output_dims.d[trt_axis] = 1;
++trt_gather_output_dims.nbDims;
nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
/*validation_only=*/false, &output_tensor));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(output_tensor), trt_gather_output_dims,
/*validation_only=*/false, &output_tensor));
}
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
@ -4116,8 +4268,8 @@ Status ConvertArgMinMax(OpConverterParams* params) {
int tf_axis = inputs.at(1).weights().GetSpan<int>()[0];
int trt_axis;
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
nvinfer1::TopKOperation topk_op;
if (node_def.op() == "ArgMin") {
topk_op = nvinfer1::TopKOperation::kMIN;

View File

@ -116,7 +116,6 @@ struct EngineInfo {
EngineType engine_type;
int64 max_workspace_size_bytes;
int maximum_cached_engines;
std::vector<int> cached_engine_batches;
TrtPrecisionMode precision_mode;
bool use_calibration;
};
@ -354,23 +353,27 @@ class Converter;
// Parameters for each op converter.
struct OpConverterParams {
OpConverterParams(Converter* arg_converter, const NodeDef& arg_node_def,
const std::vector<TRT_TensorOrWeights>& arg_inputs,
std::vector<TRT_TensorOrWeights>* arg_outputs,
bool arg_validation_only, TrtWeightStore* arg_weight_store)
: converter(arg_converter),
node_def(arg_node_def),
inputs(arg_inputs),
outputs(arg_outputs),
validation_only(arg_validation_only),
weight_store(arg_weight_store) {}
// Constructor used for validation only.
OpConverterParams(const NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs,
TrtWeightStore* weight_store,
TrtPrecisionMode precision_mode, bool use_calibration);
Converter* converter;
// Constructor used for conversion.
OpConverterParams(Converter* converter, const NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs,
TrtWeightStore* weight_store);
Converter* converter = nullptr;
const NodeDef& node_def;
const std::vector<TRT_TensorOrWeights>& inputs;
std::vector<TRT_TensorOrWeights>* outputs;
const bool validation_only;
TrtWeightStore* weight_store;
const TrtPrecisionMode precision_mode;
const bool use_calibration;
};
using OpConverter = std::function<Status(OpConverterParams*)>;
@ -378,21 +381,15 @@ using OpConverter = std::function<Status(OpConverterParams*)>;
// Class to verify if specific TF node is supported by TRT.
class TrtNodeValidator {
public:
TrtNodeValidator();
// 'graph_properties' is the GraphProperties of the graph whose nodes will be
// checked by IsTensorRTCandidate() later. It is used to get the shape and
// data type information of a tensor for validation purpose.
TrtNodeValidator(const grappler::GraphProperties& graph_properties,
TrtPrecisionMode precision_mode, bool use_calibration);
// Validate the node, and return ok if it's supported by TRT.
//
// - 'node_def' is the node to validate.
// - 'input_node_and_ports' are the input NodeDefs and their output ports that
// are connected to 'node_def' in the TF graph.
// - 'graph_properties' is the GraphProperties of the graph where 'node_def'
// belongs. It is used to get the shape and data type information of a
// tensor for validation purpose.
Status ValidateNode(
const NodeDef& node_def,
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
const TrtPrecisionMode precision_mode,
const grappler::GraphProperties& graph_properties);
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
// to TRT subgraph and later converted into TRT engine.
Status IsTensorRTCandidate(const Node* node);
private:
static const std::set<string>* quantize_ops;
@ -407,10 +404,8 @@ class TrtNodeValidator {
// Convert the output tensor at 'output_port' of 'node_def' to a
// TRT_TensorOrWeights which will be later used as an input to other nodes and
// passed to ValidateNode() below.
Status ConvertToTensorOrWeights(
const NodeDef& node_def, int output_port,
const grappler::GraphProperties& graph_properties,
TRT_TensorOrWeights* tensor_or_weights);
Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
TRT_TensorOrWeights* tensor_or_weights);
// Stores all the validators by op type. If no validator is registered for
// specific op, it means no validation is needed and ValidateNode() will
@ -421,6 +416,15 @@ class TrtNodeValidator {
// validation for Const node) may produce weights.
TrtWeightStore weight_store_;
// GraphProperties of the graph whose nodes are to be validated by
// IsTensorRTCandidate().
const grappler::GraphProperties& graph_properties_;
// Quantization ops are only converted when using quantized precisions.
const TrtPrecisionMode precision_mode_;
const bool use_calibration_;
friend class ValidatorTest;
friend class OpConverterTest;
};

File diff suppressed because it is too large Load Diff

View File

@ -54,13 +54,6 @@ Status TRTOptimizationPass::Init(
if (params.count("is_dynamic_op")) {
is_dynamic_op_ = params.at("is_dynamic_op").b();
}
if (params.count("cached_engine_batches")) {
auto batch_vec = params.at("cached_engine_batches").list();
batches_.reserve(batch_vec.i_size());
for (const auto i : batch_vec.i()) {
batches_.push_back(i);
}
}
if (params.count("maximum_cached_engines")) {
max_cached_batches_ = params.at("maximum_cached_engines").i();
}
@ -264,7 +257,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
cp.graph_properties = &static_graph_properties;
cp.cluster = cluster;
cp.is_dyn_op = is_dynamic_op_;
cp.cached_engine_batches = batches_;
cp.max_cached_engines = max_cached_batches_;
cp.use_calibration = use_calibration_;
cp.use_function_backup = use_function_backup_;

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@ -49,6 +50,7 @@ static Logger logger;
using absl::StrAppend;
using absl::StrCat;
using ::nvinfer1::IRuntime;
using ::stream_executor::port::StatusOr;
// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
@ -80,6 +82,10 @@ class TRTEngineOp : public AsyncOpKernel {
AsyncOpKernel::DoneCallback done) override;
private:
using CacheType =
LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
VectorTensorShapeHasher>;
// Execute calibration
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
@ -98,12 +104,16 @@ class TRTEngineOp : public AsyncOpKernel {
TRTCalibrationResource** cr);
// Get engine for the input shape
EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
OpKernelContext* ctx);
StatusOr<EngineContext*> GetEngine(
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx);
// Verify that the input shapes are consistent and can be handled by this op.
Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
// Return engine batch in cached_engne_batch_sizes_ which is closest to input
// batch.
bool GetCompatibleCachedEngine(
Status GetEngineInputShapes(
const CacheType& cache,
const std::vector<TensorShape>& actual_input_shapes,
std::vector<TensorShape>* engine_input_shapes);
@ -131,9 +141,6 @@ class TRTEngineOp : public AsyncOpKernel {
// Whether to calibrate INT8 engine.
bool calibration_mode_;
// Batches of the cached engines
std::vector<int> cached_engine_batches_;
// Maximum number of cached engines
int max_cached_engines_;
@ -160,6 +167,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
default: {
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
return nullptr;
@ -195,12 +203,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
if (!static_engine_) {
if (!segment_graph_.ParseFromString(serialized_segment_)) {
LOG(ERROR) << "Parsing segment graph failed!";
context->SetStatus(
errors::InvalidArgument("Failed to parse segment graphdef!"));
return;
}
OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
errors::InvalidArgument("Failed to parse segment graphdef!"));
VLOG(1) << "Size of serialized GraphDef: "
<< serialized_segment_.capacity();
string tmp;
@ -230,16 +234,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
native_func_ = kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
&cached_engine_batches_));
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
if (VLOG_IS_ON(1)) {
string s("Engine Batches= ");
for (auto i : cached_engine_batches_) {
StrAppend(&s, i, " ");
}
VLOG(1) << s;
}
}
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
@ -298,11 +292,10 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
for (int i = 0; i < num_inputs; i++) {
const Tensor& t = ctx->input(i);
void* data_address = GetTensorAddress(&t);
if (data_address == nullptr) {
ctx->SetStatus(errors::InvalidArgument(
"Unsupported data type encountered in input ", i));
return;
}
OP_REQUIRES_ASYNC(ctx, data_address,
errors::InvalidArgument(
"Unsupported data type encountered in input ", i),
*helper);
// Check the allocated buffer is sufficient for input
const auto device_tensor =
calib_res->device_tensors_.at(i).AccessTensor(ctx);
@ -331,34 +324,74 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
ExecuteNativeSegment(ctx, helper);
}
bool TRTEngineOp::GetCompatibleCachedEngine(
const std::vector<TensorShape>& actual_input_shapes,
Status TRTEngineOp::VerifyInputShapes(const std::vector<TensorShape>& shapes) {
if (shapes.empty()) {
return errors::InvalidArgument("Input shapes are empty, for ", name());
}
if (shapes[0].dims() < 1) {
return errors::InvalidArgument("Input shapes contain scalar, for ", name(),
": ",
TensorShapeUtils::ShapeListString(shapes));
}
const int batch_size = shapes[0].dim_size(0);
for (const TensorShape& shape : shapes) {
if (shape.dims() < 1 || batch_size != shape.dim_size(0)) {
return errors::InvalidArgument(
"Input shapes are inconsistent on the batch dimension, for ", name(),
": ", TensorShapeUtils::ShapeListString(shapes));
}
}
return Status::OK();
}
Status TRTEngineOp::GetEngineInputShapes(
const CacheType& cache, const std::vector<TensorShape>& actual_input_shapes,
std::vector<TensorShape>* engine_input_shapes) {
const int batch_size = actual_input_shapes[0].dim_size(0);
int smallest_batch_size = -1;
// Output shape will always be the same as the input but we will overwrite the
// batch size.
auto match_shape = [](const TensorShape& actual_shape,
const TensorShape& cached_shape) {
// Match the rank.
if (actual_shape.dims() != cached_shape.dims()) return false;
// Match the batch size.
if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
// Match remaining dimensions.
for (int i = 1; i < actual_shape.dims(); ++i) {
if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
}
return true;
};
auto match_shapes = [&](const std::vector<TensorShape>& actual_shapes,
const std::vector<TensorShape>& cached_shapes) {
for (int i = 0; i < actual_shapes.size(); ++i) {
if (!match_shape(actual_shapes[i], cached_shapes[i])) {
return false;
}
}
return true;
};
// VerifyInputShapes() already ensured that all input shapes have same
// batch size, and are not scalars.
*engine_input_shapes = actual_input_shapes;
for (const int cached_batch_size : cached_engine_batches_) {
// Check if compatible: batch <= cached batch.
//
// TODO(laigd): here it only compare the first dim a.k.a the batch size,
// we'll need to to support non-batch dimensions as well. This will be done
// as part of the offline conversion implementation.
if (batch_size <= cached_batch_size) {
// First case: first compatible engine found
// Second case: smaller batch size engine found
if ((smallest_batch_size == -1) ||
(cached_batch_size < smallest_batch_size)) {
smallest_batch_size = cached_batch_size;
// Overwrite batch size for output
for (int i = 0; i < engine_input_shapes->size(); i++) {
(*engine_input_shapes)[i].set_dim(0, smallest_batch_size);
}
int64 min_matched_batch_size = kint64max;
for (const auto& pair : cache) {
const std::vector<TensorShape>& cached_input_shapes = pair.first;
// This should not happen, but just for safety.
if (actual_input_shapes.size() != cached_input_shapes.size()) {
return errors::InvalidArgument(
"Input shape list size mismatch for ", name(),
", cached size: ", cached_input_shapes.size(),
" vs. actual size: ", actual_input_shapes.size());
}
if (match_shapes(actual_input_shapes, cached_input_shapes)) {
const int cached_batch_size = cached_input_shapes[0].dim_size(0);
if (min_matched_batch_size > cached_batch_size) {
min_matched_batch_size = cached_batch_size;
*engine_input_shapes = cached_input_shapes;
}
}
}
return (smallest_batch_size != -1);
return Status::OK();
}
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
@ -375,7 +408,10 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
for (int i = 0; i < ctx->num_inputs(); ++i) {
input_shapes.push_back(ctx->input(i).shape());
}
EngineContext* engine_context = GetEngine(input_shapes, ctx);
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_shapes), *helper);
StatusOr<EngineContext*> status = GetEngine(input_shapes, ctx);
OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
EngineContext* engine_context = status.ValueOrDie();
if (!engine_context->cuda_engine) {
VLOG(1) << "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_shapes)
@ -519,7 +555,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
return !kRetry;
}
EngineContext* TRTEngineOp::GetEngine(
StatusOr<EngineContext*> TRTEngineOp::GetEngine(
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
static EngineContext empty_context;
mutex_lock lock(engine_mutex_);
@ -609,21 +645,11 @@ EngineContext* TRTEngineOp::GetEngine(
// See if there is a compatible engine cached. The batch size should be <= the
// cached batch size.
std::vector<TensorShape> engine_input_shapes;
const bool matched_successfully =
GetCompatibleCachedEngine(input_shapes, &engine_input_shapes);
TF_RETURN_IF_ERROR(
GetEngineInputShapes(cache, input_shapes, &engine_input_shapes));
// If matched, use that engine. Otherwise, we will look in cache for that
// exact shape and possibly create a new engine if it is not in cache.
if (!matched_successfully) {
engine_input_shapes = input_shapes;
if (!cached_engine_batches_.empty()) {
// If user has explicitly defined cached_engine_batches, we should
// warn them that their input was non-compatible (batch size too high)
LOG(WARNING) << "No compatible cached engine was found for batch size: "
<< batch_size << ". A new engine will be created.";
cached_engine_batches_.push_back(batch_size);
}
}
if (!cache.count(engine_input_shapes)) {
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
bool convert_successfully = false;

View File

@ -17,17 +17,20 @@ limitations under the License.
#include <string.h>
#include <fstream>
#include <numeric>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
@ -38,48 +41,115 @@ namespace tensorflow {
namespace tensorrt {
using ::testing::ElementsAre;
class TRTEngineOpTestBase : public OpsTestBase {
public:
void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1) {
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
// Create simple TF graph.
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
ops::Placeholder::Shape({-1, -1}));
auto add = ops::Add(s.WithOpName("add"), feed, feed);
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
GraphDef graph_def;
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
PartialTensorShape shape({-1, -1});
// Create the op.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
.Attr("output_shapes", {shape})
.Attr("static_engine", false)
.Attr("segment_funcdef_name", "") // no native fallback
.Attr("serialized_segment", graph_def.SerializeAsString())
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", max_cached_engines_count)
.Attr("workspace_size_bytes", 1 << 20)
.Attr("precision_mode", "FP32")
.Attr("use_calibration", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(OpsTestBase::InitOp());
}
template <typename T>
void AddSimpleInput(const TensorShape& shape) {
std::vector<T> input(shape.num_elements());
std::iota(input.begin(), input.end(), T(0));
OpsTestBase::AddInputFromArray<T>(shape, input);
}
void ResetInputs() {
inputs_.clear();
gtl::STLDeleteElements(&tensors_);
}
};
TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
// Execute the op with batch size > 1.
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({2, 2}));
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Get the engine cache.
TRTEngineCacheResource* cache_resource = nullptr;
TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache",
"myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// It should contain only one engine.
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({2, 2})));
// Execute the op with batch size 1. It should reuse existing engine to
// execute.
ResetInputs();
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 2}));
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
EXPECT_EQ(1, cache->size());
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({2, 2})));
// Execute the op with a larger batch size.
ResetInputs();
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({3, 2}));
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
EXPECT_EQ(2, cache->size());
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({3, 2})));
EXPECT_THAT((++cache->begin())->first, ElementsAre(TensorShape({2, 2})));
// Execute the op with an input that has different non-batch dimension.
ResetInputs();
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({10, 10}));
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Execute it again with an input that has the same non-batch dimension but
// smallest batch size. It should find the correct engine to use.
ResetInputs();
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 10}));
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
EXPECT_EQ(3, cache->size()); // Should only create 3 engines in total.
auto iter = cache->begin();
EXPECT_THAT(iter->first, ElementsAre(TensorShape({10, 10})));
EXPECT_THAT((++iter)->first, ElementsAre(TensorShape({3, 2})));
EXPECT_THAT((++iter)->first, ElementsAre(TensorShape({2, 2})));
}
template <typename T>
class TRTEngineOpTest : public OpsTestBase {};
class TRTEngineOpTest : public TRTEngineOpTestBase {};
using TypeList = ::testing::Types<float, Eigen::half>;
TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
TYPED_TEST(TRTEngineOpTest, Basic) {
DataType dtype = DataTypeToEnum<TypeParam>::v();
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
// Create simple TF graph.
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
ops::Placeholder::Shape({1, 2}));
auto add = ops::Add(s.WithOpName("add"), feed, feed);
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
GraphDef graph_def;
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
TensorShapeProto shape;
TensorShape({1, 2}).AsProto(&shape);
// Create the op.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
.Attr("output_shapes", {shape})
.Attr("static_engine", false)
.Attr("segment_funcdef_name", "") // no native fallback
.Attr("serialized_segment", graph_def.SerializeAsString())
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", 1)
.Attr("workspace_size_bytes", 1 << 20)
.Attr("precision_mode", "FP32")
.Attr("use_calibration", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(OpsTestBase::InitOp());
TRTEngineOpTestBase::AddSimpleTrtOp(DataTypeToEnum<TypeParam>::v());
// Execute the op.
OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),

View File

@ -38,7 +38,6 @@ REGISTER_OP("TRTEngineOp")
.Attr("segment_funcdef_name: string")
.Attr("InT: list({int8,float16,float32,int32})")
.Attr("OutT: list({int8,float16,float32,int32})")
.Attr("cached_engine_batches: list(int) >= 0 = []")
.Attr("max_cached_engines_count: int = 1")
.Attr("workspace_size_bytes: int")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8'}")
@ -54,6 +53,7 @@ REGISTER_OP("TRTEngineOp")
// inference function as a workaround.
.SetShapeFn(shape_inference::UnknownShape)
// Deprecated attributes.
.Attr("cached_engine_batches: list(int) >= 0 = []")
.Attr("fixed_input_size: bool = true")
.Attr("static_engine: bool = true");
} // namespace tensorflow

View File

@ -1,75 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper of TRTEngineOp."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import platform
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.python.framework import errors
_tf_trt_so = None
_module_lock = threading.Lock()
def load_trt_ops():
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
global _tf_trt_so
if not is_tensorrt_enabled():
return
if platform.system() == "Windows":
raise RuntimeError("Windows platforms are not supported")
with _module_lock:
if _tf_trt_so:
return
try:
# pylint: disable=g-import-not-at-top,unused-variable
# This will call register_op_list() in
# tensorflow/python/framework/op_def_registry.py, but it doesn't register
# the op or the op kernel in C++ runtime.
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
# pylint: enable=g-import-not-at-top,unused-variable
except ImportError as e:
print("**** Failed to import TF-TRT ops. This is because the binary was "
"not built with CUDA or TensorRT enabled. ****")
raise e
try:
# pylint: disable=g-import-not-at-top
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
# pylint: enable=g-import-not-at-top
# Loading the shared object will cause registration of the op and the op
# kernel if we link TF-TRT dynamically.
_tf_trt_so = load_library.load_op_library(
resource_loader.get_path_to_datafile("libtftrt.so"))
except errors.NotFoundError as e:
no_trt_message = (
"**** Failed to initialize TensorRT. This is either because the "
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
"you do not have it installed. If not installed, please go to "
"https://developer.nvidia.com/tensorrt to download and install "
"TensorRT ****")
print(no_trt_message)
raise e

View File

@ -0,0 +1,87 @@
// Auto-generated, do not edit.
extern "C" {
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
int nmsMaxOut, float iouThreshold,
float minBoxSize, float spatialScale,
nvinfer1::DimsHW pooling,
nvinfer1::Weights anchorRatios,
nvinfer1::Weights anchorScales) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
}
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
bool acrossSpatial,
bool channelShared, float eps) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
return func_ptr(scales, acrossSpatial, channelShared, eps);
}
nvinfer1::IPluginV2* createPriorBoxPlugin(
nvinfer1::plugin::PriorBoxParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
return func_ptr(param, numLayers);
}
nvinfer1::IPluginV2* createNMSPlugin(
nvinfer1::plugin::DetectionOutputParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
return func_ptr(negSlope);
}
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
return func_ptr(stride);
}
nvinfer1::IPluginV2* createRegionPlugin(
nvinfer1::plugin::RegionParameters params) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
return func_ptr(params);
}
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
float clipMax) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
return func_ptr(layerName, clipMin, clipMax);
}
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
using FuncPtr = bool ( *)(void *, const char *);
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
return func_ptr(logger, libNamespace);
}
} // extern "C"

View File

@ -0,0 +1,95 @@
// Auto-generated, do not edit.
extern "C" {
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
int nmsMaxOut, float iouThreshold,
float minBoxSize, float spatialScale,
nvinfer1::DimsHW pooling,
nvinfer1::Weights anchorRatios,
nvinfer1::Weights anchorScales) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
}
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
bool acrossSpatial,
bool channelShared, float eps) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
return func_ptr(scales, acrossSpatial, channelShared, eps);
}
nvinfer1::IPluginV2* createPriorBoxPlugin(
nvinfer1::plugin::PriorBoxParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
return func_ptr(param, numLayers);
}
nvinfer1::IPluginV2* createNMSPlugin(
nvinfer1::plugin::DetectionOutputParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
return func_ptr(param);
}
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
return func_ptr(negSlope);
}
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
return func_ptr(stride);
}
nvinfer1::IPluginV2* createRegionPlugin(
nvinfer1::plugin::RegionParameters params) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
return func_ptr(params);
}
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
float clipMax) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
return func_ptr(layerName, clipMin, clipMax);
}
nvinfer1::IPluginV2* createBatchedNMSPlugin(
nvinfer1::plugin::NMSParameters param) {
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::NMSParameters);
static auto func_ptr = LoadSymbol<FuncPtr>("createBatchedNMSPlugin");
if (!func_ptr) LogFatalSymbolNotFound("createBatchedNMSPlugin");
return func_ptr(param);
}
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
using FuncPtr = bool ( *)(void *, const char *);
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
return func_ptr(logger, libNamespace);
}
} // extern "C"

View File

@ -0,0 +1,40 @@
// Auto-generated, do not edit.
extern "C" {
void* createInferBuilder_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
return func_ptr(logger, version);
}
void* createInferRuntime_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
return func_ptr(logger, version);
}
nvinfer1::ILogger* getLogger() {
using FuncPtr = nvinfer1::ILogger * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
return func_ptr();
}
int getInferLibVersion() {
using FuncPtr = int (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
return func_ptr();
}
nvinfer1::IPluginRegistry* getPluginRegistry() {
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
return func_ptr();
}
} // extern "C"

View File

@ -0,0 +1,47 @@
// Auto-generated, do not edit.
extern "C" {
void* createInferBuilder_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
return func_ptr(logger, version);
}
void* createInferRefitter_INTERNAL(void* engine, void* logger, int version) {
using FuncPtr = void * (*)(void *, void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRefitter_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRefitter_INTERNAL");
return func_ptr(engine, logger, version);
}
void* createInferRuntime_INTERNAL(void* logger, int version) {
using FuncPtr = void * (*)(void *, int);
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
return func_ptr(logger, version);
}
nvinfer1::ILogger* getLogger() {
using FuncPtr = nvinfer1::ILogger * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
return func_ptr();
}
int getInferLibVersion() {
using FuncPtr = int (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
return func_ptr();
}
nvinfer1::IPluginRegistry* getPluginRegistry() {
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
return func_ptr();
}
} // extern "C"

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "third_party/tensorrt/NvInferPlugin.h"
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
tensorflow::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
} // namespace
#if NV_TENSORRT_MAJOR < 5
#error TensorRT version earlier than 5 is not supported.
#elif NV_TENSORRT_MINOR < 1
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc"
#else
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc"
#endif

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "third_party/tensorrt/NvInfer.h"
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetNvInferDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
tensorflow::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
} // namespace
#if NV_TENSORRT_MAJOR < 5
#error TensorRT version earlier than 5 is not supported.
#elif NV_TENSORRT_MINOR < 1
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc"
#else
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc"
#endif

View File

@ -1,4 +1,10 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
load(
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
package(
default_visibility = [":internal"],
@ -26,13 +32,6 @@ package_group(
],
)
load(
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
cc_library(
name = "tf2xla_supported_ops_lib",
srcs = ["tf2xla_supported_ops.cc"],
@ -40,7 +39,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -108,7 +106,6 @@ cc_library(
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:xla_computation",

View File

@ -1,10 +1,10 @@
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
licenses = ["notice"], # Apache 2.0
)
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
tf_gen_op_wrapper_cc(
name = "xla_ops_gen",
out_ops_file = "ops/xla_ops",

View File

@ -54,6 +54,7 @@ tf_kernel_library(
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
"matrix_inverse_op.cc",
"matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
@ -79,6 +80,7 @@ tf_kernel_library(
"retval_op.cc",
"reverse_op.cc",
"reverse_sequence_op.cc",
"roll_op.cc",
"scan_ops.cc",
"scatter_nd_op.cc",
"segment_reduction_ops.cc",
@ -345,50 +347,3 @@ tf_kernel_library(
],
alwayslink = 1,
)
# Kernels that only work on CPU, because they use XLA custom calls.
# Only link this when using the CPU backend for XLA.
tf_kernel_library(
name = "xla_cpu_only_ops",
srcs = ["index_ops_cpu.cc"],
deps = [
":index_ops_kernel_argmax_float_1d",
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:lib",
],
)
cc_library(
name = "index_ops_kernel_argmax_float_1d",
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)
cc_library(
name = "index_ops_kernel_argmax_float_2d",
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)

View File

@ -1,142 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Native XLA implementations of indexing ops.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
// The logic below uses a custom-call to implement argmax when possible. When
// custom-call is not allowed or input shapes are not supported, this kernel
// falls back to using XLA HLO native ArgMax.
//
// Also see b/29507024 for first-class XLA support for indexing ops.
class ArgMaxCustomCallOp : public XlaOpKernel {
public:
explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape dimension_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
errors::InvalidArgument(
"dim must be a scalar, but received tensor of shape: ",
dimension_shape.DebugString()));
// We require that the dimension argument is a constant, since it lets us
// dispatch to a specialized custom-call function without any run-time
// overhead, when compiling ahead-of-time.
int64 dim;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
const int input_dims = input_shape.dims();
const int axis = dim < 0 ? dim + input_dims : dim;
OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
errors::InvalidArgument("Expected dimension in the range [",
-input_dims, ", ", input_dims,
"), but got ", dim));
const int64 axis_size = input_shape.dim_size(axis);
OP_REQUIRES(ctx, axis_size > 0,
errors::InvalidArgument(
"Reduction axis ", dim,
" is empty in shape: ", input_shape.DebugString()));
const DataType dtype = output_type(0);
xla::PrimitiveType output_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
// Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
// shape isn't supported.
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
(input_dims != 1 && input_dims != 2)) {
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
ctx->SetOutput(0, output);
return;
}
xla::XlaOp output;
// The output shape is the input shape contracted along axis.
TensorShape output_shape;
for (int d = 0; d < input_shape.dims() - 1; ++d) {
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
}
xla::XlaBuilder& b = *ctx->builder();
// XLA passes <out> to the function, so it is not included here.
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
}
// The argmax function expects row-major layout.
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
xla::S64, output_shape.dim_sizes());
std::vector<xla::Shape> arg_shapes;
for (const xla::XlaOp& arg : args) {
auto shape_status = b.GetShape(arg);
OP_REQUIRES_OK(ctx, shape_status.status());
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
*arg_shape.mutable_layout() =
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
arg_shapes.push_back(std::move(arg_shape));
}
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_{1, 2}d.cc.
if (input_dims == 1) {
output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
xla_shape, arg_shapes);
} else {
output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
xla_shape, arg_shapes);
}
output = xla::ConvertElementType(output, output_type);
ctx->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
};
REGISTER_XLA_OP(Name("ArgMax")
.TypeConstraint("T", DT_FLOAT)
.Device(DEVICE_CPU_XLA_JIT)
.CompileTimeConstantInput("dimension"),
ArgMaxCustomCallOp);
} // namespace
} // namespace tensorflow

View File

@ -1,52 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
// Data is managed by the JIT code so msan can't tell it's initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
float* input = static_cast<float*>(data[0]);
int64 input_size = *static_cast<int64*>(data[1]);
Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
int64* out_t = static_cast<int64*>(out);
TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
out_eig = in_eig.argmax(0).cast<int64>();
}
} // namespace tensorflow
// Implements argmax on CPU. This is called by an XLA custom call, set up by
// index_ops.cc.
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_1d_xla_impl(out, data);
}
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);

View File

@ -1,54 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
// data is managed by the JIT code so msan can't tell it's initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
float* in = static_cast<float*>(data[0]);
int64* in_sizes = static_cast<int64*>(data[1]);
int64* out_sizes = static_cast<int64*>(data[2]);
int32 dim = *static_cast<int32*>(data[3]);
Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
int64* out_t = static_cast<int64*>(out);
Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
out_eig = in_eig.argmax(dim).cast<int64>();
}
} // namespace tensorflow
// Implements argmax on CPU. This is called by an XLA custom call, set up by
// index_ops.cc.
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_2d_xla_impl(out, data);
}
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);

View File

@ -0,0 +1,68 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
class MatrixInverseOp : public XlaOpKernel {
public:
explicit MatrixInverseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
int64 ndims = input_shape.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("Input must have rank >= 2, got ", ndims));
OP_REQUIRES(
ctx, input_shape.dim_size(ndims - 2) == input_shape.dim_size(ndims - 1),
errors::InvalidArgument("Input matrices must be squares, got",
input_shape.dim_size(ndims - 2),
" != ", input_shape.dim_size(ndims - 1)));
xla::XlaOp input = xla::MaybeTransposeInMinorDims(ctx->Input(0), adjoint_);
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
auto qr = xla::QRDecomposition(input, /*full_matrices=*/false);
OP_REQUIRES_OK(ctx, qr.status());
xla::XlaOp output = xla::TriangularSolve(
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
ctx->SetOutput(0, output);
}
private:
bool adjoint_;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp);
};
// TODO(b/135640736): Allow this for integer and complex types.
REGISTER_XLA_OP(Name("MatrixInverse").TypeConstraint("T", kFloatTypes),
MatrixInverseOp);
} // namespace
} // namespace tensorflow

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
@ -65,24 +65,18 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
"Second dimension of 2D input must be of size 2, but got shape ",
input_tensor_shape.DebugString()));
}
std::vector<int32> dst_indices(4, 0);
int32 dst_indices[4];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
if (src_format_[i] == dst_format_[j]) {
dst_indices[i] = j;
dst_indices[j] = i;
break;
}
}
}
auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
if (input_rank == 2) {
keys = xla::BroadcastInDim(keys, {4, 2}, {0});
}
auto sorted = xla::Sort({keys, ctx->Input(0)},
xla::CreateScalarLtComputation(
{xla::S32, ctx->input_xla_type(0)}, builder),
0);
auto output = xla::GetTupleElement(sorted, 1);
xla::XlaOp indices =
xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0);
ctx->SetOutput(0, output);
}

View File

@ -0,0 +1,89 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
namespace tensorflow {
namespace {
class RollOp : public XlaOpKernel {
public:
explicit RollOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
xla::XlaOp shift = ctx->Input(1);
const TensorShape shift_shape = ctx->InputShape(1);
const TensorShape axis_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, input_shape.dims() >= 1,
errors::InvalidArgument("input must be 1-D or higher"));
OP_REQUIRES(ctx, shift_shape.dims() <= 1,
errors::InvalidArgument(
"shift must be a scalar or a 1-D vector. Found: ",
shift_shape.DebugString()));
OP_REQUIRES(
ctx, shift_shape.dims() == axis_shape.dims(),
errors::InvalidArgument("shift and axis must have the same size"));
xla::Literal axis;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &axis));
xla::XlaOp output = ctx->Input(0);
xla::PrimitiveType shift_type = ctx->input_xla_type(1);
int64 num_axes = axis_shape.dims() == 0 ? 1 : axis_shape.dim_size(0);
for (int64 i = 0; i != num_axes; ++i) {
auto cur_axis_status = axis_shape.dims() == 0
? axis.GetIntegralAsS64({})
: axis.GetIntegralAsS64({i});
OP_REQUIRES_OK(ctx, cur_axis_status.status());
int64 cur_axis = cur_axis_status.ValueOrDie();
xla::XlaOp offset =
shift_shape.dims() == 0
? shift
: xla::Reshape(xla::SliceInDim(shift, /*start_index=*/i,
/*limit_index=*/i + 1,
/*stride=*/1, /*dimno=*/0),
{});
xla::XlaOp axis_size = xla::ConstantR0WithType(
ctx->builder(), shift_type, input_shape.dim_size(cur_axis));
// Adjust large offsets into [0, axis_size). This also makes negative
// offsets positive.
offset = ((offset % axis_size) + axis_size) % axis_size;
// Stack two copies of the dimension, then slice from the calculated
// offset.
xla::XlaOp concat =
xla::ConcatInDim(ctx->builder(), {output, output}, cur_axis);
std::vector<xla::XlaOp> start_indices(
input_shape.dims(), xla::Zero(ctx->builder(), shift_type));
start_indices[cur_axis] = axis_size - offset;
output =
xla::DynamicSlice(concat, start_indices, input_shape.dim_sizes());
}
ctx->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RollOp);
};
REGISTER_XLA_OP(Name("Roll").CompileTimeConstantInput("axis"), RollOp);
} // namespace
} // namespace tensorflow

View File

@ -1,14 +1,14 @@
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_wrapper_py",
)
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],

View File

@ -633,12 +633,14 @@ REGISTER_OP("XlaEinsum")
if (context->RankKnown(input_a)) {
rank_a = context->Rank(input_a);
} else {
return errors::InvalidArgument("input 0's rank is unknown.");
context->set_output(0, context->UnknownShape());
return Status::OK();
}
if (context->RankKnown(input_b)) {
rank_b = context->Rank(input_b);
} else {
return errors::InvalidArgument("input 1's rank is unknown.");
context->set_output(0, context->UnknownShape());
return Status::OK();
}
string equation;
TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));

View File

@ -303,10 +303,27 @@ Status BuildComputation(
handle = identity_op(handle);
// Set layout of the retval to device representation layout.
if (resource->representation_shape().has_value()) {
retval_index_and_layout.emplace_back(
elems.size(), resource->representation_shape()->layout());
absl::optional<xla::Shape> representation_shape;
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(
xla::Shape xla_shape,
shape_representation_fn(resource->shape(), resource->type()));
representation_shape = xla_shape;
}
if (resource->representation_shape().has_value()) {
const xla::Shape& xla_shape = resource->representation_shape().value();
if (representation_shape) {
TF_RET_CHECK(
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
} else {
representation_shape = xla_shape;
}
}
if (representation_shape) {
retval_index_and_layout.emplace_back(elems.size(),
representation_shape->layout());
}
elems.push_back(handle);
}
}
@ -553,6 +570,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
GraphOptimizer::Options graph_optimizer_options;
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
graph_optimizer_options.inline_multi_device_functions = true;
graph_optimizer_options.inline_impl_selection_group_functions = true;
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, graph_optimizer_options);

Some files were not shown because too many files have changed in this diff Show More