Merge branch 'master' into eager_op_rewrite_registration
This commit is contained in:
commit
4e48f0664c
26
.bazelrc
26
.bazelrc
@ -9,6 +9,12 @@ build:android_arm --fat_apk_cpu=armeabi-v7a
|
||||
build:android_arm64 --config=android
|
||||
build:android_arm64 --cpu=arm64-v8a
|
||||
build:android_arm64 --fat_apk_cpu=arm64-v8a
|
||||
build:android_x86 --config=android
|
||||
build:android_x86 --cpu=x86
|
||||
build:android_x86 --fat_apk_cpu=x86
|
||||
build:android_x86_64 --config=android
|
||||
build:android_x86_64 --cpu=x86_64
|
||||
build:android_x86_64 --fat_apk_cpu=x86_64
|
||||
|
||||
# Sets the default Apple platform to macOS.
|
||||
build --apple_platform_type=macos
|
||||
@ -124,6 +130,26 @@ build --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# Suppress all warning messages.
|
||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
|
||||
# Options when using remote execution
|
||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||
build:rbe --auth_enabled=true
|
||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
build:rbe --flaky_test_attempts=3
|
||||
build:rbe --jobs=200
|
||||
build:rbe --remote_accept_cached=true
|
||||
build:rbe --remote_cache=remotebuildexecution.googleapis.com
|
||||
build:rbe --remote_executor=remotebuildexecution.googleapis.com
|
||||
build:rbe --remote_local_fallback=false
|
||||
build:rbe --remote_timeout=600
|
||||
build:rbe --spawn_strategy=remote
|
||||
build:rbe --strategy=Genrule=remote
|
||||
build:rbe --strategy=Closure=remote
|
||||
build:rbe --strategy=Javac=remote
|
||||
build:rbe --strategy=TestRunner=remote
|
||||
build:rbe --tls_enabled
|
||||
test:rbe --test_env=USER=anon
|
||||
|
||||
# Default options should come above this line
|
||||
|
||||
# Options from ./configure
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Where component owners are known, add them here.
|
||||
|
||||
/tensorflow/c/eager @jaingurav @alextp
|
||||
/tensorflow/core/common_runtime/eager @jaingaurav @alextp
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @mrry
|
||||
/tensorflow/core/platform/s3 @yongtang
|
||||
/tensorflow/go @asimshankar
|
||||
/tensorflow/java/ @asimshankar
|
||||
/tensorflow/python/debug @caisq
|
||||
/tensorflow/python/eager @jaingurav @alextp
|
||||
/tensorflow/python/tools/api/generator/ @annarev
|
||||
/tensorflow/tensorboard/ @jart
|
||||
/tensorflow/tools/docs/ @markdaoust
|
||||
@ -25,7 +26,7 @@
|
||||
/tensorflow/contrib/data/ @mrry
|
||||
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
|
||||
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
|
||||
/tensorflow/contrib/eager @alextp @asimshankar
|
||||
/tensorflow/contrib/eager @jaingurav @alextp
|
||||
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
|
||||
/tensorflow/contrib/ffmpeg/ @fredbertsch
|
||||
/tensorflow/contrib/framework/ @ebrevdo
|
||||
|
@ -168,11 +168,11 @@ There are two ways to run TensorFlow unit tests.
|
||||
[GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile)
|
||||
for the required packages. Alternatively, use the said
|
||||
[Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g.,
|
||||
`tensorflow/tensorflow:nightly-devel` and
|
||||
`tensorflow/tensorflow:nightly-devel-gpu` for development to avoid
|
||||
installing the packages directly on your system (in which case remember to
|
||||
change directory from `/root` to `/tensorflow` once you get into the running
|
||||
container so `bazel` can find the `tensorflow` workspace).
|
||||
`tensorflow/tensorflow:devel` and `tensorflow/tensorflow:devel-gpu` for
|
||||
development to avoid installing the packages directly on your system (in
|
||||
which case remember to change directory from `/root` to `/tensorflow` once
|
||||
you get into the running container so `bazel` can find the `tensorflow`
|
||||
workspace).
|
||||
|
||||
Once you have the packages installed, you can run a specific unit test in
|
||||
bazel by doing as follows:
|
||||
|
2
LICENSE
2
LICENSE
@ -188,7 +188,7 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2017, The TensorFlow Authors.
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
@ -116,7 +116,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux s390x Nightly** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
|
||||
**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
|
||||
**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
|
22
WORKSPACE
22
WORKSPACE
@ -4,14 +4,20 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file"
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
|
||||
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
|
||||
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||
],
|
||||
)
|
||||
|
||||
# Load tf_repositories() before loading dependencies for other repository so
|
||||
# that dependencies like com_google_protobuf won't be overridden.
|
||||
load("//tensorflow:workspace.bzl", "tf_repositories")
|
||||
# Please add all new TensorFlow dependencies in workspace.bzl.
|
||||
tf_repositories()
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
|
||||
|
||||
closure_repositories()
|
||||
@ -83,15 +89,15 @@ swift_rules_dependencies()
|
||||
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
|
||||
check_bazel_version_at_least("0.19.0")
|
||||
|
||||
load("//tensorflow:workspace.bzl", "tf_workspace")
|
||||
|
||||
load("//third_party/android:android_configure.bzl", "android_configure")
|
||||
android_configure(name="local_config_android")
|
||||
load("@local_config_android//:android.bzl", "android_workspace")
|
||||
android_workspace()
|
||||
|
||||
# Please add all new TensorFlow dependencies in workspace.bzl.
|
||||
tf_workspace()
|
||||
# If a target is bound twice, the later one wins, so we have to do tf bindings
|
||||
# at the end of the WORKSPACE file.
|
||||
load("//tensorflow:workspace.bzl", "tf_bind")
|
||||
tf_bind()
|
||||
|
||||
http_archive(
|
||||
name = "inception_v1",
|
||||
|
20
configure.cmd
Normal file
20
configure.cmd
Normal file
@ -0,0 +1,20 @@
|
||||
:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
::
|
||||
:: Licensed under the Apache License, Version 2.0 (the "License");
|
||||
:: you may not use this file except in compliance with the License.
|
||||
:: You may obtain a copy of the License at
|
||||
::
|
||||
:: http://www.apache.org/licenses/LICENSE-2.0
|
||||
::
|
||||
:: Unless required by applicable law or agreed to in writing, software
|
||||
:: distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
:: WARRANTIES OR CONDITIONS OF ANY KIND< either express or implied. See the
|
||||
:: License for the specific language governing permissions and limitations under
|
||||
:: the License.
|
||||
|
||||
@echo off
|
||||
|
||||
set configure_dir=%~dp0
|
||||
set configure_dir=%configure_dir:~0,-1%
|
||||
python %configure_dir%\configure.py %* || ( exit /b )
|
||||
echo Configuration finished
|
12
configure.py
12
configure.py
@ -49,6 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '0.24.1'
|
||||
_TF_MAX_BAZEL_VERSION = '0.26.1'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -234,7 +236,9 @@ def setup_python(environ_cp):
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
|
||||
_ = get_python_major_version(python_bin_path)
|
||||
python_major_version = get_python_major_version(python_bin_path)
|
||||
if python_major_version == '2':
|
||||
write_to_bazelrc('build --host_force_python=PY2')
|
||||
|
||||
# Convert python path to Windows style before writing into bazel.rc
|
||||
if is_windows() or is_cygwin():
|
||||
@ -1391,7 +1395,8 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.26.1')
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1423,7 +1428,7 @@ def main():
|
||||
if is_ppc64le():
|
||||
write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
|
||||
|
||||
xla_enabled_by_default = is_linux()
|
||||
xla_enabled_by_default = is_linux() or is_macos()
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
xla_enabled_by_default, 'xla')
|
||||
|
||||
@ -1582,6 +1587,7 @@ def main():
|
||||
config_info_line(
|
||||
'dynamic_kernels',
|
||||
'(Experimental) Build kernels into separate shared objects.')
|
||||
config_info_line('v2', 'Build TensorFlow 2.x instead of 1.x.')
|
||||
|
||||
print('Preconfigured Bazel build configs to DISABLE default on features:')
|
||||
config_info_line('noaws', 'Disable AWS S3 filesystem support.')
|
||||
|
@ -529,6 +529,10 @@ tf_cc_shared_object(
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:freebsd": [
|
||||
"-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)",
|
||||
"-lexecinfo",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)",
|
||||
],
|
||||
|
@ -149,4 +149,5 @@ if hasattr(_current_module, 'keras'):
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
|
||||
compat.v2.compat.v1 = compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -141,4 +141,6 @@ try:
|
||||
vars()['__all__'].remove('compiler')
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
compat.v2.compat.v1 = compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -77,7 +77,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
],
|
||||
}) + [":tf_status_internal"],
|
||||
}) + [
|
||||
":tf_status_internal",
|
||||
":tf_tensor_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -211,9 +214,10 @@ cc_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -221,6 +225,24 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_tensor_internal",
|
||||
hdrs = [
|
||||
"tf_tensor.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
@ -263,7 +285,7 @@ tf_cuda_library(
|
||||
hdrs = ["tf_status_helper.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_no_xla",
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -329,17 +351,16 @@ tf_cuda_library(
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_no_xla",
|
||||
":c_api_internal",
|
||||
":tf_status_helper",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_no_xla",
|
||||
":c_api_internal",
|
||||
":tf_status_helper",
|
||||
":tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
@ -357,6 +378,8 @@ tf_cuda_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
@ -365,7 +388,7 @@ tf_cuda_library(
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}) + [":c_api_internal"],
|
||||
}),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -363,7 +363,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
||||
}
|
||||
*graph_def.mutable_library() = graph.flib_def().ToProto();
|
||||
session->graph->mu.unlock();
|
||||
status->status = session->session->Extend(graph_def);
|
||||
status->status = session->session->Extend(std::move(graph_def));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
// Contract is we always delete input_values[i].
|
||||
return false;
|
||||
|
@ -737,13 +737,14 @@ int TF_PickUnusedPortOrDie() {
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
void* data, size_t len) {
|
||||
void* data, size_t len,
|
||||
TF_Status* status) {
|
||||
auto dtype = static_cast<tensorflow::DataType>(data_type);
|
||||
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
|
||||
|
||||
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
return new TFE_TensorHandle(tensor);
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -994,3 +995,23 @@ TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(
|
||||
<< handle->DebugString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
TFE_ContextMirroringPolicy policy) {
|
||||
options->mirroring_policy = policy;
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
ctx->context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
|
||||
// Note: this function looks up a thread local policy. So it should be called in
|
||||
// the appropriate client thread. In particular, in async mode, it may not be
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
return static_cast<TFE_ContextMirroringPolicy>(
|
||||
ctx->context->GetMirroringPolicy());
|
||||
}
|
||||
|
@ -285,7 +285,7 @@ TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void);
|
||||
// Fast path method that makes constructing a single scalar tensor require less
|
||||
// overhead and copies.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
|
||||
TF_DataType data_type, void* data, size_t len);
|
||||
TF_DataType data_type, void* data, size_t len, TF_Status* status);
|
||||
|
||||
// Specify the server_def that enables collective ops.
|
||||
// This is different to the above function in that it doesn't create remote
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
@ -53,14 +54,6 @@ class ServerInterface;
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Tensor {
|
||||
~TF_Tensor();
|
||||
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
struct TF_SessionOptions {
|
||||
tensorflow::SessionOptions options;
|
||||
};
|
||||
@ -193,15 +186,6 @@ struct TF_Server {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
|
@ -26,6 +26,7 @@ tf_cuda_library(
|
||||
srcs = [
|
||||
"c_api.cc",
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -81,6 +82,7 @@ tf_cuda_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
@ -274,7 +276,6 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_experimental.h",
|
||||
"*test*",
|
||||
],
|
||||
),
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
@ -89,12 +90,6 @@ bool IsCPU(const tensorflow::Device* d) {
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
bool IsXLA(const tensorflow::Device* d) {
|
||||
if (d == nullptr) return false;
|
||||
const auto& device_type = d->attributes().device_type();
|
||||
return device_type.find("XLA") != std::string::npos;
|
||||
}
|
||||
|
||||
string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
@ -241,10 +236,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
|
||||
grpc_server->channel_cache();
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
|
||||
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// Initialize remote eager workers.
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||
@ -369,7 +364,7 @@ void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
|
||||
|
||||
void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
|
||||
options->policy = policy;
|
||||
options->device_placement_policy = policy;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||
@ -392,7 +387,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
@ -406,7 +402,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
}
|
||||
@ -476,7 +473,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return new TFE_TensorHandle(tensor);
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
@ -569,14 +566,15 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
if (h->handle->IsRemote()) {
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle->IsRemote()) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(),
|
||||
h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
|
||||
handle, handle->Context(), handle->Context()->HostCPU()->name().c_str(),
|
||||
false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -585,28 +583,23 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
status->status = h->handle->Tensor(&t);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
if (!IsCPU(h->handle->device())) {
|
||||
status->status = h->handle->CopyToDevice(
|
||||
h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = h_cpu->Tensor(&t);
|
||||
if (!status->status.ok()) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
if (h_cpu != nullptr) {
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle->device())) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
status->status = handle->Tensor(&src);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
tensorflow::EagerContext* ctx = handle->Context();
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
@ -924,9 +917,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
device_name, &handle);
|
||||
device_name, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
}
|
||||
@ -971,8 +964,9 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
|
||||
} // extern "C"
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
|
||||
return new TFE_TensorHandle(t);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return TFE_TensorHandle::CreateLocalHandle(t, status);
|
||||
}
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
@ -996,9 +990,9 @@ TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
// TensorHandles created by PyFuncOp lack context and therefore could
|
||||
// not be copied.
|
||||
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(), "CPU:0", &handle);
|
||||
h->handle, h->handle->Context(), "CPU:0", false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
} else {
|
||||
|
@ -60,6 +60,8 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
|
||||
|
||||
// Controls how to act when we try to run an operation on a given device but
|
||||
// some input tensors are not on that device.
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with internal copy of enum in eager/context.h.
|
||||
typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
@ -72,6 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
} TFE_ContextDevicePlacementPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
// Sets the default execution mode (sync/async). Note that this can be
|
||||
// overridden per thread using TFE_ContextSetAsyncForThread.
|
||||
@ -465,7 +468,8 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status);
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_H_
|
||||
|
@ -42,10 +42,8 @@ bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
||||
|
||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
||||
|
||||
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
string content;
|
||||
status->status = profiler->profiler->SerializeToString(&content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
@ -102,6 +100,28 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::client::Monitor(
|
||||
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
result->length = content.length();
|
||||
result->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
|
@ -40,8 +40,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
||||
|
||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
||||
// the string to file for offline analysis by tensorboard.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx,
|
||||
TFE_Profiler* profiler,
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
@ -88,6 +87,16 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
||||
TF_Status* status);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// monitoring and return the result in a string. It will block the
|
||||
// caller thread until receiving the monitoring result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
||||
const char* service_addr, int duration_ms, int monitoring_level,
|
||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
||||
|
||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
@ -311,6 +320,29 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with internal copy of enum in eager/context.h.
|
||||
typedef enum TFE_ContextMirroringPolicy {
|
||||
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
|
||||
// copies with their own lifetime.
|
||||
TFE_MIRRORING_NONE = 0,
|
||||
// Mirroring any remote tensor handles, associating them with the lifetime of
|
||||
// the local TensorHandle.
|
||||
TFE_MIRRORING_ALL = 1,
|
||||
} TFE_ContextMirroringPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
// Sets a thread-local mirroring policy. After this call, other calls to
|
||||
// TFE_Execute in the same thread will use the mirroring policy specified here
|
||||
// instead of the mirroring policy used to construct the context. This has no
|
||||
// effect on the mirroring policy used by other program threads.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context*, TFE_ContextMirroringPolicy);
|
||||
|
||||
// Returns the mirroring policy to be used by this context in the current
|
||||
// thread.
|
||||
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context*);
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -72,7 +72,11 @@ void ExecuteWithProfiling(bool async) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||
TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status);
|
||||
if (async) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||
TFE_DeleteProfiler(profiler);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
profiler::Trace profile_proto;
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -53,19 +54,24 @@ struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
// true if async execution is enabled.
|
||||
bool async = false;
|
||||
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
||||
TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_policy, bool async,
|
||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
: context(new tensorflow::EagerContext(
|
||||
opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_policy),
|
||||
default_device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
default_mirroring_policy),
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
|
||||
@ -75,12 +81,25 @@ struct TFE_Context {
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* handle)
|
||||
: handle(handle) {}
|
||||
explicit TFE_TensorHandle(const tensorflow::Tensor& t)
|
||||
: handle(new tensorflow::TensorHandle(t)) {}
|
||||
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
|
||||
: handle(new tensorflow::TensorHandle(t, d, nullptr)) {}
|
||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(handle);
|
||||
}
|
||||
static tensorflow::Status CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
tensorflow::Device* d,
|
||||
TFE_TensorHandle** h) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(t, d, nullptr, &handle));
|
||||
*h = new TFE_TensorHandle(handle);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
|
||||
@ -92,7 +111,7 @@ struct TFE_TensorHandle {
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
@ -124,7 +143,7 @@ struct TFE_ProfilerContext {
|
||||
};
|
||||
|
||||
struct TFE_Profiler {
|
||||
TFE_Profiler(TFE_ProfilerContext* ctx) {
|
||||
explicit TFE_Profiler(TFE_ProfilerContext* ctx) {
|
||||
profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context);
|
||||
}
|
||||
|
||||
@ -211,7 +230,7 @@ struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
TFE_MonitoringBuckets(
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
@ -266,7 +285,7 @@ struct TFE_TraceContext {
|
||||
std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>* input_tensors =
|
||||
nullptr;
|
||||
|
||||
TFE_TraceContext(TF_Graph* graph) : graph(graph) {}
|
||||
explicit TFE_TraceContext(TF_Graph* graph) : graph(graph) {}
|
||||
|
||||
~TFE_TraceContext() {
|
||||
delete input_tensors;
|
||||
|
@ -19,9 +19,11 @@ limitations under the License.
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -99,6 +101,12 @@ class VSpace {
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
|
||||
|
||||
// Marks the following gradient as a result so it's not consumed by backward
|
||||
// functions.
|
||||
virtual void MarkAsResult(Gradient* gradient) const = 0;
|
||||
@ -129,7 +137,7 @@ class GradientTape {
|
||||
void Watch(int64 tensor_id);
|
||||
|
||||
void RecordOperation(
|
||||
const string& op_type, std::vector<TapeTensor>& output_tensors,
|
||||
const string& op_type, const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
@ -165,6 +173,122 @@ class GradientTape {
|
||||
bool persistent_;
|
||||
};
|
||||
|
||||
// Computes Jacobian-vector products using forward-mode automatic
|
||||
// differentiation.
|
||||
//
|
||||
// While GradientTape's RecordOperation is trivial, ForwardAccumulator's
|
||||
// Accumulate runs the gradient computation immediately.
|
||||
//
|
||||
// Keeps references to Tensors watched via Watch and computed in Accumulate
|
||||
// corresponding to output_tensors, and releases these references in its
|
||||
// destructor. However, waiting until the destructor runs loses the memory
|
||||
// efficiency of forward-mode autodiff. Instead, language bindings should call
|
||||
// DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
|
||||
// Tensor passed to Accumulate goes out of scope.
|
||||
//
|
||||
// Not thread-safe.
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
class ForwardAccumulator {
|
||||
public:
|
||||
// Does not take ownership of `vspace`, which must outlive the
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||
|
||||
virtual ~ForwardAccumulator() {
|
||||
for (auto accumulated : accumulated_gradients_) {
|
||||
vspace_.DeleteGradient(accumulated.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Tell the forward accumulator to watch tensor_id, with a Tensor tangent
|
||||
// vector `tangent` of matching shape and dtype. Tangents are the "vector" in
|
||||
// "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
|
||||
// FetchJVP for it would return `tangent`.
|
||||
void Watch(int64 tensor_id, Gradient* tangent);
|
||||
|
||||
// Removes the gradient associated with tensor_id. Should be called when the
|
||||
// Tensor associated with `tensor_id` is deleted.
|
||||
void DeleteGradient(int64 tensor_id);
|
||||
|
||||
// Runs forward autodiff. Should be called whenever a new operation is
|
||||
// available and the accumulator is active.
|
||||
//
|
||||
// Like GradientTape::RecordOperation, this method takes the operation type
|
||||
// `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
|
||||
// `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
|
||||
// redundant but taken as arguments to avoid repeatedly fetching these values
|
||||
// between calls to ShouldRecord and Accumulator), and its outputs
|
||||
// (`output_tensors`).
|
||||
//
|
||||
// Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
|
||||
// immediately. It stores the results, which feed into Accumulate for future
|
||||
// operations and may be fetched by calling FetchJVP. ForwardAccumulator
|
||||
// maintains a reference to these JVPs: if an `output_tensors` Tensor is
|
||||
// deleted, `DeleteGradient` should be called as soon as possible to free the
|
||||
// (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
|
||||
// will release remaining references.
|
||||
//
|
||||
// This method is not thread-safe (and in general ForwardAccumulator is not
|
||||
// thread-safe).
|
||||
Status Accumulate(
|
||||
const string& op_type, const std::vector<TapeTensor>& input_tensors,
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||
|
||||
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
||||
// a nullptr if none is available.
|
||||
//
|
||||
// Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
|
||||
// return value. The caller should increment the reference count before
|
||||
// deleting the ForwardAccumulator or calling DeleteGradient if keeping a
|
||||
// persistent reference to a non-null result.
|
||||
Gradient* FetchJVP(int64 tensor_id);
|
||||
|
||||
// Indicates whether the forward accumulator should run on an operation with
|
||||
// the specified inputs and dtypes.
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
|
||||
private:
|
||||
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
||||
// from a backward gradient function. Fills `out_grads` corresponding to
|
||||
// `output_tensors`. `out_grads` must not be null.
|
||||
//
|
||||
// Executes the backward function in order to trace its gradient, which will
|
||||
// waste computation if executing eagerly (when graph building the unneeded
|
||||
// computation is pruned). Temporarily sets `backward_tape_` so that
|
||||
// Accumulate will forward op executions to the tape while the backward
|
||||
// function is running; this effectively adds the backward tape to the active
|
||||
// set (but does not require complicated callbacks to the language bindings).
|
||||
Status ForwardpropFromTape(
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter,
|
||||
const std::vector<Gradient*>& in_grads,
|
||||
std::vector<Gradient*>* out_grads);
|
||||
|
||||
// Maps from tensor IDs to corresponding JVPs.
|
||||
std::unordered_map<int64, Gradient*> accumulated_gradients_;
|
||||
// Not owned; provides operations on Tensors which are currently only
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
|
||||
// keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
|
||||
// While the Accumulate method is running (accumulating_ is True), any op
|
||||
// executions not forwarded to backward_tape_ should be ignored.
|
||||
bool accumulating_;
|
||||
};
|
||||
|
||||
// Template instantiations here
|
||||
|
||||
inline bool IsDtypeTrainable(DataType dtype) {
|
||||
@ -206,7 +330,7 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
|
||||
const string& op_type, std::vector<TapeTensor>& output_tensors,
|
||||
const string& op_type, const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
@ -691,6 +815,228 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
|
||||
// we should also delegate ShouldRecord.
|
||||
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
|
||||
}
|
||||
if (accumulating_) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
if (accumulated_gradients_.find(tensor_ids[i]) !=
|
||||
accumulated_gradients_.end()) {
|
||||
if (IsDtypeTrainable(dtypes[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
Status
|
||||
ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter,
|
||||
const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
|
||||
/* This function is approximately equivalent to this Python code:
|
||||
|
||||
forwardprop_aids = tf.ones_like(output_tensors)
|
||||
with tf.GradientTape() as g:
|
||||
g.watch(forwardprop_aids)
|
||||
grad = backward_function(forwardprop_aids)
|
||||
forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
|
||||
accumulated_gradients_[ID(output_tensors)] = forward_grads
|
||||
*/
|
||||
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
||||
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
||||
backward_tape_ = tape.get();
|
||||
auto pop_backward_tape =
|
||||
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
|
||||
std::vector<Gradient*> forwardprop_aids;
|
||||
std::vector<int64> sources;
|
||||
std::unordered_set<int64> sources_set;
|
||||
sources.reserve(output_tensors.size());
|
||||
for (const TapeTensor& output_tensor : output_tensors) {
|
||||
// Ownership of `aid` transferred to CallBackwardFunction below.
|
||||
Gradient* aid = vspace_.Ones(output_tensor);
|
||||
forwardprop_aids.push_back(aid);
|
||||
int64 aid_id = vspace_.TensorId(aid);
|
||||
sources.push_back(aid_id);
|
||||
sources_set.insert(aid_id);
|
||||
tape->Watch(aid_id);
|
||||
}
|
||||
std::vector<Gradient*> grad;
|
||||
auto delete_grad = gtl::MakeCleanup([&grad, this] {
|
||||
for (Gradient* tensor : grad) {
|
||||
this->vspace_.DeleteGradient(tensor);
|
||||
}
|
||||
});
|
||||
{
|
||||
std::vector<int64> unneeded_gradients;
|
||||
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
|
||||
backward_function(backward_function_getter(),
|
||||
backward_function_deleter);
|
||||
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
|
||||
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
|
||||
}
|
||||
|
||||
// Stop the tape from recording
|
||||
pop_backward_tape.release()();
|
||||
|
||||
std::vector<int64> targets;
|
||||
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
||||
for (Gradient* grad_tensor : grad) {
|
||||
if (grad_tensor != nullptr) {
|
||||
int64 tensor_id = vspace_.TensorId(grad_tensor);
|
||||
targets.push_back(tensor_id);
|
||||
if (sources_set.find(tensor_id) != sources_set.end()) {
|
||||
sources_that_are_targets.emplace(
|
||||
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (targets.size() > in_grads.size()) {
|
||||
return tensorflow::errors::Internal("Too many gradients returned.");
|
||||
}
|
||||
|
||||
for (int target_index = 0; target_index < targets.size(); ++target_index) {
|
||||
Gradient* in_grad = in_grads[target_index];
|
||||
Gradient* grad_tensor = grad[target_index];
|
||||
if (grad_tensor != nullptr && in_grad != nullptr) {
|
||||
// ComputeGradient steals a reference
|
||||
vspace_.MarkAsResult(in_grad);
|
||||
}
|
||||
}
|
||||
|
||||
return tape->ComputeGradient(vspace_, targets, sources,
|
||||
sources_that_are_targets, in_grads, out_grads);
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
const string& op_type, const std::vector<TapeTensor>& input_tensors,
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If backward_tape_ is not null, then this call to Accumulate is the result
|
||||
// of a still-active call to Accumulate which is running operations. We
|
||||
// forward these operations to backward_tape_ so the outer Accumulate call
|
||||
// can do its work.
|
||||
//
|
||||
// Rather than re-entering and delegating Accumulate like this, we could
|
||||
// instead allow ForwardAccumulator some control over the current tape set
|
||||
// (so it can deactivate itself and activate its GradientTape). Currently
|
||||
// that is managed by the language binding and would require relatively
|
||||
// messy callbacks.
|
||||
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
|
||||
input_dtypes, backward_function_getter,
|
||||
backward_function_deleter);
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We may need to allocate zero inputs for trainable dtypes we don't have JVPs
|
||||
// for. Make sure they get cleaned up.
|
||||
std::vector<Gradient*> new_zeros;
|
||||
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
|
||||
for (Gradient* tensor : new_zeros) {
|
||||
this->vspace_.DeleteGradient(tensor);
|
||||
}
|
||||
});
|
||||
std::vector<Gradient*> in_grads;
|
||||
in_grads.reserve(input_tensors.size());
|
||||
for (int target_index = 0; target_index < input_tensors.size();
|
||||
++target_index) {
|
||||
const auto current_grad =
|
||||
accumulated_gradients_.find(input_tensors[target_index].GetID());
|
||||
if (current_grad == accumulated_gradients_.end()) {
|
||||
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
|
||||
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
|
||||
// GradientTape which uses ones.
|
||||
Gradient* zero = vspace_.Zeros(input_tensors[target_index]);
|
||||
new_zeros.push_back(zero);
|
||||
in_grads.push_back(zero);
|
||||
} else {
|
||||
in_grads.push_back(nullptr);
|
||||
}
|
||||
} else {
|
||||
in_grads.push_back(current_grad->second);
|
||||
}
|
||||
}
|
||||
|
||||
accumulating_ = true;
|
||||
auto reset_accumulating =
|
||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
||||
|
||||
std::vector<Gradient*> forward_grads;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ForwardpropFromTape(output_tensors, backward_function_getter,
|
||||
backward_function_deleter, in_grads, &forward_grads));
|
||||
|
||||
for (int i = 0; i < forward_grads.size(); ++i) {
|
||||
if (forward_grads[i] != nullptr) {
|
||||
int64 tensor_id = output_tensors[i].GetID();
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
if (existing != accumulated_gradients_.end()) {
|
||||
vspace_.DeleteGradient(existing->second);
|
||||
}
|
||||
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
|
||||
int64 tensor_id, Gradient* tangent) {
|
||||
typename std::unordered_map<int64, Gradient*>::iterator existing =
|
||||
accumulated_gradients_.find(tensor_id);
|
||||
vspace_.MarkAsResult(tangent);
|
||||
if (existing == accumulated_gradients_.end()) {
|
||||
accumulated_gradients_.emplace(tensor_id, tangent);
|
||||
} else {
|
||||
std::array<Gradient*, 2> to_aggregate;
|
||||
to_aggregate[0] = tangent;
|
||||
to_aggregate[1] = existing->second;
|
||||
// AggregateGradients steals a reference to each of its arguments. We
|
||||
// MarkAsResult on `tangent` above so we don't steal a reference to it.
|
||||
existing->second = vspace_.AggregateGradients(to_aggregate);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
|
||||
int64 tensor_id) {
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
if (existing != accumulated_gradients_.end()) {
|
||||
vspace_.DeleteGradient(existing->second);
|
||||
accumulated_gradients_.erase(existing);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
|
||||
int64 tensor_id) {
|
||||
auto lookup = accumulated_gradients_.find(tensor_id);
|
||||
if (lookup == accumulated_gradients_.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return lookup->second;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,12 +16,36 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_KERNELS_H_
|
||||
#define TENSORFLOW_C_KERNELS_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||
// .dylib, .dll).
|
||||
// This duplicates the TF_EXPORT macro definition in
|
||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||
// of any other includes.
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct TF_Tensor TF_Tensor;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow Kernels.
|
||||
//
|
||||
|
@ -14,8 +14,12 @@ tf_kernel_library(
|
||||
prefix = "bitcast_op",
|
||||
deps = [
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/c:ops",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -28,6 +32,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,11 +16,12 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
#include "tensorflow/c/ops.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
// BitcastOp implements a bitcast kernel, creating an output tensor that shares
|
||||
// the same data buffer as the input but with a different shape and/or data
|
||||
@ -135,9 +136,8 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_DeleteTensor(tensor);
|
||||
}
|
||||
|
||||
static void RegisterBitcastOp() {
|
||||
void RegisterBitcastOp() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
|
||||
&BitcastOp_Create, &BitcastOp_Compute,
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
@ -78,15 +78,10 @@ void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
::tensorflow::OpRegistry::Global()->Register(
|
||||
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
return cc_builder->Finalize(op_reg_data);
|
||||
Status result = cc_builder->Finalize(op_reg_data);
|
||||
delete cc_builder;
|
||||
return result;
|
||||
});
|
||||
|
||||
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
|
||||
// is called and that the builder can be deleted.
|
||||
Set_TF_Status_from_Status(
|
||||
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
|
||||
|
||||
delete cc_builder;
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
|
@ -73,7 +73,8 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
|
@ -43,7 +43,6 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
|
||||
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Status* status) {
|
||||
AttrValue attr_val;
|
||||
|
||||
mutex_lock l(graph->mu);
|
||||
op->node.ClearAttr(attr_name);
|
||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
|
||||
@ -227,13 +230,15 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
if (sz < src_len) {
|
||||
status->status = InvalidArgument("src string is too large to encode");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument("src string is too large to encode"));
|
||||
return 0;
|
||||
}
|
||||
if (dst_len < sz) {
|
||||
status->status =
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
|
||||
src_len, "-byte string");
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
@ -259,7 +264,8 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len,
|
||||
|
||||
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
|
||||
size_t* dst_len, TF_Status* status) {
|
||||
status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
|
||||
Set_TF_Status_from_Status(status,
|
||||
TF_StringDecode_Impl(src, src_len, dst, dst_len));
|
||||
if (TF_GetCode(status) != TF_OK) return 0;
|
||||
return static_cast<size_t>(*dst - src) + *dst_len;
|
||||
}
|
||||
@ -299,8 +305,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (!src.IsInitialized()) {
|
||||
status->status = FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value");
|
||||
Set_TF_Status_from_Status(
|
||||
status, FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value"));
|
||||
return nullptr;
|
||||
}
|
||||
if (src.NumElements() == 0) {
|
||||
@ -308,13 +315,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||
if (src.shape().dims() != 0) {
|
||||
status->status = InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error.");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error."));
|
||||
return nullptr;
|
||||
}
|
||||
const string str =
|
||||
@ -353,9 +361,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
status->status = InvalidArgument(
|
||||
"invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", status->status.error_message());
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", TF_Message(status)));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
@ -363,9 +372,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
status->status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
|
46
tensorflow/c/tf_tensor_internal.h
Normal file
46
tensorflow/c/tf_tensor_internal.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Tensor {
|
||||
~TF_Tensor();
|
||||
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
@ -236,10 +236,13 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_experimental",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -141,6 +141,15 @@ Status ClientSession::RunCallable(CallableHandle handle,
|
||||
run_metadata);
|
||||
}
|
||||
|
||||
Status ClientSession::RunCallable(CallableHandle handle,
|
||||
const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata,
|
||||
const thread::ThreadPoolOptions& options) {
|
||||
return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
|
||||
run_metadata, options);
|
||||
}
|
||||
|
||||
Status ClientSession::ReleaseCallable(CallableHandle handle) {
|
||||
return impl()->session_->ReleaseCallable(handle);
|
||||
}
|
||||
|
@ -27,6 +27,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace thread {
|
||||
|
||||
struct ThreadPoolOptions;
|
||||
|
||||
}
|
||||
|
||||
/// @addtogroup core
|
||||
/// @{
|
||||
|
||||
@ -110,6 +116,20 @@ class ClientSession {
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata);
|
||||
|
||||
/// \brief Invokes the subgraph named by `handle` with the given options and
|
||||
/// input tensors.
|
||||
///
|
||||
/// The order of tensors in `feed_tensors` must match the order of names in
|
||||
/// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
|
||||
/// match the order of names in `CallableOptions::fetch()` when this subgraph
|
||||
/// was created.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
Status RunCallable(CallableHandle handle,
|
||||
const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata,
|
||||
const thread::ThreadPoolOptions& options);
|
||||
|
||||
/// \brief Releases resources associated with the given `handle` in this
|
||||
/// session.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
|
@ -13,24 +13,67 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/synchronization/barrier.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/core/threadpool_options.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ops::Add;
|
||||
using ops::BatchMatMul;
|
||||
using ops::Const;
|
||||
using ops::Mul;
|
||||
using ops::Placeholder;
|
||||
using ops::Sub;
|
||||
|
||||
class CustomThreadPoolImpl : public thread::ThreadPoolInterface {
|
||||
public:
|
||||
explicit CustomThreadPoolImpl(int numThreads) {
|
||||
underlying_threadpool_.reset(new thread::ThreadPool(
|
||||
tensorflow::Env::Default(), "custom_threadpool", numThreads));
|
||||
num_schedule_called_ = 0;
|
||||
}
|
||||
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
num_schedule_called_ += 1;
|
||||
underlying_threadpool_->Schedule(std::move(fn));
|
||||
}
|
||||
|
||||
void ScheduleWithHint(std::function<void()> fn, int start, int end) override {
|
||||
num_schedule_called_ += 1;
|
||||
underlying_threadpool_->ScheduleWithHint(std::move(fn), start, end);
|
||||
}
|
||||
|
||||
void Cancel() override {}
|
||||
|
||||
int NumThreads() const override {
|
||||
return underlying_threadpool_->NumThreads();
|
||||
}
|
||||
|
||||
int CurrentThreadId() const override {
|
||||
return underlying_threadpool_->CurrentThreadId();
|
||||
}
|
||||
|
||||
int GetNumScheduleCalled() { return num_schedule_called_; }
|
||||
|
||||
private:
|
||||
int num_schedule_called_;
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> underlying_threadpool_;
|
||||
};
|
||||
|
||||
TEST(ClientSessionTest, Basic) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = Const(root, {{1, 1}});
|
||||
@ -95,7 +138,7 @@ TEST(ClientSessionTest, MultiThreaded) {
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
|
||||
}
|
||||
|
||||
TEST(ClientSessionTest, Callable) {
|
||||
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto a = Placeholder(root, DT_INT32);
|
||||
auto b = Placeholder(root, DT_INT32);
|
||||
@ -116,5 +159,60 @@ TEST(ClientSessionTest, Callable) {
|
||||
TF_EXPECT_OK(session.ReleaseCallable(callable));
|
||||
}
|
||||
|
||||
TEST(ClientSessionTest, CallableWithCustomThreadPool) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
int num_threads = 3;
|
||||
|
||||
TensorShape data_shape({1, 1});
|
||||
auto a = Placeholder(root, DT_INT32, Placeholder::Shape(data_shape));
|
||||
auto b = Placeholder(root, DT_INT32, Placeholder::Shape(data_shape));
|
||||
auto c = BatchMatMul(root, a, b);
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
auto inter_op_threadpool =
|
||||
absl::make_unique<CustomThreadPoolImpl>(num_threads);
|
||||
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
|
||||
auto intra_op_threadpool =
|
||||
absl::make_unique<CustomThreadPoolImpl>(num_threads);
|
||||
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
|
||||
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
|
||||
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
|
||||
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
|
||||
|
||||
CallableOptions options;
|
||||
options.add_feed(a.node()->name());
|
||||
options.add_feed(b.node()->name());
|
||||
options.add_fetch(c.node()->name());
|
||||
ClientSession::CallableHandle callable;
|
||||
TF_CHECK_OK(session.MakeCallable(options, &callable));
|
||||
|
||||
// This is needed to have BatchMatMul computation be scheduled in the
|
||||
// intra_op_threadpool.
|
||||
absl::Barrier barrier(num_threads + 1);
|
||||
for (int i = 0; i < num_threads; i++) {
|
||||
intra_op_threadpool->Schedule([&barrier, num_threads]() {
|
||||
tensorflow::SetPerThreadMaxParallelism(num_threads - 1);
|
||||
barrier.Block();
|
||||
});
|
||||
}
|
||||
barrier.Block();
|
||||
|
||||
TF_EXPECT_OK(session.RunCallable(
|
||||
callable,
|
||||
{test::AsTensor<int>({2}, {1, 1}), test::AsTensor<int>({10}, {1, 1})},
|
||||
&outputs, nullptr, threadPoolOptions));
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({20}, {1, 1}));
|
||||
TF_EXPECT_OK(session.ReleaseCallable(callable));
|
||||
ASSERT_GT(inter_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
ASSERT_GT(intra_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
|
||||
// Free intra_op_threadpool and wait for its threads to exit before freeing
|
||||
// other objects (e.g. barrier). This is needed to avoid data race.
|
||||
intra_op_threadpool.reset();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,10 +1,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.cc"],
|
||||
|
@ -1,13 +1,6 @@
|
||||
# Description:
|
||||
# TensorFlow SavedModel.
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
@ -22,6 +15,13 @@ load(
|
||||
"if_static_and_not_mobile",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "constants",
|
||||
hdrs = ["constants.h"],
|
||||
|
@ -34,7 +34,6 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
|
@ -1,11 +1,11 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
# We disable some tfcompile tests in the open source build with the
|
||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||
# (once for host and once for target).
|
||||
|
@ -286,8 +286,6 @@ def tf_library(
|
||||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually
|
||||
# needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
|
@ -1,3 +1,8 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
@ -22,10 +27,6 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
# Target that bundles up the XLA CPU and GPU JIT devices.
|
||||
cc_library(
|
||||
name = "jit",
|
||||
@ -50,7 +51,6 @@ cc_library(
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
@ -192,18 +192,11 @@ cc_library(
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:host_constant_op",
|
||||
"//tensorflow/core/kernels:identity_n_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:logging_ops",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:queue_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:stack",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
@ -286,6 +279,7 @@ cc_library(
|
||||
srcs = ["xla_compilation_cache.cc"],
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -322,9 +316,10 @@ cc_library(
|
||||
srcs = ["jit_compilation_pass_registration.cc"],
|
||||
deps = [
|
||||
":compilation_passes",
|
||||
":xla_activity_logging_listener",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
],
|
||||
] + tf_jit_compilation_passes_extra_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -512,6 +507,7 @@ cc_library(
|
||||
"mark_for_compilation_pass.cc",
|
||||
"mark_for_compilation_pass_test_helper.cc",
|
||||
"partially_decluster_pass.cc",
|
||||
"report_clustering_info_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_ops_pass.h",
|
||||
@ -525,6 +521,7 @@ cc_library(
|
||||
"mark_for_compilation_pass.h",
|
||||
"mark_for_compilation_pass_test_helper.h",
|
||||
"partially_decluster_pass.h",
|
||||
"report_clustering_info_pass.h",
|
||||
],
|
||||
deps = [
|
||||
"compilability_check_util",
|
||||
@ -535,6 +532,7 @@ cc_library(
|
||||
":resource_operation_safety_analysis",
|
||||
":shape_inference_helpers",
|
||||
":union_find",
|
||||
":xla_activity_listener",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
@ -577,6 +575,7 @@ cc_library(
|
||||
hdrs = ["xla_cluster_util.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -843,6 +842,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_activity_listener_test",
|
||||
srcs = ["xla_activity_listener_test.cc"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_listener",
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "xla_ops_py",
|
||||
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||
@ -855,6 +875,37 @@ tf_custom_op_py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_activity_listener",
|
||||
srcs = ["xla_activity_listener.cc"],
|
||||
hdrs = ["xla_activity_listener.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "xla_activity_proto",
|
||||
srcs = ["xla_activity.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
provide_cc_alias = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_activity_logging_listener",
|
||||
srcs = ["xla_activity_logging_listener.cc"],
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
"//tensorflow/core:logger",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||
cc_header_only_library(
|
||||
name = "xla_jit_headers_lib",
|
||||
|
@ -72,7 +72,47 @@ void LogNotCompilable(const Node& node, absl::string_view reason = "") {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
node_stack_trace) const {
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
// If `node_stack_trace` is provided, that means `node` is inside
|
||||
// a function body, and therefore, arg nodes and retval nodes are
|
||||
// not considered uncompilable.
|
||||
if (node_stack_trace != nullptr) {
|
||||
for (const auto& frame : *node_stack_trace) {
|
||||
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{node.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
node_stack_trace) const {
|
||||
// If `node_stack_trace` is provided, that means `call_def` is inside
|
||||
// a function body, and therefore, arg nodes and retval nodes are
|
||||
// not considered uncompilable.
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
if (node_stack_trace != nullptr) {
|
||||
for (const auto& frame : *node_stack_trace) {
|
||||
stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name});
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
@ -104,7 +144,7 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
|
||||
bool RecursiveCompilabilityChecker::IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
const NameAttrList* name_attr;
|
||||
NodeDef call;
|
||||
Status status;
|
||||
@ -155,7 +195,7 @@ bool RecursiveCompilabilityChecker::IsCompilableWhile(
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
if (stack_trace->size() > kMaxRecursionDepth) {
|
||||
std::string uncompilable_reason = "function depth limit exceeded";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
@ -191,22 +231,24 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
|
||||
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
|
||||
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
|
||||
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
|
||||
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
|
||||
// b/135640736: MatrixInverse performance issues.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd" || node.type_string() == "Qr";
|
||||
node.type_string() == "Svd" || node.type_string() == "Qr" ||
|
||||
node.type_string() == "MatrixInverse";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) {
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
auto stack_depth = stack_trace->size();
|
||||
if (node.IsSource() || node.IsSink()) {
|
||||
absl::string_view uncompilable_reason = "source or sink node";
|
||||
@ -358,7 +400,7 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
return op_filter;
|
||||
}
|
||||
|
||||
void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
|
||||
/*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list) {
|
||||
|
@ -129,29 +129,25 @@ class RecursiveCompilabilityChecker {
|
||||
const DeviceType* jit_device_type)
|
||||
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
|
||||
|
||||
// Returns a list of uncompilable nodes.
|
||||
// Returns a list of uncompilable nodes. When `node` is inside a function
|
||||
// body, users can set `node_stack_trace` to provide an additional
|
||||
// context for `node`'s placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime) {
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
stack_trace.emplace_back(StackFrameView{node.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns a list of uncompilable nodes in `call_def` that cannot be
|
||||
// compiled by XLA. It is assumed that `call_def` is a call operation.
|
||||
// When `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime) {
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns true if `node` can be compiled by XLA.
|
||||
bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) {
|
||||
bool IsCompilableNode(const Node& node,
|
||||
FunctionLibraryRuntime* lib_runtime) const {
|
||||
std::vector<StackFrameView> stack_trace;
|
||||
stack_trace.emplace_back(StackFrameView{node.name(), ""});
|
||||
return IsCompilableNode(node, lib_runtime, &stack_trace);
|
||||
@ -168,8 +164,8 @@ class RecursiveCompilabilityChecker {
|
||||
|
||||
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
|
||||
// due to performance or correctness concerns).
|
||||
bool OpIsInaccurate(const Node& node);
|
||||
bool OpIsSlow(const Node& node);
|
||||
bool OpIsInaccurate(const Node& node) const;
|
||||
bool OpIsSlow(const Node& node) const;
|
||||
|
||||
private:
|
||||
struct StackFrameView {
|
||||
@ -180,47 +176,47 @@ class RecursiveCompilabilityChecker {
|
||||
bool IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr);
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr);
|
||||
bool IsCompilableWhile(const Node& while_node,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes);
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
|
||||
bool IsStackOp(const Node& node) {
|
||||
bool IsStackOp(const Node& node) const {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
GetResourceOpInfoForOp(node.type_string());
|
||||
return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
|
||||
}
|
||||
|
||||
bool IsTensorArrayOp(const Node& node) {
|
||||
bool IsTensorArrayOp(const Node& node) const {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
GetResourceOpInfoForOp(node.type_string());
|
||||
return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
|
||||
}
|
||||
|
||||
bool IsAssertOrCheckNumerics(absl::string_view op_name) {
|
||||
bool IsAssertOrCheckNumerics(absl::string_view op_name) const {
|
||||
return op_name == "Assert" || op_name == "CheckNumerics";
|
||||
}
|
||||
|
||||
bool IsStatefulRandomOp(absl::string_view op_name) {
|
||||
bool IsStatefulRandomOp(absl::string_view op_name) const {
|
||||
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
||||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
||||
op_name == "TruncatedNormal" || op_name == "Multinomial";
|
||||
}
|
||||
|
||||
bool OpProducesOrConsumesVariant(const Node& node) {
|
||||
bool OpProducesOrConsumesVariant(const Node& node) const {
|
||||
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
|
||||
return absl::c_any_of(node.input_types(), is_variant) ||
|
||||
absl::c_any_of(node.output_types(), is_variant);
|
||||
}
|
||||
|
||||
bool HasXLAKernel(const Node& node);
|
||||
bool HasXLAKernel(const Node& node) const;
|
||||
|
||||
void MaybeMarkUncompilableNode(
|
||||
static void MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -36,7 +37,7 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
|
||||
IntroduceFloatingPointJitterPass);
|
||||
|
||||
// from
|
||||
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
|
||||
// tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
|
||||
// FunctionalizeControlFlowPass: 27
|
||||
//
|
||||
// This pass looks at the graph and all associated FunctionDefs, and turns
|
||||
@ -58,15 +59,22 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
PartiallyDeclusterPass);
|
||||
|
||||
// ReportClusteringInfoPass pass needs to run after all of the auto-clustering
|
||||
// passes have run but before encapsulation has run. This way it can easily
|
||||
// compute a summary of the clustering decisions we made and broadcast it via
|
||||
// xla_activity_listener.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
ReportClusteringInfoPass);
|
||||
|
||||
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
||||
// also need to run it after the graph been rewritten to have _Send nodes added
|
||||
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
||||
// name, and encapsulation might remove that node from the graph.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||
EncapsulateSubgraphsPass);
|
||||
|
||||
// Must run after EncapsulateSubgraphsPass.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 60,
|
||||
BuildXlaOpsPass);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
|
||||
@ -525,9 +526,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
// We're missing the must-be-constant inputs, tell `PopulateInputs`
|
||||
// about this. We don't actually need these inputs because they've
|
||||
// already been baked into the compiled kernel.
|
||||
launch_context.PopulateInputs(
|
||||
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args());
|
||||
{
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
return absl::StrCat(
|
||||
"Populate Inputs (",
|
||||
closure.compilation_result()->xla_input_shapes.size(), ")");
|
||||
},
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
|
||||
launch_context.PopulateInputs(
|
||||
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args());
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
@ -546,6 +557,12 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
|
||||
},
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
|
@ -1014,6 +1014,39 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
if (!node->IsIdentity()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the Identity is driven by a Switch on its true path.
|
||||
auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
|
||||
return e->src()->IsSwitch() && e->src_output() == 1;
|
||||
});
|
||||
if (it == node->in_edges().end()) {
|
||||
return false;
|
||||
}
|
||||
const Node* switch_node = (*it)->src();
|
||||
|
||||
// Check if the Switch is driven by LoopCond.
|
||||
const Node* maybe_loop_cond;
|
||||
TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
|
||||
if (!maybe_loop_cond->IsLoopCond()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the Identity is driving any const nodes through a control edge.
|
||||
bool driving_any_consts =
|
||||
absl::c_any_of(node->out_edges(), [](const Edge* e) {
|
||||
return e->dst()->IsConstant() && e->IsControlEdge();
|
||||
});
|
||||
if (!driving_any_consts) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
@ -1135,6 +1168,35 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
}
|
||||
}
|
||||
|
||||
// This is a heuristic to avoid creating dependency between while loop
|
||||
// condition and body computations. Dependency between them can be created
|
||||
// if a special Identity node in the following pattern is clustered in.
|
||||
// That is, an Identity node in the loop cond computation is used to drive
|
||||
// const nodes consumed by the loop body. If this Identity node goes into
|
||||
// the same cluster with nodes from the loop body, extra dependency is
|
||||
// created between the loop cond and body computations and it hinders the
|
||||
// progression of the loop cond computation at runtime with significant
|
||||
// overhead. Specifically, we look for the below pattern and do not cluster
|
||||
// in this Identity to avoid the described issue. Since Identity has low
|
||||
// execution cost in native TF, the fact that this heuristic gives up these
|
||||
// special Identity nodes as candidates should not harm any performance. If
|
||||
// other considerations emerge in the future, we can revisit the heuristic
|
||||
// and only disallow these Identities to go into the cluster with nodes from
|
||||
// the loop body but still consider them candidates.
|
||||
//
|
||||
// LoopCond ->
|
||||
// Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
|
||||
// ..> Const -> LoopBody
|
||||
// (control edge)
|
||||
TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
|
||||
IsIdentityDrivingConstsInLoop(node));
|
||||
if (is_identity_driving_consts_in_loop) {
|
||||
VLOG(2) << "Rejecting " << node->name()
|
||||
<< ": including it can create dependencies between while loop "
|
||||
"condition and body computations with runtime overhead.";
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates_.insert(node);
|
||||
--(*debug_options_.fuel);
|
||||
}
|
||||
@ -1302,46 +1364,36 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<absl::string_view, int> cluster_name_to_size;
|
||||
std::map<absl::string_view, std::map<absl::string_view, int>>
|
||||
cluster_name_to_op_histogram;
|
||||
std::map<absl::string_view, int> unclustered_op_histogram;
|
||||
int clustered_node_count = 0;
|
||||
|
||||
for (Node* n : graph_->nodes()) {
|
||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||
if (cluster_name) {
|
||||
clustered_node_count++;
|
||||
cluster_name_to_size[*cluster_name]++;
|
||||
cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
|
||||
} else {
|
||||
unclustered_op_histogram[n->type_string()]++;
|
||||
}
|
||||
}
|
||||
|
||||
int unclustered_node_count = graph_->num_nodes() - clustered_node_count;
|
||||
XlaAutoClusteringSummary auto_clustering_info =
|
||||
GetXlaAutoClusteringSummary(*graph_);
|
||||
|
||||
VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
|
||||
VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
|
||||
<< RatioToString(clustered_node_count, graph_->num_nodes());
|
||||
VLOG(2) << " Built " << auto_clustering_info.clusters_size()
|
||||
<< " clusters, size "
|
||||
<< RatioToString(auto_clustering_info.clustered_node_count(),
|
||||
graph_->num_nodes());
|
||||
|
||||
for (const auto& cluster_name_size_pair : cluster_name_to_size) {
|
||||
absl::string_view cluster_name = cluster_name_size_pair.first;
|
||||
int size = cluster_name_size_pair.second;
|
||||
for (XlaAutoClusteringSummary::Cluster cluster :
|
||||
auto_clustering_info.clusters()) {
|
||||
absl::string_view cluster_name = cluster.name();
|
||||
int size = cluster.size();
|
||||
VLOG(2) << " " << cluster_name << " "
|
||||
<< RatioToString(size, graph_->num_nodes());
|
||||
for (const auto& op_count_pair :
|
||||
cluster_name_to_op_histogram[cluster_name]) {
|
||||
VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
|
||||
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||
cluster.op_histogram()) {
|
||||
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||
<< " instances";
|
||||
}
|
||||
}
|
||||
|
||||
if (!unclustered_op_histogram.empty()) {
|
||||
if (!auto_clustering_info.unclustered_op_histogram().empty()) {
|
||||
VLOG(2) << " Unclustered nodes: "
|
||||
<< RatioToString(unclustered_node_count, graph_->num_nodes());
|
||||
for (const auto& pair : unclustered_op_histogram) {
|
||||
VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
|
||||
<< RatioToString(auto_clustering_info.unclustered_node_count(),
|
||||
graph_->num_nodes());
|
||||
for (const XlaAutoClusteringSummary::OpAndCount& op_count :
|
||||
auto_clustering_info.unclustered_op_histogram()) {
|
||||
VLOG(3) << " " << op_count.op() << ": " << op_count.count()
|
||||
<< " instances";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,6 @@ absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
|
||||
|
||||
TEST(XlaCompilationTest, Chains) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
@ -114,7 +113,6 @@ TEST(XlaCompilationTest, Chains) {
|
||||
|
||||
TEST(XlaCompilationTest, UncompilableCycles) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -135,7 +133,6 @@ TEST(XlaCompilationTest, UncompilableCycles) {
|
||||
|
||||
TEST(XlaCompilationTest, CompilableCycles) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -157,7 +154,6 @@ TEST(XlaCompilationTest, CompilableCycles) {
|
||||
|
||||
TEST(XlaCompilationTest, StringUnsupported) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
@ -177,7 +173,6 @@ TEST(XlaCompilationTest, StringUnsupported) {
|
||||
|
||||
TEST(XlaCompilationTest, HalfSupported) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Tensor t(DT_HALF, TensorShape());
|
||||
@ -253,7 +248,6 @@ TEST(XlaCompilationTest, FunctionCalls) {
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(&flib_def));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
Node* a =
|
||||
@ -291,7 +285,6 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(&flib_def));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
Node* resource =
|
||||
@ -381,7 +374,6 @@ REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
|
||||
|
||||
TEST(XlaCompilationTest, SymbolicGradients) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
@ -449,7 +441,6 @@ TEST(XlaCompilationTest, Loops) {
|
||||
|
||||
TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -483,7 +474,6 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
|
||||
|
||||
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -512,7 +502,6 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
|
||||
|
||||
TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -555,7 +544,6 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
|
||||
|
||||
TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -789,7 +777,6 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
|
||||
|
||||
TEST(XlaCompilationTest, Retval) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
@ -1677,5 +1664,37 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
|
||||
EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
|
||||
}
|
||||
|
||||
// Test a pattern where a special Identity node is driving consts in a loop.
|
||||
// Expect that the Identity node will not go into any clusters. Note that we
|
||||
// create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
|
||||
// etc.) just enough to test the pattern, as a complete graph may be too
|
||||
// cumbersome and unnecessary.
|
||||
TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
|
||||
ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
|
||||
|
||||
Output identity =
|
||||
ops::Identity(root.WithOpName("identity"), switch_node.output_true);
|
||||
Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
|
||||
root.graph()->AddControlEdge(identity.node(), const_node.node());
|
||||
Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
|
||||
Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
|
||||
Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(clusters["identity"], "");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,10 +1,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
srcs = ["xla_ops.cc"],
|
||||
|
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.cc
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/report_clustering_info_pass.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
Status ReportClusteringInfoPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
XlaAutoClusteringActivity activity;
|
||||
*activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph);
|
||||
activity.set_global_jit_level(GetGlobalJitLevelForGraph(options));
|
||||
activity.set_cpu_global_jit_enabled(
|
||||
GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit);
|
||||
return BroadcastXlaActivity(std::move(activity));
|
||||
}
|
||||
} // namespace tensorflow
|
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
32
tensorflow/compiler/jit/report_clustering_info_pass.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This is not really an optimization pass. It does not change the graph in any
|
||||
// way; instead it computes a summary of the XLA clusters in the graph and
|
||||
// broadcasts it via xla_activity_listener.
|
||||
class ReportClusteringInfoPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_
|
@ -1,8 +1,8 @@
|
||||
Clustered nodes: 1988
|
||||
Unclustered nodes: 3960
|
||||
Clustered nodes: 1962
|
||||
Unclustered nodes: 3974
|
||||
Number of clusters: 29
|
||||
|
||||
unclustered size 3960
|
||||
unclustered size 3974
|
||||
Add 17
|
||||
AddN 1
|
||||
ApplyAdam 38
|
||||
@ -14,7 +14,7 @@ unclustered size 3960
|
||||
Cast 8
|
||||
ConcatOffset 10
|
||||
ConcatV2 2
|
||||
Const 704
|
||||
Const 708
|
||||
ControlTrigger 5
|
||||
DynamicStitch 1
|
||||
Enter 874
|
||||
@ -24,7 +24,7 @@ unclustered size 3960
|
||||
FloorDiv 1
|
||||
FloorMod 1
|
||||
GreaterEqual 7
|
||||
Identity 105
|
||||
Identity 113
|
||||
IsVariableInitialized 1
|
||||
IteratorGetNext 1
|
||||
IteratorV2 1
|
||||
@ -42,7 +42,7 @@ unclustered size 3960
|
||||
RefSwitch 166
|
||||
Reshape 2
|
||||
ScatterAdd 4
|
||||
Shape 4
|
||||
Shape 6
|
||||
ShapeN 10
|
||||
Size 2
|
||||
Snapshot 1
|
||||
@ -169,15 +169,13 @@ cluster 7 size 11
|
||||
Mul 2
|
||||
Pow 1
|
||||
Sub 1
|
||||
cluster 10 size 14
|
||||
Add 2
|
||||
cluster 10 size 8
|
||||
Add 1
|
||||
All 2
|
||||
Const 4
|
||||
Const 2
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
Less 1
|
||||
LogicalOr 1
|
||||
Shape 2
|
||||
cluster 11 size 226
|
||||
Add 24
|
||||
BatchMatMulV2 1
|
||||
@ -226,13 +224,12 @@ cluster 12 size 430
|
||||
TanhGrad 17
|
||||
Tile 2
|
||||
ZerosLike 1
|
||||
cluster 13 size 25
|
||||
Add 3
|
||||
cluster 13 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
@ -256,13 +253,12 @@ cluster 14 size 52
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 15 size 25
|
||||
Add 3
|
||||
cluster 15 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
@ -290,14 +286,13 @@ cluster 17 size 52
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 19 size 30
|
||||
Add 3
|
||||
cluster 19 size 25
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
Const 5
|
||||
Const 3
|
||||
GreaterEqual 2
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
@ -305,77 +300,7 @@ cluster 19 size 30
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 20 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 21 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 22 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 24 size 24
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 25 size 363
|
||||
cluster 20 size 363
|
||||
Add 12
|
||||
AddN 28
|
||||
BiasAddGrad 6
|
||||
@ -391,6 +316,71 @@ cluster 25 size 363
|
||||
Slice 12
|
||||
Sum 76
|
||||
TanhGrad 12
|
||||
cluster 21 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 22 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 24 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 25 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 26 size 9
|
||||
AddN 1
|
||||
MatMul 2
|
||||
|
96
tensorflow/compiler/jit/xla_activity.proto
Normal file
96
tensorflow/compiler/jit/xla_activity.proto
Normal file
@ -0,0 +1,96 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/protobuf/config.proto";
|
||||
|
||||
// Summarizes the results of auto-clustering a TensorFlow graph.
|
||||
//
|
||||
// Next ID: 5
|
||||
message XlaAutoClusteringSummary {
|
||||
// Represents a single element in a histogram of ops ("op" as in "TensorFlow
|
||||
// operation").
|
||||
//
|
||||
// Next ID: 3
|
||||
message OpAndCount {
|
||||
// The TensorFlow operation (like MatMult, Add etc.)
|
||||
string op = 1;
|
||||
|
||||
// The number of times this occurs.
|
||||
int32 count = 2;
|
||||
}
|
||||
|
||||
// Describes a single XLA cluster.
|
||||
//
|
||||
// Next ID: 4
|
||||
message Cluster {
|
||||
string name = 1;
|
||||
|
||||
// The number of nodes in the cluster.
|
||||
int32 size = 2;
|
||||
|
||||
// A histogram of the TF operations in this cluster.
|
||||
repeated OpAndCount op_histogram = 3;
|
||||
};
|
||||
|
||||
// The number of nodes in the graph that are not inside an XLA cluster.
|
||||
int32 unclustered_node_count = 1;
|
||||
|
||||
// The number of nodes in the graph that are in an XLA cluster.
|
||||
int32 clustered_node_count = 2;
|
||||
|
||||
// All of the XLA clusters in the TF graph.
|
||||
repeated Cluster clusters = 3;
|
||||
|
||||
// A histogram of the TF operations that were not clustered.
|
||||
repeated OpAndCount unclustered_op_histogram = 4;
|
||||
}
|
||||
|
||||
// Listeners listening for auto clustering events get messages of this type.
|
||||
//
|
||||
// Next ID: 4
|
||||
message XlaAutoClusteringActivity {
|
||||
// The value of GlobalJitLevel, as determined by `GetGlobalJitLevelForGraph`.
|
||||
// This determines if global auto-clustering is enabled.
|
||||
OptimizerOptions.GlobalJitLevel global_jit_level = 1;
|
||||
|
||||
// Whether --tf_xla_cpu_global_jit is enabled in TF_XLA_FLAGS.
|
||||
bool cpu_global_jit_enabled = 2;
|
||||
|
||||
XlaAutoClusteringSummary summary = 3;
|
||||
}
|
||||
|
||||
// Listeners listening for JIT compilation events get messages of this type.
|
||||
// Each instance of XlaJitCompilationActivity corresponds to a single
|
||||
// compilation of a single XLA cluster. E.g. if a graph has two clusters, A and
|
||||
// B, and A is compiled 5 times and B is compiled 2 times then we will generate
|
||||
// 7 instances of XlaJitCompilationActivity.
|
||||
//
|
||||
// Next ID: 5
|
||||
message XlaJitCompilationActivity {
|
||||
string cluster_name = 1;
|
||||
|
||||
// The number of time this cluster has been compiled.
|
||||
int32 compile_count = 2;
|
||||
|
||||
// Microseconds spent in the individual compilation being reported.
|
||||
int64 compile_time_us = 3;
|
||||
|
||||
// Total microseconds spent in (re-)compiling this cluster so far.
|
||||
int64 cumulative_compile_time_us = 4;
|
||||
}
|
86
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
86
tensorflow/compiler/jit/xla_activity_listener.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// The list of all registered `XlaActivityListener`s.
|
||||
struct XlaActivityListenerList {
|
||||
absl::Mutex mutex;
|
||||
std::vector<std::unique_ptr<XlaActivityListener>> listeners GUARDED_BY(mutex);
|
||||
};
|
||||
|
||||
void FlushAllListeners();
|
||||
|
||||
XlaActivityListenerList* GetXlaActivityListenerList() {
|
||||
static XlaActivityListenerList* listener_list = new XlaActivityListenerList;
|
||||
static int unused = std::atexit(FlushAllListeners);
|
||||
(void)unused;
|
||||
return listener_list;
|
||||
}
|
||||
|
||||
template <typename FnTy>
|
||||
Status ForEachListener(FnTy fn) {
|
||||
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||
absl::ReaderMutexLock reader_lock(&listener_list->mutex);
|
||||
|
||||
for (const std::unique_ptr<XlaActivityListener>& listener :
|
||||
listener_list->listeners) {
|
||||
TF_RETURN_IF_ERROR(fn(listener.get()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void FlushAllListeners() {
|
||||
Status s = ForEachListener([](XlaActivityListener* listener) {
|
||||
listener->Flush();
|
||||
return Status::OK();
|
||||
});
|
||||
CHECK(s.ok());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status BroadcastXlaActivity(
|
||||
XlaAutoClusteringActivity auto_clustering_activity) {
|
||||
return ForEachListener([&](XlaActivityListener* listener) {
|
||||
return listener->Listen(auto_clustering_activity);
|
||||
});
|
||||
}
|
||||
|
||||
Status BroadcastXlaActivity(
|
||||
XlaJitCompilationActivity jit_compilation_activity) {
|
||||
return ForEachListener([&](XlaActivityListener* listener) {
|
||||
return listener->Listen(jit_compilation_activity);
|
||||
});
|
||||
}
|
||||
|
||||
void RegisterXlaActivityListener(
|
||||
std::unique_ptr<XlaActivityListener> listener) {
|
||||
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||
absl::WriterMutexLock writer_lock(&listener_list->mutex);
|
||||
|
||||
listener_list->listeners.push_back(std::move(listener));
|
||||
}
|
||||
|
||||
void XlaActivityListener::Flush() {}
|
||||
|
||||
XlaActivityListener::~XlaActivityListener() {}
|
||||
|
||||
} // namespace tensorflow
|
58
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
58
tensorflow/compiler/jit/xla_activity_listener.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Broadcast `auto_clustering_activity` to all the registered listeners.
|
||||
Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
|
||||
|
||||
// Broadcast `jit_compilation_activity` to all the registered listeners.
|
||||
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
|
||||
|
||||
// Various components of the system can subclass XlaActivityListener to
|
||||
// notifications on auto-clustering and JIT compilation events.
|
||||
//
|
||||
// Subclasses of XlaActivityListener must be thread safe.
|
||||
class XlaActivityListener {
|
||||
public:
|
||||
// Called after TensorFlow auto-clusters a graph.
|
||||
virtual Status Listen(
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity) = 0;
|
||||
|
||||
// Called after TensorFlow JIT compiles an XLA cluster.
|
||||
virtual Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
|
||||
|
||||
// Called at program exit in best-effort manner to give listeners a chance to
|
||||
// flush their state.
|
||||
//
|
||||
// Default implementation is a no-op.
|
||||
virtual void Flush();
|
||||
|
||||
virtual ~XlaActivityListener();
|
||||
};
|
||||
|
||||
// Registers an `XlaActivityListener`, which will be invoked on all subsequent
|
||||
// `BroadcastXlaActivity` calls.
|
||||
void RegisterXlaActivityListener(std::unique_ptr<XlaActivityListener> listener);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_
|
193
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
193
tensorflow/compiler/jit/xla_activity_listener_test.cc
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/list_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/direct_session.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TestListener : public XlaActivityListener {
|
||||
public:
|
||||
Status Listen(
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity) override {
|
||||
auto_clustering_activity_ = auto_clustering_activity;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) override {
|
||||
jit_compilation_activity_ = jit_compilation_activity;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
~TestListener() override {}
|
||||
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity() const {
|
||||
return auto_clustering_activity_;
|
||||
}
|
||||
const XlaJitCompilationActivity& jit_compilation_activity() const {
|
||||
return jit_compilation_activity_;
|
||||
}
|
||||
|
||||
private:
|
||||
XlaAutoClusteringActivity auto_clustering_activity_;
|
||||
XlaJitCompilationActivity jit_compilation_activity_;
|
||||
};
|
||||
|
||||
class XlaActivityListenerTest : public ::testing::Test {
|
||||
protected:
|
||||
XlaActivityListenerTest() {
|
||||
auto listener = absl::make_unique<TestListener>();
|
||||
listener_ = listener.get();
|
||||
RegisterXlaActivityListener(std::move(listener));
|
||||
}
|
||||
|
||||
TestListener* listener() const { return listener_; }
|
||||
|
||||
private:
|
||||
TestListener* listener_;
|
||||
};
|
||||
|
||||
GraphDef CreateGraphDef() {
|
||||
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
|
||||
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
|
||||
}
|
||||
|
||||
GraphDef graph_def;
|
||||
root.graph()->ToGraphDef(&graph_def);
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
TEST_F(XlaActivityListenerTest, Test) {
|
||||
GraphDef graph_def = CreateGraphDef();
|
||||
SessionOptions options;
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
|
||||
TF_ASSERT_OK(session->Create(graph_def));
|
||||
|
||||
std::vector<std::string> output_names = {std::string("add_4:0")};
|
||||
|
||||
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
|
||||
}
|
||||
|
||||
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
|
||||
for (int i = 0; i < 9; i++) {
|
||||
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
|
||||
}
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
|
||||
&outputs));
|
||||
|
||||
absl::string_view expected_auto_clustering_activity =
|
||||
R"(global_jit_level: ON_2
|
||||
cpu_global_jit_enabled: true
|
||||
summary {
|
||||
unclustered_node_count: 4
|
||||
clustered_node_count: 14
|
||||
clusters {
|
||||
name: "cluster_0"
|
||||
size: 14
|
||||
op_histogram {
|
||||
op: "Add"
|
||||
count: 1
|
||||
}
|
||||
op_histogram {
|
||||
op: "Const"
|
||||
count: 4
|
||||
}
|
||||
op_histogram {
|
||||
op: "MatMul"
|
||||
count: 5
|
||||
}
|
||||
op_histogram {
|
||||
op: "Mul"
|
||||
count: 4
|
||||
}
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "NoOp"
|
||||
count: 2
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "_Arg"
|
||||
count: 1
|
||||
}
|
||||
unclustered_op_histogram {
|
||||
op: "_Retval"
|
||||
count: 1
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
|
||||
expected_auto_clustering_activity);
|
||||
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
|
||||
|
||||
int64 first_compile_time =
|
||||
listener()->jit_compilation_activity().compile_time_us();
|
||||
EXPECT_GT(first_compile_time, 0);
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||
first_compile_time);
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
|
||||
|
||||
outputs.clear();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
|
||||
/*target_node_names=*/{}, &outputs));
|
||||
}
|
||||
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
|
||||
|
||||
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
|
||||
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
|
||||
first_compile_time +
|
||||
listener()->jit_compilation_activity().compile_time_us());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
69
tensorflow/compiler/jit/xla_activity_logging_listener.cc
Normal file
69
tensorflow/compiler/jit/xla_activity_logging_listener.cc
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Listens to XLA activity and logs them using tensorflow::Logger.
|
||||
class XlaActivityLoggingListener final : public XlaActivityListener {
|
||||
public:
|
||||
Status Listen(
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity) override {
|
||||
if (!IsEnabled()) {
|
||||
VLOG(3) << "Logging XlaAutoClusteringActivity disabled";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(2) << "Logging XlaAutoClusteringActivity";
|
||||
VLOG(3) << auto_clustering_activity.DebugString();
|
||||
Logger::Singleton()->LogProto(auto_clustering_activity);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) override {
|
||||
if (!IsEnabled()) {
|
||||
VLOG(3) << "Logging XlaJitCompilationActivity disabled";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(2) << "Logging XlaJitCompilationActivity";
|
||||
VLOG(3) << jit_compilation_activity.DebugString();
|
||||
Logger::Singleton()->LogProto(jit_compilation_activity);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsEnabled() {
|
||||
static bool result = ComputeIsEnabled();
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ComputeIsEnabled() {
|
||||
char* log_xla_activity = getenv("TF_LOG_XLA_ACTIVITY");
|
||||
return log_xla_activity && !strcmp(log_xla_activity, "1");
|
||||
}
|
||||
};
|
||||
|
||||
bool Register() {
|
||||
RegisterXlaActivityListener(absl::make_unique<XlaActivityLoggingListener>());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool unused = Register();
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -318,4 +318,72 @@ bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "Rank" ||
|
||||
node.type_string() == "Size";
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ClusterInfo {
|
||||
int size;
|
||||
|
||||
// Maps op names to the number of times they appear in the cluster.
|
||||
absl::flat_hash_map<absl::string_view, int> op_histogram;
|
||||
};
|
||||
|
||||
void HistogramMapToRepeatedOpAndCount(
|
||||
protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
|
||||
const absl::flat_hash_map<absl::string_view, int>& histogram) {
|
||||
for (const auto& pair : histogram) {
|
||||
XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
|
||||
new_entry->set_op(std::string(pair.first));
|
||||
new_entry->set_count(pair.second);
|
||||
}
|
||||
|
||||
absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
|
||||
const XlaAutoClusteringSummary::OpAndCount& b) {
|
||||
return a.op() < b.op();
|
||||
});
|
||||
}
|
||||
|
||||
void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
|
||||
absl::string_view name, const ClusterInfo& info) {
|
||||
result->set_name(std::string(name));
|
||||
result->set_size(info.size);
|
||||
HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
|
||||
info.op_histogram);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
||||
absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
|
||||
XlaAutoClusteringSummary result;
|
||||
|
||||
absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
|
||||
|
||||
for (Node* n : graph.nodes()) {
|
||||
absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
|
||||
if (cluster_name) {
|
||||
result.set_clustered_node_count(result.clustered_node_count() + 1);
|
||||
ClusterInfo* info = &cluster_name_to_info[*cluster_name];
|
||||
info->size++;
|
||||
info->op_histogram[n->type_string()]++;
|
||||
} else {
|
||||
result.set_unclustered_node_count(result.unclustered_node_count() + 1);
|
||||
unclustered_op_histogram[n->type_string()]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& pair : cluster_name_to_info) {
|
||||
XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
|
||||
ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
|
||||
}
|
||||
|
||||
absl::c_sort(*result.mutable_clusters(),
|
||||
[&](const XlaAutoClusteringSummary::Cluster& a,
|
||||
const XlaAutoClusteringSummary::Cluster& b) {
|
||||
return a.name() < b.name();
|
||||
});
|
||||
|
||||
HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
|
||||
unclustered_op_histogram);
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -87,6 +89,11 @@ bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
|
||||
// Returns true if `node` an operator that consumes only the shape of its input,
|
||||
// not the data itself.
|
||||
bool IsShapeConsumerOp(const Node& node);
|
||||
|
||||
// Computes a clustering summary for `graph`. See documentation on
|
||||
// `XlaAutoClusteringSummary` for details.
|
||||
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -361,6 +362,16 @@ Status XlaCompilationCache::CompileImpl(
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||
it->second.cumulative_compile_time_us / 1.0e6)
|
||||
<< ")";
|
||||
|
||||
XlaJitCompilationActivity jit_compilation_activity;
|
||||
jit_compilation_activity.set_cluster_name(function.name());
|
||||
jit_compilation_activity.set_compile_count(it->second.compile_count);
|
||||
jit_compilation_activity.set_compile_time_us(compile_time_us);
|
||||
jit_compilation_activity.set_cumulative_compile_time_us(
|
||||
it->second.cumulative_compile_time_us);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastXlaActivity(std::move(jit_compilation_activity)));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
|
@ -375,18 +375,6 @@ Status XlaDevice::FillContextMap(const Graph* graph,
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat("XlaDevice::Compute ", op_kernel->name(), ":",
|
||||
op_kernel->type_string(),
|
||||
"#step_id=", context->step_id(),
|
||||
",step_container_name=",
|
||||
context->step_container() == nullptr
|
||||
? "n/a"
|
||||
: context->step_container()->name(),
|
||||
"#");
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
@ -394,18 +382,6 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat("XlaDevice::ComputeAsync ", op_kernel->name(), ":",
|
||||
op_kernel->type_string(),
|
||||
"#step_id=", context->step_id(),
|
||||
",step_container_name=",
|
||||
context->step_container() == nullptr
|
||||
? "n/a"
|
||||
: context->step_container()->name(),
|
||||
"#");
|
||||
},
|
||||
profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
}
|
||||
|
||||
|
@ -21,22 +21,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/data/generator_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/fifo_queue.h"
|
||||
#include "tensorflow/core/kernels/function_ops.h"
|
||||
#include "tensorflow/core/kernels/host_constant_op.h"
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/logging_ops.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -80,24 +73,14 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
|
||||
|
||||
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Assert") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("condition") \
|
||||
.HostMemory("data"), \
|
||||
AssertOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
ConstantOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
|
||||
ResourceHandleOp<Var>); \
|
||||
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \
|
||||
ResourceHandlesOp<Var>); \
|
||||
@ -153,36 +136,6 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
|
||||
XlaAssignVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
|
||||
ControlTriggerOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
|
||||
SwitchOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
|
||||
NextIterationOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("LoopCond") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output"), \
|
||||
LoopCondOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("size") \
|
||||
.HostMemory("handle"), \
|
||||
QueueSizeOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
|
||||
QueueIsClosedOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
|
||||
@ -262,25 +215,7 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("StackV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("max_size") \
|
||||
.HostMemory("handle"), \
|
||||
StackOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("elem_type", TYPES), \
|
||||
StackPopOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
|
||||
RetvalOp);
|
||||
|
||||
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
|
@ -314,6 +314,21 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_inverse_op_test",
|
||||
size = "small",
|
||||
timeout = "moderate",
|
||||
srcs = ["matrix_inverse_op_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "small",
|
||||
@ -557,6 +572,7 @@ tf_xla_py_test(
|
||||
name = "image_ops_test",
|
||||
size = "small",
|
||||
srcs = ["image_ops_test.py"],
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"optonly", # Times out frequently in fastbuild mode.
|
||||
],
|
||||
@ -598,6 +614,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "manip_ops_test",
|
||||
size = "small",
|
||||
srcs = ["manip_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:manip_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_band_part_test",
|
||||
size = "medium",
|
||||
@ -970,6 +999,7 @@ tf_xla_py_test(
|
||||
name = "unary_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["unary_ops_test.py"],
|
||||
tags = ["notap"], # b/136030724
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -51,7 +51,7 @@ class ArgMinMaxTest(xla_test.XLATestCase):
|
||||
def testArgMinMax(self):
|
||||
# Complex numbers do not support argmin/argmax.
|
||||
minmax_types = self.all_types & {np.int32, np.int64}
|
||||
for dtype in minmax_types:
|
||||
for dtype in self.int_types | self.float_types:
|
||||
# output_type is a numpy data type that is used to specify the desired
|
||||
# output type of the op as well as to convert the Python number to the
|
||||
# array scalar of the type.
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -312,30 +313,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
dtype(7),
|
||||
expected=np.array([[-6], [-5]], dtype=dtype))
|
||||
|
||||
if dtype in [np.float32, np.float64]:
|
||||
x = np.array([
|
||||
-0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0,
|
||||
1.0
|
||||
],
|
||||
dtype=dtype)
|
||||
y = np.array(
|
||||
[-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0],
|
||||
dtype=dtype)
|
||||
expected = np.nextafter(x, y)
|
||||
|
||||
# We use assertAllEqual to expose any bugs hidden by relative or
|
||||
# absolute error tolerances.
|
||||
def NextAfterEqualityTest(result, expected, rtol):
|
||||
del rtol
|
||||
return self.assertAllEqual(result, expected)
|
||||
|
||||
self._testBinary(
|
||||
math_ops.nextafter,
|
||||
x,
|
||||
y,
|
||||
expected=expected,
|
||||
equality_test=NextAfterEqualityTest)
|
||||
|
||||
# min/max not supported for complex
|
||||
if dtype not in self.complex_types | {np.uint8, np.int8}:
|
||||
self._testBinary(
|
||||
@ -423,6 +400,32 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
|
||||
dtype=np.int64))
|
||||
|
||||
def testNextAfter(self):
|
||||
for dtype in self.numeric_types:
|
||||
if dtype in [np.float32, np.float64]:
|
||||
x = np.array([
|
||||
-0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0,
|
||||
1.0
|
||||
],
|
||||
dtype=dtype)
|
||||
y = np.array(
|
||||
[-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0],
|
||||
dtype=dtype)
|
||||
expected = np.nextafter(x, y)
|
||||
|
||||
# We use assertAllEqual to expose any bugs hidden by relative or
|
||||
# absolute error tolerances.
|
||||
def NextAfterEqualityTest(result, expected, rtol):
|
||||
del rtol
|
||||
return self.assertAllEqual(result, expected)
|
||||
|
||||
self._testBinary(
|
||||
math_ops.nextafter,
|
||||
x,
|
||||
y,
|
||||
expected=expected,
|
||||
equality_test=NextAfterEqualityTest)
|
||||
|
||||
def testComplexOps(self):
|
||||
for dtype in self.complex_types:
|
||||
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
|
||||
@ -1478,10 +1481,12 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
expected=None)
|
||||
|
||||
def testMatrixSetDiag(self):
|
||||
# TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to
|
||||
# gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2).
|
||||
for dtype in self.numeric_types:
|
||||
# Square
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
dtype=dtype),
|
||||
np.array([1.0, 2.0, 3.0], dtype=dtype),
|
||||
@ -1489,7 +1494,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
|
||||
dtype=dtype),
|
||||
@ -1501,19 +1506,19 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
# Rectangular
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
|
||||
np.array([[-1.0, -2.0], [-4.0, -5.0]],
|
||||
|
@ -131,15 +131,19 @@ class FFTTest(xla_test.XLATestCase):
|
||||
signal.ifft3d)
|
||||
|
||||
def testRFFT(self):
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
|
||||
lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
|
||||
|
||||
def _to_expected(x):
|
||||
return np.fft.rfft(x, n=x.shape[-1])
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft(x, fft_length=[x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_1D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
def testRFFT2D(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft2d(
|
||||
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
|
||||
return signal.rfft2d(x, fft_length=[x.shape[-2], x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_2D, np.real,
|
||||
@ -153,8 +157,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft3d(
|
||||
x,
|
||||
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
|
||||
x, fft_length=[x.shape[-3], x.shape[-2], x.shape[-1]])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
@ -168,17 +171,14 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.rfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||
])
|
||||
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||
|
||||
def testIRFFT(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
|
||||
return signal.irfft(x, fft_length=[2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
|
||||
@ -187,8 +187,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
def testIRFFT2D(self):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft2d(
|
||||
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
|
||||
return signal.irfft2d(x, fft_length=[x.shape[-2], 2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(
|
||||
INNER_DIMS_2D,
|
||||
@ -212,10 +211,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
|
||||
])
|
||||
x, fft_length=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||
|
||||
@ -235,10 +231,7 @@ class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _tf_fn(x):
|
||||
return signal.irfft3d(
|
||||
x,
|
||||
fft_length=[
|
||||
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||
])
|
||||
x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||
|
||||
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||
|
||||
|
68
tensorflow/compiler/tests/manip_ops_test.py
Normal file
68
tensorflow/compiler/tests/manip_ops_test.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for manip ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import manip_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ManipOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for manip ops."""
|
||||
|
||||
def _testRoll(self, a, shift, axis):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
p = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
||||
output = manip_ops.roll(a, shift, axis)
|
||||
result = session.run(output, {p: a})
|
||||
self.assertAllEqual(result, np.roll(a, shift, axis))
|
||||
|
||||
def testNumericTypes(self):
|
||||
for t in self.numeric_types:
|
||||
self._testRoll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
|
||||
self._testRoll(
|
||||
np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -6, 6],
|
||||
[0, 1, 2])
|
||||
self._testRoll(
|
||||
np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
|
||||
[1, 2, 3])
|
||||
|
||||
def testFloatTypes(self):
|
||||
for t in self.float_types:
|
||||
self._testRoll(np.random.rand(5).astype(t), 2, 0)
|
||||
self._testRoll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
|
||||
self._testRoll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
|
||||
|
||||
def testComplexTypes(self):
|
||||
for t in self.complex_types:
|
||||
x = np.random.rand(4, 4).astype(t)
|
||||
self._testRoll(x + 1j * x, 2, 0)
|
||||
x = np.random.rand(2, 5).astype(t)
|
||||
self._testRoll(x + 1j * x, [1, 2], [1, 0])
|
||||
x = np.random.rand(3, 2, 1, 1).astype(t)
|
||||
self._testRoll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
86
tensorflow/compiler/tests/matrix_inverse_op_test.py
Normal file
86
tensorflow/compiler/tests/matrix_inverse_op_test.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class InverseOpTest(xla_test.XLATestCase):
|
||||
|
||||
def _verifyInverse(self, x, np_type):
|
||||
for adjoint in False, True:
|
||||
y = x.astype(np_type)
|
||||
with self.session() as sess:
|
||||
# Verify that x^{-1} * x == Identity matrix.
|
||||
p = array_ops.placeholder(dtypes.as_dtype(y.dtype), y.shape, name="x")
|
||||
with self.test_scope():
|
||||
inv = linalg_ops.matrix_inverse(p, adjoint=adjoint)
|
||||
tf_ans = math_ops.matmul(inv, p, adjoint_b=adjoint)
|
||||
np_ans = np.identity(y.shape[-1])
|
||||
if x.ndim > 2:
|
||||
tiling = list(y.shape)
|
||||
tiling[-2:] = [1, 1]
|
||||
np_ans = np.tile(np_ans, tiling)
|
||||
out = sess.run(tf_ans, feed_dict={p: y})
|
||||
self.assertAllClose(np_ans, out, rtol=1e-3, atol=1e-3)
|
||||
self.assertShapeEqual(y, tf_ans)
|
||||
|
||||
def _verifyInverseReal(self, x):
|
||||
for np_type in self.float_types & {np.float64, np.float32}:
|
||||
self._verifyInverse(x, np_type)
|
||||
|
||||
def _makeBatch(self, matrix1, matrix2):
|
||||
matrix_batch = np.concatenate(
|
||||
[np.expand_dims(matrix1, 0),
|
||||
np.expand_dims(matrix2, 0)])
|
||||
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
|
||||
return matrix_batch
|
||||
|
||||
def testNonsymmetric(self):
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||
self._verifyInverseReal(matrix1)
|
||||
self._verifyInverseReal(matrix2)
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||
|
||||
def testSymmetricPositiveDefinite(self):
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||
self._verifyInverseReal(matrix1)
|
||||
self._verifyInverseReal(matrix2)
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||
|
||||
def testEmpty(self):
|
||||
self._verifyInverseReal(np.empty([0, 2, 2]))
|
||||
self._verifyInverseReal(np.empty([2, 0, 0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -27,6 +27,7 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -107,16 +108,18 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[-1, 1]], dtype=dtype))
|
||||
|
||||
# TODO(penporn): Once XLA supports MatrixDiagV2, change the call to
|
||||
# gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2).
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
|
||||
gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array(
|
||||
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.matrix_diag,
|
||||
gen_array_ops.matrix_diag,
|
||||
np.array(
|
||||
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
|
||||
np.array(
|
||||
@ -126,7 +129,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[0, 0, 12]]]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.matrix_diag_part,
|
||||
gen_array_ops.matrix_diag_part,
|
||||
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
|
||||
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
|
||||
|
||||
|
@ -3,16 +3,8 @@
|
||||
# and provide TensorRT operators and converter package.
|
||||
# APIs are meant to change over time.
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_shared_object",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
@ -31,6 +23,42 @@ load(
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
# Placeholder for Google-internal load statements.
|
||||
|
||||
# NOTE: we always assume that if_static returns "otherwise" list in open source.
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt_stub",
|
||||
srcs = if_tensorrt([
|
||||
"stub/nvinfer_stub.cc",
|
||||
"stub/nvinfer_plugin_stub.cc",
|
||||
]),
|
||||
textual_hdrs = glob(["stub/*.inc"]),
|
||||
deps = if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorrt_lib",
|
||||
actual = if_static(
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
":tensorrt_stub",
|
||||
),
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
size = "small",
|
||||
@ -46,8 +74,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_tensorrt([
|
||||
":tensorrt_lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -72,9 +100,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -82,7 +108,7 @@ cc_library(
|
||||
name = "trt_engine_resource_op_kernels",
|
||||
srcs = ["kernels/trt_engine_resource_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
visibility = ["//tensorflow/core:__subpackages__"],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
":trt_engine_instance_proto_cc",
|
||||
@ -96,9 +122,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -130,21 +154,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "python/ops/libtftrt.so",
|
||||
copts = tf_copts(is_external = True),
|
||||
linkopts = ["-lm"],
|
||||
deps = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_op_test",
|
||||
size = "small",
|
||||
@ -164,6 +173,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -197,9 +207,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
@ -212,18 +220,6 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "trt_ops_loader",
|
||||
srcs = ["python/ops/trt_ops.py"],
|
||||
dso = [
|
||||
"python/ops/libtftrt.so",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
kernels = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":trt_ops",
|
||||
@ -254,9 +250,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
@ -267,9 +261,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -338,9 +330,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -373,9 +363,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
@ -409,8 +397,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + if_tensorrt([
|
||||
":tensorrt_lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -428,7 +416,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -463,9 +451,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -491,9 +477,7 @@ cc_library(
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
hdrs = ["utils/py_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
deps = if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
|
@ -64,24 +64,6 @@ namespace convert {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
TrtCandidateSelector::TrtCandidateSelector(
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode)
|
||||
: graph_properties_(graph_properties), precision_mode_(precision_mode) {}
|
||||
|
||||
Status TrtCandidateSelector::IsTensorRTCandidate(const Node* node) {
|
||||
std::vector<const Edge*> input_edges;
|
||||
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
|
||||
std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
|
||||
input_node_and_ports.reserve(input_edges.size());
|
||||
for (const Edge* input_edge : input_edges) {
|
||||
input_node_and_ports.emplace_back(&input_edge->src()->def(),
|
||||
input_edge->src_output());
|
||||
}
|
||||
return validator_.ValidateNode(node->def(), input_node_and_ports,
|
||||
precision_mode_, graph_properties_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status BuildNodeMap(const Graph& graph,
|
||||
@ -478,10 +460,6 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
node_builder.ControlInput(c);
|
||||
}
|
||||
|
||||
if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
|
||||
!info.cached_engine_batches.empty()) {
|
||||
LOG(WARNING) << "Cached engine batches are ignored for static engines";
|
||||
}
|
||||
NodeDef trt_node;
|
||||
Status status =
|
||||
node_builder.Attr("input_shapes", input_shape_protos)
|
||||
@ -740,13 +718,13 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
}
|
||||
segment_options.minimum_segment_size = params.minimum_segment_size;
|
||||
segment::SegmentNodesVector initial_segments;
|
||||
TrtCandidateSelector candidate_selector(*params.graph_properties,
|
||||
params.precision_mode);
|
||||
TrtNodeValidator validator(*params.graph_properties, params.precision_mode,
|
||||
params.use_calibration);
|
||||
TF_RETURN_IF_ERROR(segment::SegmentGraph(
|
||||
&graph,
|
||||
std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
|
||||
std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
|
||||
std::placeholders::_1),
|
||||
// Input validation is already done by TrtCandidateSelector, so we don't
|
||||
// Input validation is already done by TrtNodeValidator, so we don't
|
||||
// need to check the input edges.
|
||||
[](const Edge* edge) { return true; }, OutputEdgeValidator(),
|
||||
segment_options, &initial_segments));
|
||||
@ -782,7 +760,6 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
? EngineInfo::EngineType::TRTDynamic
|
||||
: EngineInfo::EngineType::TRTStatic);
|
||||
curr_engine.use_calibration = params.use_calibration;
|
||||
curr_engine.cached_engine_batches = params.cached_engine_batches;
|
||||
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
||||
if (params.use_function_backup) {
|
||||
status = RegisterSegmentFunctionToFunctionLibrary(
|
||||
|
@ -31,30 +31,6 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// Helper class for the segmenter to determine whether given TF node is
|
||||
// supported by TRT.
|
||||
class TrtCandidateSelector {
|
||||
public:
|
||||
TrtCandidateSelector(const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode);
|
||||
|
||||
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
|
||||
// to TRT subgraph and later converted into TRT engine.
|
||||
Status IsTensorRTCandidate(const Node* node);
|
||||
|
||||
private:
|
||||
// The TF-TRT node converter used to verify whether individual node is
|
||||
// supported. It will operate in validation-only mode.
|
||||
TrtNodeValidator validator_;
|
||||
|
||||
// GraphProperties of the graph whose nodes are to be validated by
|
||||
// IsTensorRTCandidate().
|
||||
const grappler::GraphProperties& graph_properties_;
|
||||
|
||||
// Quantization ops are only converted when using quantized precisions.
|
||||
const TrtPrecisionMode precision_mode_;
|
||||
};
|
||||
|
||||
struct ConversionParams {
|
||||
const GraphDef* input_graph_def = nullptr;
|
||||
const std::vector<string>* output_names = nullptr;
|
||||
@ -70,8 +46,6 @@ struct ConversionParams {
|
||||
// maximum number of cached engines
|
||||
int max_cached_engines = 1;
|
||||
bool use_calibration = true;
|
||||
// list of cached engines
|
||||
std::vector<int> cached_engine_batches;
|
||||
// Whether to use function fallback for TRTEngineOp
|
||||
bool use_function_backup = true;
|
||||
};
|
||||
|
@ -50,81 +50,6 @@ void ExpectStatus(Status status, error::Code code = error::OK,
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TrtCandidateSelector, Basics) {
|
||||
// Create a graph containing both TRT-compatible and TRT-incompatible nodes
|
||||
// and use it to test TrtCandidateSelector::IsTensorRTCandidate().
|
||||
const std::vector<int32> input_shape_array{2, 2};
|
||||
TensorShape input_shape;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape));
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
ops::Placeholder::Attrs feed_attrs;
|
||||
TF_EXPECT_OK(
|
||||
TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_));
|
||||
|
||||
// Compatible input.
|
||||
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs);
|
||||
auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape);
|
||||
|
||||
// Compatible MatMul.
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1);
|
||||
|
||||
// Incompatible MatMul.
|
||||
ops::MatMul::Attrs matmul_attrs;
|
||||
matmul_attrs.transpose_a_ = true;
|
||||
auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"),
|
||||
feed, const_1, matmul_attrs);
|
||||
|
||||
// Unsupported op.
|
||||
auto unsupported_op = ops::Erf(s.WithOpName("sin"), feed);
|
||||
|
||||
// Incompatible input.
|
||||
auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE);
|
||||
auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape);
|
||||
// Compatible op with incompatible input.
|
||||
auto matmul_with_incompatible_input =
|
||||
ops::MatMul(s.WithOpName("matmul_with_incompatible_input"),
|
||||
incompatible_feed, const_2);
|
||||
|
||||
// Quantize ops.
|
||||
auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
|
||||
auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed,
|
||||
quantize_attrs);
|
||||
|
||||
// Get GrapplerItem and GraphProperties.
|
||||
grappler::GrapplerItem item;
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
Tensor feed_tensor(DT_FLOAT, input_shape);
|
||||
item.feed.push_back(std::make_pair("feed", feed_tensor));
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
for (const TrtPrecisionMode precision_mode :
|
||||
{TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) {
|
||||
TrtCandidateSelector selector(graph_properties, precision_mode);
|
||||
TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
|
||||
ExpectStatus(
|
||||
selector.IsTensorRTCandidate(incompatible_matmul.operation.node()),
|
||||
error::INVALID_ARGUMENT,
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions.");
|
||||
ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()),
|
||||
error::UNIMPLEMENTED, "Op type Erf is not supported");
|
||||
ExpectStatus(
|
||||
selector.IsTensorRTCandidate(
|
||||
matmul_with_incompatible_input.operation.node()),
|
||||
error::INTERNAL,
|
||||
"Failed to convert input with index 0 to a TRT_TensorOrWeights");
|
||||
if (precision_mode == TrtPrecisionMode::INT8) {
|
||||
TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node()));
|
||||
} else {
|
||||
ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()),
|
||||
error::UNIMPLEMENTED,
|
||||
"Op type FakeQuantWithMinMaxArgs is not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class FakeCluster : public grappler::Cluster {
|
||||
public:
|
||||
FakeCluster() : Cluster(0) {}
|
||||
|
@ -504,12 +504,11 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
||||
}
|
||||
|
||||
// Convert an axis from TF format to TRT format while validating. TF format
|
||||
// includes the batch dimension, while TRT does not. TF can also use negative
|
||||
// indices.
|
||||
// TODO(tmorris): Use this method in more ops.
|
||||
// includes the batch dimension, while TRT does not if implicit batching is used
|
||||
// (i.e. for tensors). TF can also use negative indices.
|
||||
Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
|
||||
int* trt_axis) {
|
||||
const int tf_nb_dims = trt_nb_dims + 1;
|
||||
bool use_implicit_batch, int* trt_axis) {
|
||||
const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0);
|
||||
// Check bounds.
|
||||
if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
|
||||
return errors::InvalidArgument(
|
||||
@ -519,13 +518,13 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
|
||||
// Make negative axis positive.
|
||||
if (tf_axis < 0) tf_axis += tf_nb_dims;
|
||||
// Don't allow axis to be the batch dimension.
|
||||
if (tf_axis == 0) {
|
||||
if (use_implicit_batch && tf_axis == 0) {
|
||||
return errors::Unimplemented(
|
||||
"TensorRT does not allow manipulation of the batch dimension, at ",
|
||||
node_name);
|
||||
}
|
||||
// Remove batch dimension.
|
||||
*trt_axis = tf_axis - 1;
|
||||
// Remove batch dimension if it is implicit.
|
||||
*trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -956,6 +955,31 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
|
||||
return weights;
|
||||
}
|
||||
|
||||
OpConverterParams::OpConverterParams(
|
||||
const NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration)
|
||||
: node_def(node_def),
|
||||
inputs(inputs),
|
||||
outputs(outputs),
|
||||
validation_only(true),
|
||||
weight_store(weight_store),
|
||||
precision_mode(precision_mode),
|
||||
use_calibration(use_calibration) {}
|
||||
|
||||
OpConverterParams::OpConverterParams(
|
||||
Converter* converter, const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store)
|
||||
: converter(converter),
|
||||
node_def(node_def),
|
||||
inputs(inputs),
|
||||
outputs(outputs),
|
||||
validation_only(false),
|
||||
weight_store(weight_store),
|
||||
precision_mode(converter->precision_mode()),
|
||||
use_calibration(converter->use_calibration()) {}
|
||||
|
||||
const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
|
||||
"QuantizeAndDequantizeV2",
|
||||
"QuantizeAndDequantizeV3",
|
||||
@ -963,11 +987,17 @@ const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
|
||||
"FakeQuantWithMinMaxArgs",
|
||||
};
|
||||
|
||||
TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); }
|
||||
TrtNodeValidator::TrtNodeValidator(
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration)
|
||||
: graph_properties_(graph_properties),
|
||||
precision_mode_(precision_mode),
|
||||
use_calibration_(use_calibration) {
|
||||
RegisterOpValidators();
|
||||
}
|
||||
|
||||
Status TrtNodeValidator::ConvertToTensorOrWeights(
|
||||
const NodeDef& node_def, int output_port,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TRT_TensorOrWeights* tensor_or_weights) {
|
||||
if (node_def.op() == "Const") {
|
||||
if (output_port != 0) {
|
||||
@ -983,13 +1013,13 @@ Status TrtNodeValidator::ConvertToTensorOrWeights(
|
||||
std::vector<TRT_TensorOrWeights> inputs;
|
||||
return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
|
||||
}
|
||||
if (!graph_properties.HasOutputProperties(node_def.name())) {
|
||||
if (!graph_properties_.HasOutputProperties(node_def.name())) {
|
||||
return errors::InvalidArgument("Shape and data type are unknown");
|
||||
}
|
||||
|
||||
// Validate and convert shape and dtype.
|
||||
const auto& output_params =
|
||||
graph_properties.GetOutputProperties(node_def.name());
|
||||
graph_properties_.GetOutputProperties(node_def.name());
|
||||
const auto& tensor_properties = output_params.at(output_port);
|
||||
const DataType dtype = tensor_properties.dtype();
|
||||
const PartialTensorShape shape = tensor_properties.shape();
|
||||
@ -1007,20 +1037,16 @@ Status TrtNodeValidator::ConvertToTensorOrWeights(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrtNodeValidator::ValidateNode(
|
||||
const NodeDef& node_def,
|
||||
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
|
||||
const TrtPrecisionMode precision_mode,
|
||||
const grappler::GraphProperties& graph_properties) {
|
||||
const string& op = node_def.op();
|
||||
Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) {
|
||||
const string& op = node->def().op();
|
||||
// In INT8 mode, we will always apply the quantization ranges provided by
|
||||
// these ops to the relevant tensors. This happens regardless of the value of
|
||||
// use_calibration.
|
||||
bool is_supported_op = false;
|
||||
if (quantize_ops->count(op)) {
|
||||
is_supported_op = (precision_mode == TrtPrecisionMode::INT8);
|
||||
is_supported_op = (precision_mode_ == TrtPrecisionMode::INT8);
|
||||
} else {
|
||||
is_supported_op = op_validators_.count(node_def.op());
|
||||
is_supported_op = op_validators_.count(op);
|
||||
}
|
||||
if (!is_supported_op) {
|
||||
return errors::Unimplemented("Op type ", op, " is not supported.");
|
||||
@ -1029,23 +1055,24 @@ Status TrtNodeValidator::ValidateNode(
|
||||
// Convert input NodeDef and corresponding output ports to
|
||||
// TRT_TensorOrWeights.
|
||||
std::vector<TRT_TensorOrWeights> inputs;
|
||||
for (int i = 0; i < input_node_and_ports.size(); ++i) {
|
||||
const auto& pair = input_node_and_ports[i];
|
||||
std::vector<const Edge*> input_edges;
|
||||
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
|
||||
for (const Edge* edge : input_edges) {
|
||||
TRT_TensorOrWeights tensor_or_weights;
|
||||
Status status = ConvertToTensorOrWeights(
|
||||
*pair.first, pair.second, graph_properties, &tensor_or_weights);
|
||||
const NodeDef& src_def = edge->src()->def();
|
||||
Status status = ConvertToTensorOrWeights(src_def, edge->src_output(),
|
||||
&tensor_or_weights);
|
||||
if (!status.ok()) {
|
||||
return errors::Internal(
|
||||
"Failed to convert input with index ", i,
|
||||
"Failed to convert input ", src_def.name(),
|
||||
" to a TRT_TensorOrWeights: ", status.error_message());
|
||||
}
|
||||
inputs.push_back(tensor_or_weights);
|
||||
}
|
||||
|
||||
OpConverter validator = op_validators_[node_def.op()];
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
OpConverter validator = op_validators_[op];
|
||||
OpConverterParams params(node->def(), inputs, /*arg_outputs=*/nullptr,
|
||||
&weight_store_, precision_mode_, use_calibration_);
|
||||
return validator(¶ms);
|
||||
}
|
||||
|
||||
@ -1054,9 +1081,8 @@ Status TrtNodeValidator::ConvertConstToWeights(
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
TRT_TensorOrWeights* output) {
|
||||
std::vector<TRT_TensorOrWeights> outputs;
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, const_node_def, inputs, &outputs,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
OpConverterParams params(const_node_def, inputs, &outputs, &weight_store_,
|
||||
precision_mode_, use_calibration_);
|
||||
Status status = op_validators_["Const"](¶ms);
|
||||
if (status.ok() && output) *output = outputs[0];
|
||||
return status;
|
||||
@ -1108,8 +1134,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) {
|
||||
std::vector<TRT_TensorOrWeights> inputs, outputs;
|
||||
TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
|
||||
|
||||
OpConverterParams params(this, node_def, inputs, &outputs,
|
||||
/*arg_validation_only=*/false, &weight_store_);
|
||||
OpConverterParams params(this, node_def, inputs, &outputs, &weight_store_);
|
||||
const string& op = node_def.op();
|
||||
auto itr = op_registry_.find(op);
|
||||
if (itr == op_registry_.end()) {
|
||||
@ -2004,8 +2029,8 @@ Status ConvertExpandDims(OpConverterParams* params) {
|
||||
// Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for
|
||||
// ExpandDim's ability to add an axis at end of the shape.
|
||||
int trt_axis;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// ExpandDims: Insert new dim of size 1.
|
||||
@ -2040,8 +2065,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
|
||||
for (int tf_axis : squeeze_dims) {
|
||||
// Make sure axis is valid.
|
||||
int trt_axis;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
// Make sure target dimension is size 1.
|
||||
if (input_dims[trt_axis] != 1) {
|
||||
return errors::InvalidArgument(
|
||||
@ -2071,7 +2096,8 @@ template <typename Container>
|
||||
Status ConvertStridedSliceHelper(OpConverterParams* params,
|
||||
const TRT_TensorOrWeights& input,
|
||||
Container begin, Container size,
|
||||
const Container& stride) {
|
||||
const Container& stride,
|
||||
const nvinfer1::Dims* final_shape = nullptr) {
|
||||
const auto& node_def = params->node_def;
|
||||
// Get input dims.
|
||||
nvinfer1::Dims dims = input.GetTrtDims();
|
||||
@ -2110,7 +2136,14 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
|
||||
|
||||
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
|
||||
*input.tensor(), begin_dims, size_dims, stride_dims);
|
||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||
nvinfer1::ITensor* tensor = layer->getOutput(0);
|
||||
// Reshape for shrink_axis.
|
||||
if (final_shape) {
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false,
|
||||
&tensor));
|
||||
}
|
||||
params->outputs->push_back(TRT_TensorOrWeights(tensor));
|
||||
return Status::OK();
|
||||
#else
|
||||
// Use IPaddingLayer.
|
||||
@ -2228,8 +2261,13 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
|
||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||
tensor, inv_transpose_order, &tensor));
|
||||
}
|
||||
// Restore reshape
|
||||
if (need_reshape) {
|
||||
// Reshape for shrink_axis.
|
||||
if (final_shape) {
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false,
|
||||
&tensor));
|
||||
} else if (need_reshape) {
|
||||
// Restore reshape.
|
||||
// Calculate output dimensions
|
||||
for (int i = 0; i < pad_dims.size(); i++) {
|
||||
const int axis = pad_dims[i];
|
||||
@ -2313,17 +2351,17 @@ Status ConvertStridedSlice(OpConverterParams* params) {
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
// Unsupported mask options.
|
||||
for (const string& attr : {"new_axis_mask", "shrink_axis_mask"}) {
|
||||
int attr_val = attrs.get<int64>(attr);
|
||||
if (attr_val != 0) {
|
||||
return errors::Unimplemented(
|
||||
attr, " is not supported for StridedSlice, at ", node_def.name());
|
||||
}
|
||||
// new_axis_mask is not supported.
|
||||
const int32 new_axis_mask = attrs.get<int64>("new_axis_mask");
|
||||
if (new_axis_mask != 0) {
|
||||
return errors::Unimplemented(
|
||||
"new_axis_mask is not supported for StridedSlice, at ",
|
||||
node_def.name());
|
||||
}
|
||||
const int32 begin_mask = attrs.get<int64>("begin_mask");
|
||||
const int32 end_mask = attrs.get<int64>("end_mask");
|
||||
const int32 ellipsis_mask = attrs.get<int64>("ellipsis_mask");
|
||||
const int32 shrink_axis_mask = attrs.get<int64>("shrink_axis_mask");
|
||||
|
||||
// Get input dims.
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
@ -2355,9 +2393,9 @@ Status ConvertStridedSlice(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
||||
&begin_weights.GetTensor(), &end_weights.GetTensor(),
|
||||
stride_weights.GetTensor(), input_shape, begin_mask, end_mask,
|
||||
ellipsis_mask, /*new_axis_mask=*/0,
|
||||
/*shrink_axis_mask=*/0, &processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
|
||||
ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape,
|
||||
&final_shape, &is_identity, &is_simple_slice, &slice_dim0, &begin, &end,
|
||||
&strides));
|
||||
|
||||
// Negative or zero strides currently not supported.
|
||||
for (int stride : strides) {
|
||||
@ -2391,13 +2429,29 @@ Status ConvertStridedSlice(OpConverterParams* params) {
|
||||
node_def.name());
|
||||
}
|
||||
}
|
||||
// Can't shrink axis on batch dimension.
|
||||
if (shrink_axis_mask & 1) {
|
||||
return errors::Unimplemented(
|
||||
"TensorRT does not allow modifications to the batch dimension, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// TRT Slice layer uses (begin, size) instead of (begin, end)
|
||||
absl::InlinedVector<int64, 4> size(input_dims.size());
|
||||
for (int i = 0; i < input_dims.size(); i++) {
|
||||
// Divide by stride (round up)
|
||||
size[i] = (end[i] - begin[i] + strides[i] - 1) / strides[i];
|
||||
}
|
||||
return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides);
|
||||
|
||||
// shrink_axis_mask requires a reshape after the slice.
|
||||
nvinfer1::Dims final_shape_dims;
|
||||
nvinfer1::Dims* final_shape_dims_ptr = nullptr;
|
||||
if (shrink_axis_mask) {
|
||||
final_shape_dims =
|
||||
TensorShapeToTrtDims(final_shape, /*ignore_first_dim=*/true);
|
||||
final_shape_dims_ptr = &final_shape_dims;
|
||||
}
|
||||
return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides,
|
||||
final_shape_dims_ptr);
|
||||
}
|
||||
|
||||
Status ConvertConv2D(OpConverterParams* params) {
|
||||
@ -2795,7 +2849,7 @@ Status ConvertRelu6(OpConverterParams* params) {
|
||||
#endif
|
||||
}
|
||||
|
||||
Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -2893,6 +2947,71 @@ Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
if (params->precision_mode == TrtPrecisionMode::INT8 &&
|
||||
!params->use_calibration) {
|
||||
// NOTE(laigd): based on some observation, it seems TensorRT cannot fuse
|
||||
// IConvolutionLayer and IElementwiseLayer and will require range
|
||||
// information for the output of Conv2D. Using IScaleLayer will fix the
|
||||
// problem.
|
||||
return ConvertBiasAddInt8WithoutCalibration(params);
|
||||
}
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
|
||||
if (inputs.size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"BiasAdd expects exactly 2 inputs, but received ", inputs.size());
|
||||
}
|
||||
|
||||
if (inputs[0].is_weights() && inputs[1].is_weights()) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs are weights, but Grappler is expected to fold them.");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
const string& data_format = attrs.get<string>("data_format");
|
||||
|
||||
nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims();
|
||||
nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims();
|
||||
// If the input is NCHW, then we need to unsqueeze the bias such that its last
|
||||
// dimensions are 1s (and the first dimension is C).
|
||||
if (data_format == "NCHW") {
|
||||
bias_shape.nbDims = input_shape.nbDims;
|
||||
std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1);
|
||||
} else {
|
||||
// Next, broadcast the bias across the input.
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1),
|
||||
&input_shape, &bias_shape));
|
||||
}
|
||||
|
||||
// Convert input to a TRT tensor
|
||||
nvinfer1::ITensor* input_tensor{nullptr};
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(0), input_shape, params->validation_only, &input_tensor));
|
||||
|
||||
// Finally, reshape bias. Since the bias is usually a constant, this will
|
||||
// normally happen at conversion-time.
|
||||
nvinfer1::ITensor* bias_tensor{nullptr};
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(1), bias_shape, params->validation_only, &bias_tensor));
|
||||
VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape);
|
||||
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::IElementWiseLayer* layer =
|
||||
params->converter->network()->addElementWise(
|
||||
*input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
|
||||
if (tensor.dims() > 0) {
|
||||
*dims = GetTrtDimsForTensor(tensor);
|
||||
@ -3289,9 +3408,9 @@ Status ConvertReduce(OpConverterParams* params) {
|
||||
}
|
||||
for (int i = 0; i < tf_axes_list.size(); i++) {
|
||||
int trt_axis;
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i],
|
||||
tensor->getDimensions().nbDims,
|
||||
node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims,
|
||||
node_def.name(), /*use_implicit_batch=*/true, &trt_axis));
|
||||
axes |= (1 << trt_axis);
|
||||
}
|
||||
|
||||
@ -3358,8 +3477,8 @@ Status ConvertPack(OpConverterParams* params) {
|
||||
const nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
const int64 tf_axis = attrs.get<int64>("axis");
|
||||
int trt_axis;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
|
||||
// Compute expanded dimensions and then reshape input tensors.
|
||||
std::vector<int> tensor_dims(dims.d, dims.d + dims.nbDims);
|
||||
@ -3506,8 +3625,8 @@ Status ConvertSplitHelper(OpConverterParams* params,
|
||||
const nvinfer1::Dims dims = input.GetTrtDims();
|
||||
// Convert axis.
|
||||
int trt_axis;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
// Dimension must equal num_splits for Unstack (when squeeze_after is true)
|
||||
if (squeeze_after && dims.d[trt_axis] != num_splits) {
|
||||
return errors::InvalidArgument(
|
||||
@ -3536,31 +3655,23 @@ Status ConvertSplitHelper(OpConverterParams* params,
|
||||
begin.insert(begin.begin(), 0);
|
||||
size.insert(size.begin(), 1);
|
||||
stride.insert(stride.begin(), 1);
|
||||
// Create final shape for Unpack/Unstack, where split axis is squeezed.
|
||||
nvinfer1::Dims final_shape_for_unpack;
|
||||
nvinfer1::Dims* final_shape_for_unpack_ptr = nullptr;
|
||||
if (squeeze_after) {
|
||||
std::vector<int> size_after_squeeze(size);
|
||||
size_after_squeeze.erase(size_after_squeeze.begin() + trt_axis + 1);
|
||||
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(
|
||||
size_after_squeeze, &final_shape_for_unpack, /*ignore_frst_dim=*/true));
|
||||
final_shape_for_unpack_ptr = &final_shape_for_unpack;
|
||||
}
|
||||
|
||||
// Slice the input. ConvertStridedSliceHelper will push the outputs onto
|
||||
// params->outputs.
|
||||
for (int i = 0; i < num_splits; ++i) {
|
||||
begin[trt_axis + 1] = i * split_size_on_axis;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertStridedSliceHelper(params, input, begin, size, stride));
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// For Unpack/Unstack, remove axis that we split upon.
|
||||
if (squeeze_after) {
|
||||
// Create the new shape.
|
||||
size.erase(size.begin() + trt_axis + 1);
|
||||
nvinfer1::Dims new_dims;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeArrayToTrtDims(size, &new_dims, /*ignore_frst_dim=*/true));
|
||||
// Reshape each slice.
|
||||
for (int i = 0; i < params->outputs->size(); i++) {
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
params->outputs->at(i), new_dims, /*validation_only=*/false,
|
||||
&output_tensor));
|
||||
(*params->outputs)[i] = TRT_TensorOrWeights(output_tensor);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertStridedSliceHelper(
|
||||
params, input, begin, size, stride, final_shape_for_unpack_ptr));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -3635,8 +3746,8 @@ Status ConvertConcat(OpConverterParams* params) {
|
||||
}
|
||||
int trt_axis = 0;
|
||||
const auto dim = inputs.at(0).GetTrtDims();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
// Check that dimensions match on non-concatenate axis.
|
||||
TF_RETURN_IF_ERROR(VerifyShapesMatch(
|
||||
absl::Span<const TRT_TensorOrWeights>(inputs).first(num_inputs), trt_axis,
|
||||
@ -3795,29 +3906,58 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
Status ConvertGather(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
|
||||
// TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
|
||||
// option for an input to be either tensor or weight.
|
||||
if (inputs.size() != 3) {
|
||||
return errors::InvalidArgument("GatherV2 got ", inputs.size(),
|
||||
" inputs but expected 3, at ",
|
||||
node_def.name());
|
||||
}
|
||||
const auto& params_input = inputs.at(0);
|
||||
const auto& indices_input = inputs.at(1);
|
||||
const auto& axis_input = inputs.at(2);
|
||||
if (!axis_input.is_weights()) {
|
||||
return errors::Unimplemented(
|
||||
"The input \"axis\" for GatherV2 must be a constant, at ",
|
||||
node_def.name());
|
||||
}
|
||||
if (!indices_input.is_tensor()) {
|
||||
return errors::Unimplemented(
|
||||
"The input \"indices\" for GatherV2 must be a tensor, at ",
|
||||
node_def.name());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
|
||||
/*dtype_attr_name=*/"Tparams"));
|
||||
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32},
|
||||
/*dtype_attr_name=*/"Tindices"));
|
||||
|
||||
absl::Span<const int> axis = axis_input.weights().GetSpan<int>();
|
||||
if (axis.size() != 1) {
|
||||
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
|
||||
node_def.name());
|
||||
}
|
||||
int trt_axis = 0;
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
|
||||
node_def.name(), &trt_axis));
|
||||
const TRT_TensorOrWeights& params_tensor = inputs.at(0);
|
||||
const TRT_TensorOrWeights& indices_tensor = inputs.at(1);
|
||||
if (indices_tensor.batch_size() != 1) {
|
||||
return errors::InvalidArgument("Only indices with batch 1 are supported.");
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims,
|
||||
node_def.name(), params_input.is_tensor(),
|
||||
&trt_axis));
|
||||
if (params_input.is_weights() && trt_axis != 0) {
|
||||
return errors::Unimplemented(
|
||||
"The input axis must be zero when params is a weight.");
|
||||
}
|
||||
if (params_input.is_tensor() && indices_input.batch_size() != 1) {
|
||||
return errors::Unimplemented(
|
||||
"Indices must have a batch size of 1 when params is a tensor.");
|
||||
}
|
||||
// Both input are tensors, and the TF gather result will have rank:
|
||||
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
|
||||
// where "+ 1" adds the batch dim.
|
||||
const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
|
||||
indices_tensor.GetTrtDims().nbDims + 1;
|
||||
// where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
|
||||
// the TF rank so we don't have to add + 1.
|
||||
const int params_tf_rank =
|
||||
params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0);
|
||||
const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1;
|
||||
const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
|
||||
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Result of gather has dimension greater than ",
|
||||
@ -3825,38 +3965,50 @@ Status ConvertGather(OpConverterParams* params) {
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Convert params to tensor is it is a weight.
|
||||
nvinfer1::ITensor* params_tensor = nullptr;
|
||||
if (params_input.is_weights()) {
|
||||
params_tensor = params->converter->CreateConstantLayer(
|
||||
params_input.weights(), params_input.GetTrtDims());
|
||||
} else {
|
||||
params_tensor = params_input.tensor();
|
||||
}
|
||||
|
||||
// Note on how IGatherLayer works: if both the data and indices tensors have
|
||||
// a batch size dimension of size N, it performs:
|
||||
// for batchid in xrange(N):
|
||||
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
|
||||
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
|
||||
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
|
||||
*params_tensor.tensor(), *indices_tensor.tensor(), trt_axis);
|
||||
*params_tensor, *indices_input.tensor(), trt_axis);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
|
||||
nvinfer1::ITensor* gather_output = layer->getOutput(0);
|
||||
nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
|
||||
// Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
|
||||
// and the other is for the output dimension that is squeezed by IGatherLayer
|
||||
// because of the implicit batch dim in the indices (see the above note).
|
||||
if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
|
||||
const int expected_trt_output_rank =
|
||||
tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
|
||||
if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
|
||||
return errors::Internal(
|
||||
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
|
||||
tf_gather_output_rank - 2,
|
||||
expected_trt_output_rank,
|
||||
", actual nbDims: ", trt_gather_output_dims.nbDims);
|
||||
}
|
||||
// Reshape the output so after adding the implicit batch dim it'll match the
|
||||
// output shape of TF GatherV2.
|
||||
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
|
||||
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
|
||||
}
|
||||
trt_gather_output_dims.d[trt_axis] = 1;
|
||||
++trt_gather_output_dims.nbDims;
|
||||
if (params_input.is_tensor()) {
|
||||
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
|
||||
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
|
||||
}
|
||||
trt_gather_output_dims.d[trt_axis] = 1;
|
||||
++trt_gather_output_dims.nbDims;
|
||||
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
|
||||
/*validation_only=*/false, &output_tensor));
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(output_tensor), trt_gather_output_dims,
|
||||
/*validation_only=*/false, &output_tensor));
|
||||
}
|
||||
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
return Status::OK();
|
||||
@ -4116,8 +4268,8 @@ Status ConvertArgMinMax(OpConverterParams* params) {
|
||||
int tf_axis = inputs.at(1).weights().GetSpan<int>()[0];
|
||||
int trt_axis;
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
|
||||
/*use_implicit_batch=*/true, &trt_axis));
|
||||
nvinfer1::TopKOperation topk_op;
|
||||
if (node_def.op() == "ArgMin") {
|
||||
topk_op = nvinfer1::TopKOperation::kMIN;
|
||||
|
@ -116,7 +116,6 @@ struct EngineInfo {
|
||||
EngineType engine_type;
|
||||
int64 max_workspace_size_bytes;
|
||||
int maximum_cached_engines;
|
||||
std::vector<int> cached_engine_batches;
|
||||
TrtPrecisionMode precision_mode;
|
||||
bool use_calibration;
|
||||
};
|
||||
@ -354,23 +353,27 @@ class Converter;
|
||||
|
||||
// Parameters for each op converter.
|
||||
struct OpConverterParams {
|
||||
OpConverterParams(Converter* arg_converter, const NodeDef& arg_node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& arg_inputs,
|
||||
std::vector<TRT_TensorOrWeights>* arg_outputs,
|
||||
bool arg_validation_only, TrtWeightStore* arg_weight_store)
|
||||
: converter(arg_converter),
|
||||
node_def(arg_node_def),
|
||||
inputs(arg_inputs),
|
||||
outputs(arg_outputs),
|
||||
validation_only(arg_validation_only),
|
||||
weight_store(arg_weight_store) {}
|
||||
// Constructor used for validation only.
|
||||
OpConverterParams(const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs,
|
||||
TrtWeightStore* weight_store,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration);
|
||||
|
||||
Converter* converter;
|
||||
// Constructor used for conversion.
|
||||
OpConverterParams(Converter* converter, const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs,
|
||||
TrtWeightStore* weight_store);
|
||||
|
||||
Converter* converter = nullptr;
|
||||
const NodeDef& node_def;
|
||||
const std::vector<TRT_TensorOrWeights>& inputs;
|
||||
std::vector<TRT_TensorOrWeights>* outputs;
|
||||
const bool validation_only;
|
||||
TrtWeightStore* weight_store;
|
||||
const TrtPrecisionMode precision_mode;
|
||||
const bool use_calibration;
|
||||
};
|
||||
|
||||
using OpConverter = std::function<Status(OpConverterParams*)>;
|
||||
@ -378,21 +381,15 @@ using OpConverter = std::function<Status(OpConverterParams*)>;
|
||||
// Class to verify if specific TF node is supported by TRT.
|
||||
class TrtNodeValidator {
|
||||
public:
|
||||
TrtNodeValidator();
|
||||
// 'graph_properties' is the GraphProperties of the graph whose nodes will be
|
||||
// checked by IsTensorRTCandidate() later. It is used to get the shape and
|
||||
// data type information of a tensor for validation purpose.
|
||||
TrtNodeValidator(const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration);
|
||||
|
||||
// Validate the node, and return ok if it's supported by TRT.
|
||||
//
|
||||
// - 'node_def' is the node to validate.
|
||||
// - 'input_node_and_ports' are the input NodeDefs and their output ports that
|
||||
// are connected to 'node_def' in the TF graph.
|
||||
// - 'graph_properties' is the GraphProperties of the graph where 'node_def'
|
||||
// belongs. It is used to get the shape and data type information of a
|
||||
// tensor for validation purpose.
|
||||
Status ValidateNode(
|
||||
const NodeDef& node_def,
|
||||
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
|
||||
const TrtPrecisionMode precision_mode,
|
||||
const grappler::GraphProperties& graph_properties);
|
||||
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
|
||||
// to TRT subgraph and later converted into TRT engine.
|
||||
Status IsTensorRTCandidate(const Node* node);
|
||||
|
||||
private:
|
||||
static const std::set<string>* quantize_ops;
|
||||
@ -407,10 +404,8 @@ class TrtNodeValidator {
|
||||
// Convert the output tensor at 'output_port' of 'node_def' to a
|
||||
// TRT_TensorOrWeights which will be later used as an input to other nodes and
|
||||
// passed to ValidateNode() below.
|
||||
Status ConvertToTensorOrWeights(
|
||||
const NodeDef& node_def, int output_port,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TRT_TensorOrWeights* tensor_or_weights);
|
||||
Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
|
||||
TRT_TensorOrWeights* tensor_or_weights);
|
||||
|
||||
// Stores all the validators by op type. If no validator is registered for
|
||||
// specific op, it means no validation is needed and ValidateNode() will
|
||||
@ -421,6 +416,15 @@ class TrtNodeValidator {
|
||||
// validation for Const node) may produce weights.
|
||||
TrtWeightStore weight_store_;
|
||||
|
||||
// GraphProperties of the graph whose nodes are to be validated by
|
||||
// IsTensorRTCandidate().
|
||||
const grappler::GraphProperties& graph_properties_;
|
||||
|
||||
// Quantization ops are only converted when using quantized precisions.
|
||||
const TrtPrecisionMode precision_mode_;
|
||||
|
||||
const bool use_calibration_;
|
||||
|
||||
friend class ValidatorTest;
|
||||
friend class OpConverterTest;
|
||||
};
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -54,13 +54,6 @@ Status TRTOptimizationPass::Init(
|
||||
if (params.count("is_dynamic_op")) {
|
||||
is_dynamic_op_ = params.at("is_dynamic_op").b();
|
||||
}
|
||||
if (params.count("cached_engine_batches")) {
|
||||
auto batch_vec = params.at("cached_engine_batches").list();
|
||||
batches_.reserve(batch_vec.i_size());
|
||||
for (const auto i : batch_vec.i()) {
|
||||
batches_.push_back(i);
|
||||
}
|
||||
}
|
||||
if (params.count("maximum_cached_engines")) {
|
||||
max_cached_batches_ = params.at("maximum_cached_engines").i();
|
||||
}
|
||||
@ -264,7 +257,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
||||
cp.graph_properties = &static_graph_properties;
|
||||
cp.cluster = cluster;
|
||||
cp.is_dyn_op = is_dynamic_op_;
|
||||
cp.cached_engine_batches = batches_;
|
||||
cp.max_cached_engines = max_cached_batches_;
|
||||
cp.use_calibration = use_calibration_;
|
||||
cp.use_function_backup = use_function_backup_;
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
@ -49,6 +50,7 @@ static Logger logger;
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
using ::nvinfer1::IRuntime;
|
||||
using ::stream_executor::port::StatusOr;
|
||||
|
||||
// A helper class to call done() when destructed for asynchronous execution.
|
||||
// Helps simultaneous execution of native and TRT engines.
|
||||
@ -80,6 +82,10 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
|
||||
private:
|
||||
using CacheType =
|
||||
LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
|
||||
VectorTensorShapeHasher>;
|
||||
|
||||
// Execute calibration
|
||||
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
|
||||
@ -98,12 +104,16 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
TRTCalibrationResource** cr);
|
||||
|
||||
// Get engine for the input shape
|
||||
EngineContext* GetEngine(const std::vector<TensorShape>& input_shapes,
|
||||
OpKernelContext* ctx);
|
||||
StatusOr<EngineContext*> GetEngine(
|
||||
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx);
|
||||
|
||||
// Verify that the input shapes are consistent and can be handled by this op.
|
||||
Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
|
||||
|
||||
// Return engine batch in cached_engne_batch_sizes_ which is closest to input
|
||||
// batch.
|
||||
bool GetCompatibleCachedEngine(
|
||||
Status GetEngineInputShapes(
|
||||
const CacheType& cache,
|
||||
const std::vector<TensorShape>& actual_input_shapes,
|
||||
std::vector<TensorShape>* engine_input_shapes);
|
||||
|
||||
@ -131,9 +141,6 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// Whether to calibrate INT8 engine.
|
||||
bool calibration_mode_;
|
||||
|
||||
// Batches of the cached engines
|
||||
std::vector<int> cached_engine_batches_;
|
||||
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
@ -160,6 +167,7 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
|
||||
TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr);
|
||||
TYPECASE(DT_HALF, tensor_ptr, dest_ptr);
|
||||
TYPECASE(DT_INT8, tensor_ptr, dest_ptr);
|
||||
TYPECASE(DT_INT32, tensor_ptr, dest_ptr);
|
||||
default: {
|
||||
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
|
||||
return nullptr;
|
||||
@ -195,12 +203,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
context->GetAttr("workspace_size_bytes", &workspace_size_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
|
||||
if (!static_engine_) {
|
||||
if (!segment_graph_.ParseFromString(serialized_segment_)) {
|
||||
LOG(ERROR) << "Parsing segment graph failed!";
|
||||
context->SetStatus(
|
||||
errors::InvalidArgument("Failed to parse segment graphdef!"));
|
||||
return;
|
||||
}
|
||||
OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
|
||||
errors::InvalidArgument("Failed to parse segment graphdef!"));
|
||||
VLOG(1) << "Size of serialized GraphDef: "
|
||||
<< serialized_segment_.capacity();
|
||||
string tmp;
|
||||
@ -230,16 +234,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
native_func_ = kInvalidHandle;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
||||
&max_cached_engines_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
|
||||
&cached_engine_batches_));
|
||||
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string s("Engine Batches= ");
|
||||
for (auto i : cached_engine_batches_) {
|
||||
StrAppend(&s, i, " ");
|
||||
}
|
||||
VLOG(1) << s;
|
||||
}
|
||||
}
|
||||
|
||||
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
@ -298,11 +292,10 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
const Tensor& t = ctx->input(i);
|
||||
void* data_address = GetTensorAddress(&t);
|
||||
if (data_address == nullptr) {
|
||||
ctx->SetStatus(errors::InvalidArgument(
|
||||
"Unsupported data type encountered in input ", i));
|
||||
return;
|
||||
}
|
||||
OP_REQUIRES_ASYNC(ctx, data_address,
|
||||
errors::InvalidArgument(
|
||||
"Unsupported data type encountered in input ", i),
|
||||
*helper);
|
||||
// Check the allocated buffer is sufficient for input
|
||||
const auto device_tensor =
|
||||
calib_res->device_tensors_.at(i).AccessTensor(ctx);
|
||||
@ -331,34 +324,74 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
}
|
||||
|
||||
bool TRTEngineOp::GetCompatibleCachedEngine(
|
||||
const std::vector<TensorShape>& actual_input_shapes,
|
||||
Status TRTEngineOp::VerifyInputShapes(const std::vector<TensorShape>& shapes) {
|
||||
if (shapes.empty()) {
|
||||
return errors::InvalidArgument("Input shapes are empty, for ", name());
|
||||
}
|
||||
if (shapes[0].dims() < 1) {
|
||||
return errors::InvalidArgument("Input shapes contain scalar, for ", name(),
|
||||
": ",
|
||||
TensorShapeUtils::ShapeListString(shapes));
|
||||
}
|
||||
|
||||
const int batch_size = shapes[0].dim_size(0);
|
||||
for (const TensorShape& shape : shapes) {
|
||||
if (shape.dims() < 1 || batch_size != shape.dim_size(0)) {
|
||||
return errors::InvalidArgument(
|
||||
"Input shapes are inconsistent on the batch dimension, for ", name(),
|
||||
": ", TensorShapeUtils::ShapeListString(shapes));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TRTEngineOp::GetEngineInputShapes(
|
||||
const CacheType& cache, const std::vector<TensorShape>& actual_input_shapes,
|
||||
std::vector<TensorShape>* engine_input_shapes) {
|
||||
const int batch_size = actual_input_shapes[0].dim_size(0);
|
||||
int smallest_batch_size = -1;
|
||||
// Output shape will always be the same as the input but we will overwrite the
|
||||
// batch size.
|
||||
auto match_shape = [](const TensorShape& actual_shape,
|
||||
const TensorShape& cached_shape) {
|
||||
// Match the rank.
|
||||
if (actual_shape.dims() != cached_shape.dims()) return false;
|
||||
// Match the batch size.
|
||||
if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
|
||||
// Match remaining dimensions.
|
||||
for (int i = 1; i < actual_shape.dims(); ++i) {
|
||||
if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
auto match_shapes = [&](const std::vector<TensorShape>& actual_shapes,
|
||||
const std::vector<TensorShape>& cached_shapes) {
|
||||
for (int i = 0; i < actual_shapes.size(); ++i) {
|
||||
if (!match_shape(actual_shapes[i], cached_shapes[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// VerifyInputShapes() already ensured that all input shapes have same
|
||||
// batch size, and are not scalars.
|
||||
*engine_input_shapes = actual_input_shapes;
|
||||
for (const int cached_batch_size : cached_engine_batches_) {
|
||||
// Check if compatible: batch <= cached batch.
|
||||
//
|
||||
// TODO(laigd): here it only compare the first dim a.k.a the batch size,
|
||||
// we'll need to to support non-batch dimensions as well. This will be done
|
||||
// as part of the offline conversion implementation.
|
||||
if (batch_size <= cached_batch_size) {
|
||||
// First case: first compatible engine found
|
||||
// Second case: smaller batch size engine found
|
||||
if ((smallest_batch_size == -1) ||
|
||||
(cached_batch_size < smallest_batch_size)) {
|
||||
smallest_batch_size = cached_batch_size;
|
||||
// Overwrite batch size for output
|
||||
for (int i = 0; i < engine_input_shapes->size(); i++) {
|
||||
(*engine_input_shapes)[i].set_dim(0, smallest_batch_size);
|
||||
}
|
||||
int64 min_matched_batch_size = kint64max;
|
||||
for (const auto& pair : cache) {
|
||||
const std::vector<TensorShape>& cached_input_shapes = pair.first;
|
||||
// This should not happen, but just for safety.
|
||||
if (actual_input_shapes.size() != cached_input_shapes.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Input shape list size mismatch for ", name(),
|
||||
", cached size: ", cached_input_shapes.size(),
|
||||
" vs. actual size: ", actual_input_shapes.size());
|
||||
}
|
||||
if (match_shapes(actual_input_shapes, cached_input_shapes)) {
|
||||
const int cached_batch_size = cached_input_shapes[0].dim_size(0);
|
||||
if (min_matched_batch_size > cached_batch_size) {
|
||||
min_matched_batch_size = cached_batch_size;
|
||||
*engine_input_shapes = cached_input_shapes;
|
||||
}
|
||||
}
|
||||
}
|
||||
return (smallest_batch_size != -1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
@ -375,7 +408,10 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
input_shapes.push_back(ctx->input(i).shape());
|
||||
}
|
||||
EngineContext* engine_context = GetEngine(input_shapes, ctx);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_shapes), *helper);
|
||||
StatusOr<EngineContext*> status = GetEngine(input_shapes, ctx);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
|
||||
EngineContext* engine_context = status.ValueOrDie();
|
||||
if (!engine_context->cuda_engine) {
|
||||
VLOG(1) << "Engine retrieval for input shapes: "
|
||||
<< TensorShapeUtils::ShapeListString(input_shapes)
|
||||
@ -519,7 +555,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
return !kRetry;
|
||||
}
|
||||
|
||||
EngineContext* TRTEngineOp::GetEngine(
|
||||
StatusOr<EngineContext*> TRTEngineOp::GetEngine(
|
||||
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
|
||||
static EngineContext empty_context;
|
||||
mutex_lock lock(engine_mutex_);
|
||||
@ -609,21 +645,11 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
// See if there is a compatible engine cached. The batch size should be <= the
|
||||
// cached batch size.
|
||||
std::vector<TensorShape> engine_input_shapes;
|
||||
const bool matched_successfully =
|
||||
GetCompatibleCachedEngine(input_shapes, &engine_input_shapes);
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetEngineInputShapes(cache, input_shapes, &engine_input_shapes));
|
||||
|
||||
// If matched, use that engine. Otherwise, we will look in cache for that
|
||||
// exact shape and possibly create a new engine if it is not in cache.
|
||||
if (!matched_successfully) {
|
||||
engine_input_shapes = input_shapes;
|
||||
if (!cached_engine_batches_.empty()) {
|
||||
// If user has explicitly defined cached_engine_batches, we should
|
||||
// warn them that their input was non-compatible (batch size too high)
|
||||
LOG(WARNING) << "No compatible cached engine was found for batch size: "
|
||||
<< batch_size << ". A new engine will be created.";
|
||||
cached_engine_batches_.push_back(batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (!cache.count(engine_input_shapes)) {
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
bool convert_successfully = false;
|
||||
|
@ -17,17 +17,20 @@ limitations under the License.
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -38,48 +41,115 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class TRTEngineOpTestBase : public OpsTestBase {
|
||||
public:
|
||||
void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1) {
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
|
||||
// Create simple TF graph.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
|
||||
ops::Placeholder::Shape({-1, -1}));
|
||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
|
||||
|
||||
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
||||
PartialTensorShape shape({-1, -1});
|
||||
|
||||
// Create the op.
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
|
||||
.Input(FakeInput(1, dtype))
|
||||
.Attr("input_shapes", {shape})
|
||||
.Attr("output_shapes", {shape})
|
||||
.Attr("static_engine", false)
|
||||
.Attr("segment_funcdef_name", "") // no native fallback
|
||||
.Attr("serialized_segment", graph_def.SerializeAsString())
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", max_cached_engines_count)
|
||||
.Attr("workspace_size_bytes", 1 << 20)
|
||||
.Attr("precision_mode", "FP32")
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(OpsTestBase::InitOp());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddSimpleInput(const TensorShape& shape) {
|
||||
std::vector<T> input(shape.num_elements());
|
||||
std::iota(input.begin(), input.end(), T(0));
|
||||
OpsTestBase::AddInputFromArray<T>(shape, input);
|
||||
}
|
||||
|
||||
void ResetInputs() {
|
||||
inputs_.clear();
|
||||
gtl::STLDeleteElements(&tensors_);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
|
||||
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
|
||||
|
||||
// Execute the op with batch size > 1.
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({2, 2}));
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Get the engine cache.
|
||||
TRTEngineCacheResource* cache_resource = nullptr;
|
||||
TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache",
|
||||
"myop", &cache_resource));
|
||||
core::ScopedUnref sc(cache_resource);
|
||||
|
||||
// It should contain only one engine.
|
||||
auto cache = &cache_resource->cache_;
|
||||
EXPECT_EQ(1, cache->size());
|
||||
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({2, 2})));
|
||||
|
||||
// Execute the op with batch size 1. It should reuse existing engine to
|
||||
// execute.
|
||||
ResetInputs();
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 2}));
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
EXPECT_EQ(1, cache->size());
|
||||
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({2, 2})));
|
||||
|
||||
// Execute the op with a larger batch size.
|
||||
ResetInputs();
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({3, 2}));
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
EXPECT_EQ(2, cache->size());
|
||||
EXPECT_THAT(cache->begin()->first, ElementsAre(TensorShape({3, 2})));
|
||||
EXPECT_THAT((++cache->begin())->first, ElementsAre(TensorShape({2, 2})));
|
||||
|
||||
// Execute the op with an input that has different non-batch dimension.
|
||||
ResetInputs();
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({10, 10}));
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
// Execute it again with an input that has the same non-batch dimension but
|
||||
// smallest batch size. It should find the correct engine to use.
|
||||
ResetInputs();
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 10}));
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
EXPECT_EQ(3, cache->size()); // Should only create 3 engines in total.
|
||||
auto iter = cache->begin();
|
||||
EXPECT_THAT(iter->first, ElementsAre(TensorShape({10, 10})));
|
||||
EXPECT_THAT((++iter)->first, ElementsAre(TensorShape({3, 2})));
|
||||
EXPECT_THAT((++iter)->first, ElementsAre(TensorShape({2, 2})));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class TRTEngineOpTest : public OpsTestBase {};
|
||||
class TRTEngineOpTest : public TRTEngineOpTestBase {};
|
||||
|
||||
using TypeList = ::testing::Types<float, Eigen::half>;
|
||||
TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
|
||||
|
||||
TYPED_TEST(TRTEngineOpTest, Basic) {
|
||||
DataType dtype = DataTypeToEnum<TypeParam>::v();
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
|
||||
// Create simple TF graph.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
|
||||
ops::Placeholder::Shape({1, 2}));
|
||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
|
||||
|
||||
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
||||
TensorShapeProto shape;
|
||||
TensorShape({1, 2}).AsProto(&shape);
|
||||
|
||||
// Create the op.
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp")
|
||||
.Input(FakeInput(1, dtype))
|
||||
.Attr("input_shapes", {shape})
|
||||
.Attr("output_shapes", {shape})
|
||||
.Attr("static_engine", false)
|
||||
.Attr("segment_funcdef_name", "") // no native fallback
|
||||
.Attr("serialized_segment", graph_def.SerializeAsString())
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Attr("workspace_size_bytes", 1 << 20)
|
||||
.Attr("precision_mode", "FP32")
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(OpsTestBase::InitOp());
|
||||
TRTEngineOpTestBase::AddSimpleTrtOp(DataTypeToEnum<TypeParam>::v());
|
||||
|
||||
// Execute the op.
|
||||
OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),
|
||||
|
@ -38,7 +38,6 @@ REGISTER_OP("TRTEngineOp")
|
||||
.Attr("segment_funcdef_name: string")
|
||||
.Attr("InT: list({int8,float16,float32,int32})")
|
||||
.Attr("OutT: list({int8,float16,float32,int32})")
|
||||
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
.Attr("workspace_size_bytes: int")
|
||||
.Attr("precision_mode: {'FP32', 'FP16', 'INT8'}")
|
||||
@ -54,6 +53,7 @@ REGISTER_OP("TRTEngineOp")
|
||||
// inference function as a workaround.
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
// Deprecated attributes.
|
||||
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
||||
.Attr("fixed_input_size: bool = true")
|
||||
.Attr("static_engine: bool = true");
|
||||
} // namespace tensorflow
|
||||
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Exposes the Python wrapper of TRTEngineOp."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import platform
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
_tf_trt_so = None
|
||||
_module_lock = threading.Lock()
|
||||
|
||||
|
||||
def load_trt_ops():
|
||||
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
|
||||
global _tf_trt_so
|
||||
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
if platform.system() == "Windows":
|
||||
raise RuntimeError("Windows platforms are not supported")
|
||||
|
||||
with _module_lock:
|
||||
if _tf_trt_so:
|
||||
return
|
||||
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top,unused-variable
|
||||
# This will call register_op_list() in
|
||||
# tensorflow/python/framework/op_def_registry.py, but it doesn't register
|
||||
# the op or the op kernel in C++ runtime.
|
||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
||||
# pylint: enable=g-import-not-at-top,unused-variable
|
||||
except ImportError as e:
|
||||
print("**** Failed to import TF-TRT ops. This is because the binary was "
|
||||
"not built with CUDA or TensorRT enabled. ****")
|
||||
raise e
|
||||
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# Loading the shared object will cause registration of the op and the op
|
||||
# kernel if we link TF-TRT dynamically.
|
||||
_tf_trt_so = load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile("libtftrt.so"))
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
"**** Failed to initialize TensorRT. This is either because the "
|
||||
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
|
||||
"you do not have it installed. If not installed, please go to "
|
||||
"https://developer.nvidia.com/tensorrt to download and install "
|
||||
"TensorRT ****")
|
||||
print(no_trt_message)
|
||||
raise e
|
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
87
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc
Normal file
@ -0,0 +1,87 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||
int nmsMaxOut, float iouThreshold,
|
||||
float minBoxSize, float spatialScale,
|
||||
nvinfer1::DimsHW pooling,
|
||||
nvinfer1::Weights anchorRatios,
|
||||
nvinfer1::Weights anchorScales) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||
bool acrossSpatial,
|
||||
bool channelShared, float eps) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||
nvinfer1::plugin::PriorBoxParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||
return func_ptr(param, numLayers);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNMSPlugin(
|
||||
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||
return func_ptr(negSlope);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||
return func_ptr(stride);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createRegionPlugin(
|
||||
nvinfer1::plugin::RegionParameters params) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||
return func_ptr(params);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||
float clipMax) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||
return func_ptr(layerName, clipMin, clipMax);
|
||||
}
|
||||
|
||||
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||
using FuncPtr = bool ( *)(void *, const char *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||
return func_ptr(logger, libNamespace);
|
||||
}
|
||||
|
||||
} // extern "C"
|
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
95
tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc
Normal file
@ -0,0 +1,95 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
nvinfer1::IPluginV2* createRPNROIPlugin(int featureStride, int preNmsTop,
|
||||
int nmsMaxOut, float iouThreshold,
|
||||
float minBoxSize, float spatialScale,
|
||||
nvinfer1::DimsHW pooling,
|
||||
nvinfer1::Weights anchorRatios,
|
||||
nvinfer1::Weights anchorScales) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int, int, int, float, float, float, nvinfer1::DimsHW, nvinfer1::Weights, nvinfer1::Weights);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRPNROIPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRPNROIPlugin");
|
||||
return func_ptr(featureStride, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, spatialScale, pooling, anchorRatios, anchorScales);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNormalizePlugin(const nvinfer1::Weights* scales,
|
||||
bool acrossSpatial,
|
||||
bool channelShared, float eps) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const nvinfer1::Weights *, bool, bool, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNormalizePlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNormalizePlugin");
|
||||
return func_ptr(scales, acrossSpatial, channelShared, eps);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createPriorBoxPlugin(
|
||||
nvinfer1::plugin::PriorBoxParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::PriorBoxParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createPriorBoxPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createPriorBoxPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createAnchorGeneratorPlugin(
|
||||
nvinfer1::plugin::GridAnchorParameters* param, int numLayers) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::GridAnchorParameters *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createAnchorGeneratorPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createAnchorGeneratorPlugin");
|
||||
return func_ptr(param, numLayers);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createNMSPlugin(
|
||||
nvinfer1::plugin::DetectionOutputParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::DetectionOutputParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createLReLUPlugin(float negSlope) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createLReLUPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createLReLUPlugin");
|
||||
return func_ptr(negSlope);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createReorgPlugin(int stride) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createReorgPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createReorgPlugin");
|
||||
return func_ptr(stride);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createRegionPlugin(
|
||||
nvinfer1::plugin::RegionParameters params) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::RegionParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createRegionPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createRegionPlugin");
|
||||
return func_ptr(params);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createClipPlugin(const char* layerName, float clipMin,
|
||||
float clipMax) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(const char *, float, float);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createClipPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createClipPlugin");
|
||||
return func_ptr(layerName, clipMin, clipMax);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* createBatchedNMSPlugin(
|
||||
nvinfer1::plugin::NMSParameters param) {
|
||||
using FuncPtr = nvinfer1::IPluginV2 * ( *)(nvinfer1::plugin::NMSParameters);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createBatchedNMSPlugin");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createBatchedNMSPlugin");
|
||||
return func_ptr(param);
|
||||
}
|
||||
|
||||
bool initLibNvInferPlugins(void* logger, const char* libNamespace) {
|
||||
using FuncPtr = bool ( *)(void *, const char *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("initLibNvInferPlugins");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("initLibNvInferPlugins");
|
||||
return func_ptr(logger, libNamespace);
|
||||
}
|
||||
|
||||
} // extern "C"
|
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
40
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc
Normal file
@ -0,0 +1,40 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* getLogger() {
|
||||
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
int getInferLibVersion() {
|
||||
using FuncPtr = int (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
} // extern "C"
|
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
47
tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc
Normal file
@ -0,0 +1,47 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* createInferBuilder_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferBuilder_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferBuilder_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
void* createInferRefitter_INTERNAL(void* engine, void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRefitter_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRefitter_INTERNAL");
|
||||
return func_ptr(engine, logger, version);
|
||||
}
|
||||
|
||||
void* createInferRuntime_INTERNAL(void* logger, int version) {
|
||||
using FuncPtr = void * (*)(void *, int);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("createInferRuntime_INTERNAL");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("createInferRuntime_INTERNAL");
|
||||
return func_ptr(logger, version);
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* getLogger() {
|
||||
using FuncPtr = nvinfer1::ILogger * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getLogger");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getLogger");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
int getInferLibVersion() {
|
||||
using FuncPtr = int (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getInferLibVersion");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getInferLibVersion");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
nvinfer1::IPluginRegistry* getPluginRegistry() {
|
||||
using FuncPtr = nvinfer1::IPluginRegistry * (*)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("getPluginRegistry");
|
||||
if (!func_ptr) LogFatalSymbolNotFound("getPluginRegistry");
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
} // extern "C"
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "third_party/tensorrt/NvInferPlugin.h"
|
||||
|
||||
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
tensorflow::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if NV_TENSORRT_MAJOR < 5
|
||||
#error TensorRT version earlier than 5 is not supported.
|
||||
#elif NV_TENSORRT_MINOR < 1
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_0.inc"
|
||||
#else
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_5_1.inc"
|
||||
#endif
|
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
59
tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
// Implements the TensorRT API by forwarding to TensorRT loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetNvInferDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
tensorflow::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if NV_TENSORRT_MAJOR < 5
|
||||
#error TensorRT version earlier than 5 is not supported.
|
||||
#elif NV_TENSORRT_MINOR < 1
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_0.inc"
|
||||
#else
|
||||
#include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_5_1.inc"
|
||||
#endif
|
@ -1,4 +1,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
@ -26,13 +32,6 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla_supported_ops_lib",
|
||||
srcs = ["tf2xla_supported_ops.cc"],
|
||||
@ -40,7 +39,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -108,7 +106,6 @@ cc_library(
|
||||
":tf2xla_proto",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
|
@ -1,10 +1,10 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc")
|
||||
|
||||
tf_gen_op_wrapper_cc(
|
||||
name = "xla_ops_gen",
|
||||
out_ops_file = "ops/xla_ops",
|
||||
|
@ -54,6 +54,7 @@ tf_kernel_library(
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"matrix_band_part_op.cc",
|
||||
"matrix_inverse_op.cc",
|
||||
"matrix_set_diag_op.cc",
|
||||
"matrix_triangular_solve_op.cc",
|
||||
"mirror_pad_op.cc",
|
||||
@ -79,6 +80,7 @@ tf_kernel_library(
|
||||
"retval_op.cc",
|
||||
"reverse_op.cc",
|
||||
"reverse_sequence_op.cc",
|
||||
"roll_op.cc",
|
||||
"scan_ops.cc",
|
||||
"scatter_nd_op.cc",
|
||||
"segment_reduction_ops.cc",
|
||||
@ -345,50 +347,3 @@ tf_kernel_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Kernels that only work on CPU, because they use XLA custom calls.
|
||||
# Only link this when using the CPU backend for XLA.
|
||||
tf_kernel_library(
|
||||
name = "xla_cpu_only_ops",
|
||||
srcs = ["index_ops_cpu.cc"],
|
||||
deps = [
|
||||
":index_ops_kernel_argmax_float_1d",
|
||||
":index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "index_ops_kernel_argmax_float_1d",
|
||||
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "index_ops_kernel_argmax_float_2d",
|
||||
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1,142 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Native XLA implementations of indexing ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// The logic below uses a custom-call to implement argmax when possible. When
|
||||
// custom-call is not allowed or input shapes are not supported, this kernel
|
||||
// falls back to using XLA HLO native ArgMax.
|
||||
//
|
||||
// Also see b/29507024 for first-class XLA support for indexing ops.
|
||||
class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape dimension_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
|
||||
errors::InvalidArgument(
|
||||
"dim must be a scalar, but received tensor of shape: ",
|
||||
dimension_shape.DebugString()));
|
||||
|
||||
// We require that the dimension argument is a constant, since it lets us
|
||||
// dispatch to a specialized custom-call function without any run-time
|
||||
// overhead, when compiling ahead-of-time.
|
||||
int64 dim;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
|
||||
|
||||
const int input_dims = input_shape.dims();
|
||||
const int axis = dim < 0 ? dim + input_dims : dim;
|
||||
OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
|
||||
errors::InvalidArgument("Expected dimension in the range [",
|
||||
-input_dims, ", ", input_dims,
|
||||
"), but got ", dim));
|
||||
|
||||
const int64 axis_size = input_shape.dim_size(axis);
|
||||
OP_REQUIRES(ctx, axis_size > 0,
|
||||
errors::InvalidArgument(
|
||||
"Reduction axis ", dim,
|
||||
" is empty in shape: ", input_shape.DebugString()));
|
||||
|
||||
const DataType dtype = output_type(0);
|
||||
xla::PrimitiveType output_type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
|
||||
|
||||
// Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
|
||||
// shape isn't supported.
|
||||
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
|
||||
(input_dims != 1 && input_dims != 2)) {
|
||||
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
|
||||
ctx->SetOutput(0, output);
|
||||
return;
|
||||
}
|
||||
|
||||
xla::XlaOp output;
|
||||
// The output shape is the input shape contracted along axis.
|
||||
TensorShape output_shape;
|
||||
for (int d = 0; d < input_shape.dims() - 1; ++d) {
|
||||
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
|
||||
}
|
||||
|
||||
xla::XlaBuilder& b = *ctx->builder();
|
||||
|
||||
// XLA passes <out> to the function, so it is not included here.
|
||||
std::vector<xla::XlaOp> args;
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
if (input_shape.dims() > 1) {
|
||||
// Don't bother passing the output shape and dim for the 1d case, since
|
||||
// the shape is always a scalar and the dim is always 0.
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(
|
||||
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
|
||||
}
|
||||
|
||||
// The argmax function expects row-major layout.
|
||||
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
xla::S64, output_shape.dim_sizes());
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
for (const xla::XlaOp& arg : args) {
|
||||
auto shape_status = b.GetShape(arg);
|
||||
OP_REQUIRES_OK(ctx, shape_status.status());
|
||||
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
|
||||
*arg_shape.mutable_layout() =
|
||||
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
|
||||
arg_shapes.push_back(std::move(arg_shape));
|
||||
}
|
||||
|
||||
// Tell XLA to call the custom code, defined in
|
||||
// index_ops_kernel_argmax_float_{1, 2}d.cc.
|
||||
if (input_dims == 1) {
|
||||
output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
|
||||
xla_shape, arg_shapes);
|
||||
} else {
|
||||
output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
|
||||
xla_shape, arg_shapes);
|
||||
}
|
||||
output = xla::ConvertElementType(output, output_type);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ArgMax")
|
||||
.TypeConstraint("T", DT_FLOAT)
|
||||
.Device(DEVICE_CPU_XLA_JIT)
|
||||
.CompileTimeConstantInput("dimension"),
|
||||
ArgMaxCustomCallOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,52 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
// Data is managed by the JIT code so msan can't tell it's initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
|
||||
|
||||
float* input = static_cast<float*>(data[0]);
|
||||
int64 input_size = *static_cast<int64*>(data[1]);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
|
||||
TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
|
||||
int64* out_t = static_cast<int64*>(out);
|
||||
TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
|
||||
|
||||
out_eig = in_eig.argmax(0).cast<int64>();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
||||
// index_ops.cc.
|
||||
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_1d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
@ -1,54 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
// data is managed by the JIT code so msan can't tell it's initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
|
||||
|
||||
float* in = static_cast<float*>(data[0]);
|
||||
int64* in_sizes = static_cast<int64*>(data[1]);
|
||||
int64* out_sizes = static_cast<int64*>(data[2]);
|
||||
int32 dim = *static_cast<int32*>(data[3]);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
|
||||
TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
|
||||
|
||||
int64* out_t = static_cast<int64*>(out);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
|
||||
TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
|
||||
|
||||
out_eig = in_eig.argmax(dim).cast<int64>();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
||||
// index_ops.cc.
|
||||
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_2d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
68
tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc
Normal file
68
tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MatrixInverseOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixInverseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
int64 ndims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("Input must have rank >= 2, got ", ndims));
|
||||
OP_REQUIRES(
|
||||
ctx, input_shape.dim_size(ndims - 2) == input_shape.dim_size(ndims - 1),
|
||||
errors::InvalidArgument("Input matrices must be squares, got",
|
||||
input_shape.dim_size(ndims - 2),
|
||||
" != ", input_shape.dim_size(ndims - 1)));
|
||||
|
||||
xla::XlaOp input = xla::MaybeTransposeInMinorDims(ctx->Input(0), adjoint_);
|
||||
|
||||
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
|
||||
auto qr = xla::QRDecomposition(input, /*full_matrices=*/false);
|
||||
OP_REQUIRES_OK(ctx, qr.status());
|
||||
|
||||
xla::XlaOp output = xla::TriangularSolve(
|
||||
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
|
||||
/*left_side=*/true,
|
||||
/*lower=*/false, /*unit_diagonal=*/false,
|
||||
/*transpose_a=*/
|
||||
xla::TriangularSolveOptions::NO_TRANSPOSE);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
bool adjoint_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp);
|
||||
};
|
||||
|
||||
// TODO(b/135640736): Allow this for integer and complex types.
|
||||
REGISTER_XLA_OP(Name("MatrixInverse").TypeConstraint("T", kFloatTypes),
|
||||
MatrixInverseOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -65,24 +65,18 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||
"Second dimension of 2D input must be of size 2, but got shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
}
|
||||
std::vector<int32> dst_indices(4, 0);
|
||||
int32 dst_indices[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
if (src_format_[i] == dst_format_[j]) {
|
||||
dst_indices[i] = j;
|
||||
dst_indices[j] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
|
||||
if (input_rank == 2) {
|
||||
keys = xla::BroadcastInDim(keys, {4, 2}, {0});
|
||||
}
|
||||
auto sorted = xla::Sort({keys, ctx->Input(0)},
|
||||
xla::CreateScalarLtComputation(
|
||||
{xla::S32, ctx->input_xla_type(0)}, builder),
|
||||
0);
|
||||
auto output = xla::GetTupleElement(sorted, 1);
|
||||
xla::XlaOp indices =
|
||||
xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
|
||||
xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
89
tensorflow/compiler/tf2xla/kernels/roll_op.cc
Normal file
89
tensorflow/compiler/tf2xla/kernels/roll_op.cc
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class RollOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit RollOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
xla::XlaOp shift = ctx->Input(1);
|
||||
const TensorShape shift_shape = ctx->InputShape(1);
|
||||
const TensorShape axis_shape = ctx->InputShape(2);
|
||||
|
||||
OP_REQUIRES(ctx, input_shape.dims() >= 1,
|
||||
errors::InvalidArgument("input must be 1-D or higher"));
|
||||
OP_REQUIRES(ctx, shift_shape.dims() <= 1,
|
||||
errors::InvalidArgument(
|
||||
"shift must be a scalar or a 1-D vector. Found: ",
|
||||
shift_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, shift_shape.dims() == axis_shape.dims(),
|
||||
errors::InvalidArgument("shift and axis must have the same size"));
|
||||
|
||||
xla::Literal axis;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &axis));
|
||||
|
||||
xla::XlaOp output = ctx->Input(0);
|
||||
xla::PrimitiveType shift_type = ctx->input_xla_type(1);
|
||||
int64 num_axes = axis_shape.dims() == 0 ? 1 : axis_shape.dim_size(0);
|
||||
for (int64 i = 0; i != num_axes; ++i) {
|
||||
auto cur_axis_status = axis_shape.dims() == 0
|
||||
? axis.GetIntegralAsS64({})
|
||||
: axis.GetIntegralAsS64({i});
|
||||
OP_REQUIRES_OK(ctx, cur_axis_status.status());
|
||||
int64 cur_axis = cur_axis_status.ValueOrDie();
|
||||
|
||||
xla::XlaOp offset =
|
||||
shift_shape.dims() == 0
|
||||
? shift
|
||||
: xla::Reshape(xla::SliceInDim(shift, /*start_index=*/i,
|
||||
/*limit_index=*/i + 1,
|
||||
/*stride=*/1, /*dimno=*/0),
|
||||
{});
|
||||
xla::XlaOp axis_size = xla::ConstantR0WithType(
|
||||
ctx->builder(), shift_type, input_shape.dim_size(cur_axis));
|
||||
// Adjust large offsets into [0, axis_size). This also makes negative
|
||||
// offsets positive.
|
||||
offset = ((offset % axis_size) + axis_size) % axis_size;
|
||||
|
||||
// Stack two copies of the dimension, then slice from the calculated
|
||||
// offset.
|
||||
xla::XlaOp concat =
|
||||
xla::ConcatInDim(ctx->builder(), {output, output}, cur_axis);
|
||||
std::vector<xla::XlaOp> start_indices(
|
||||
input_shape.dims(), xla::Zero(ctx->builder(), shift_type));
|
||||
start_indices[cur_axis] = axis_size - offset;
|
||||
output =
|
||||
xla::DynamicSlice(concat, start_indices, input_shape.dim_sizes());
|
||||
}
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RollOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Roll").CompileTimeConstantInput("axis"), RollOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,14 +1,14 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
srcs = ["xla_ops.cc"],
|
||||
|
@ -633,12 +633,14 @@ REGISTER_OP("XlaEinsum")
|
||||
if (context->RankKnown(input_a)) {
|
||||
rank_a = context->Rank(input_a);
|
||||
} else {
|
||||
return errors::InvalidArgument("input 0's rank is unknown.");
|
||||
context->set_output(0, context->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
if (context->RankKnown(input_b)) {
|
||||
rank_b = context->Rank(input_b);
|
||||
} else {
|
||||
return errors::InvalidArgument("input 1's rank is unknown.");
|
||||
context->set_output(0, context->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
string equation;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
|
||||
|
@ -303,10 +303,27 @@ Status BuildComputation(
|
||||
handle = identity_op(handle);
|
||||
|
||||
// Set layout of the retval to device representation layout.
|
||||
if (resource->representation_shape().has_value()) {
|
||||
retval_index_and_layout.emplace_back(
|
||||
elems.size(), resource->representation_shape()->layout());
|
||||
absl::optional<xla::Shape> representation_shape;
|
||||
if (shape_representation_fn) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape xla_shape,
|
||||
shape_representation_fn(resource->shape(), resource->type()));
|
||||
representation_shape = xla_shape;
|
||||
}
|
||||
if (resource->representation_shape().has_value()) {
|
||||
const xla::Shape& xla_shape = resource->representation_shape().value();
|
||||
if (representation_shape) {
|
||||
TF_RET_CHECK(
|
||||
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
|
||||
} else {
|
||||
representation_shape = xla_shape;
|
||||
}
|
||||
}
|
||||
if (representation_shape) {
|
||||
retval_index_and_layout.emplace_back(elems.size(),
|
||||
representation_shape->layout());
|
||||
}
|
||||
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
@ -553,6 +570,7 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
GraphOptimizer::Options graph_optimizer_options;
|
||||
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
|
||||
graph_optimizer_options.inline_multi_device_functions = true;
|
||||
graph_optimizer_options.inline_impl_selection_group_functions = true;
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user