merging master with upstream master
This commit is contained in:
commit
d391ba441b
28
WORKSPACE
28
WORKSPACE
@ -14,6 +14,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
|
||||
|
||||
closure_repositories()
|
||||
|
||||
http_archive(
|
||||
name = "base_images_docker",
|
||||
sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9",
|
||||
strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6",
|
||||
urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "bazel_toolchains",
|
||||
sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb",
|
||||
strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b",
|
||||
urls = [
|
||||
"https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_docker",
|
||||
sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd",
|
||||
strip_prefix = "rules_docker-0.5.1",
|
||||
urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"],
|
||||
)
|
||||
|
||||
load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace")
|
||||
|
||||
remote_config_workspace()
|
||||
|
||||
# We must check the bazel version before trying to parse any other BUILD
|
||||
# files, in case the parsing of those build files depends on the bazel
|
||||
# version we require here.
|
||||
@ -79,3 +106,4 @@ new_http_archive(
|
||||
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||
_TF_OPENCL_VERSION = '1.2'
|
||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16]
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
|
||||
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
|
||||
|
||||
@ -1555,6 +1555,9 @@ def main():
|
||||
check_bazel_version('0.15.0')
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
|
||||
write_to_bazelrc('import %workspace%/tools/bazel.rc')
|
||||
|
||||
cleanup_makefile()
|
||||
setup_python(environ_cp)
|
||||
|
||||
|
@ -352,6 +352,7 @@ package_group(
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/...",
|
||||
"//tensorflow_fold/llgtm/...",
|
||||
"//tensorflow_text/...",
|
||||
"//third_party/py/tensor2tensor/...",
|
||||
],
|
||||
)
|
||||
|
@ -95,6 +95,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -199,7 +200,7 @@ tf_cuda_cc_test(
|
||||
size = "small",
|
||||
srcs = ["c_api_test.cc"],
|
||||
data = [
|
||||
":test_op.so",
|
||||
":test_op1.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
kernels = [":test_op_kernel"],
|
||||
@ -218,6 +219,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/cc:grad_ops",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
"//tensorflow/cc/saved_model:tag_constants",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:framework",
|
||||
@ -284,8 +286,8 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "test_op.so",
|
||||
srcs = ["test_op.cc"],
|
||||
name = "test_op1.so",
|
||||
srcs = ["test_op1.cc"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -2810,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// TF_Server functions ----------------------------------------------
|
||||
|
||||
#ifndef __ANDROID__
|
||||
TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
|
||||
: target(server->target()), server(std::move(server)) {}
|
||||
#endif // __ANDROID__
|
||||
|
||||
TF_Server* TF_NewServer(const void* proto, size_t proto_len,
|
||||
TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Server functionality is not supported in Android");
|
||||
return nullptr;
|
||||
#else
|
||||
tensorflow::ServerDef server_def;
|
||||
if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
|
||||
status->status = InvalidArgument(
|
||||
"Could not parse provided bytes into a ServerDef protocol buffer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::ServerInterface> out_server;
|
||||
status->status = tensorflow::NewServer(server_def, &out_server);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TF_Server(std::move(out_server));
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_ServerStart(TF_Server* server, TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Server functionality is not supported in Android");
|
||||
#else
|
||||
status->status = server->server->Start();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_ServerStop(TF_Server* server, TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Server functionality is not supported in Android");
|
||||
#else
|
||||
status->status = server->server->Stop();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_ServerJoin(TF_Server* server, TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Server functionality is not supported in Android");
|
||||
#else
|
||||
status->status = server->server->Join();
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* TF_ServerTarget(TF_Server* server) {
|
||||
#ifdef __ANDROID__
|
||||
return nullptr;
|
||||
#else
|
||||
return server->target.c_str();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_DeleteServer(TF_Server* server) { delete server; }
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -1668,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||
const char* name, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// In-process TensorFlow server functionality, for use in distributed training.
|
||||
// A Server instance encapsulates a set of devices and a Session target that
|
||||
// can participate in distributed training. A server belongs to a cluster
|
||||
// (specified by a ClusterSpec), and corresponds to a particular task in a
|
||||
// named job. The server can communicate with any other server in the same
|
||||
// cluster.
|
||||
|
||||
// In-process TensorFlow server.
|
||||
typedef struct TF_Server TF_Server;
|
||||
|
||||
// Creates a new in-process TensorFlow server configured using a serialized
|
||||
// ServerDef protocol buffer provided via `proto` and `proto_len`.
|
||||
//
|
||||
// The server will not serve any requests until TF_ServerStart is invoked.
|
||||
// The server will stop serving requests once TF_ServerStop or
|
||||
// TF_DeleteServer is invoked.
|
||||
TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Starts an in-process TensorFlow server.
|
||||
TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status);
|
||||
|
||||
// Stops an in-process TensorFlow server.
|
||||
TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status);
|
||||
|
||||
// Blocks until the server has been successfully stopped (via TF_ServerStop or
|
||||
// TF_ServerClose).
|
||||
TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status);
|
||||
|
||||
// Returns the target string that can be provided to TF_SetTarget() to connect
|
||||
// a TF_Session to `server`.
|
||||
//
|
||||
// The returned string is valid only until TF_DeleteServer is invoked.
|
||||
TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server);
|
||||
|
||||
// Destroy an in-process TensorFlow server, frees memory. If server is running
|
||||
// it will be stopped and joined.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#ifndef __ANDROID__
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
@ -179,6 +180,15 @@ struct TF_ApiDefMap {
|
||||
tensorflow::mutex lock;
|
||||
};
|
||||
|
||||
#ifndef __ANDROID__
|
||||
struct TF_Server {
|
||||
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
|
||||
|
||||
const tensorflow::string target;
|
||||
std::unique_ptr<tensorflow::ServerInterface> server;
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
|
@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
// tf_cuda_cc_test() bazel rule and remove the next line.
|
||||
if (!GPUDeviceName().empty()) return;
|
||||
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
|
||||
// Test op list.
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
tensorflow::OpList op_list;
|
||||
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
|
||||
ASSERT_EQ(op_list.op_size(), 1);
|
||||
EXPECT_EQ("TestCApi1", op_list.op(0).name());
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
||||
tensorflow::OpList op_list;
|
||||
@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
EXPECT_TRUE(found);
|
||||
TF_DeleteBuffer(op_list_buffer);
|
||||
}
|
||||
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
// Test op list.
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
tensorflow::OpList op_list;
|
||||
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
|
||||
ASSERT_EQ(op_list.op_size(), 1);
|
||||
EXPECT_EQ("TestCApi", op_list.op(0).name());
|
||||
}
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
|
@ -69,7 +69,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__pkg__",
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
|
@ -404,8 +404,7 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = nullptr;
|
||||
status->status = h->handle->OpDevice(&d);
|
||||
tensorflow::Device* d = h->handle->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::Device* device;
|
||||
status->status = handle->handle->Device(&device);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle->handle->device();
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
|
@ -79,10 +79,6 @@ struct TFE_TensorHandle {
|
||||
tensorflow::Device* op_device)
|
||||
: handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
|
||||
|
||||
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
|
||||
tensorflow::EagerContext* ctx)
|
||||
: handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
|
||||
|
||||
TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
|
23
tensorflow/c/test_op1.cc
Normal file
23
tensorflow/c/test_op1.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -170,6 +170,7 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc(
|
||||
":array_ops",
|
||||
":const_op",
|
||||
":math_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -93,7 +93,7 @@ cc_library(
|
||||
":tfcompile_lib",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
@ -132,7 +132,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
||||
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
|
||||
usage += tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
|
@ -21,7 +21,6 @@ package(
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
@ -52,6 +51,7 @@ cc_library(
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
],
|
||||
@ -65,6 +65,7 @@ cc_library(
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
@ -190,6 +191,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:stack",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
@ -241,6 +243,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -253,6 +256,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -263,6 +267,21 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_compilation_cache_test",
|
||||
srcs = [
|
||||
"xla_compilation_cache_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":xla_compilation_cache",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
@ -500,6 +519,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -524,25 +544,6 @@ cc_library(
|
||||
hdrs = ["union_find.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "producer_consumer_queue",
|
||||
hdrs = ["producer_consumer_queue.h"],
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "producer_consumer_queue_test",
|
||||
size = "small",
|
||||
srcs = ["producer_consumer_queue_test.cc"],
|
||||
deps = [
|
||||
":producer_consumer_queue",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "deadness_analysis_test",
|
||||
size = "small",
|
||||
@ -606,6 +607,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -648,31 +650,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_launch_util_test",
|
||||
size = "small",
|
||||
srcs = ["xla_launch_util_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_compilation_cache",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_fusion_optimizer",
|
||||
srcs = ["xla_fusion_optimizer.cc"],
|
||||
|
@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) {
|
||||
return errors::Internal("Could not find compilation device ",
|
||||
device_type.type());
|
||||
}
|
||||
*result = registration->requires_compilation;
|
||||
*result = registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root,
|
||||
Output loop_cond =
|
||||
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
|
||||
latch.output_false);
|
||||
Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
|
||||
latch.output_true, increment_by);
|
||||
Output next_iteration =
|
||||
@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue(
|
||||
value, frame_name);
|
||||
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
|
||||
latch.output_false);
|
||||
Output next_iteration = ops::NextIteration(
|
||||
root.WithOpName(prefix + "/next_iteration"), latch.output_true);
|
||||
CHECK(root.graph()
|
||||
|
@ -117,6 +117,25 @@ Status PreprocessForEncapsulation(Graph* g,
|
||||
|
||||
// Information for XLA computation.
|
||||
struct XlaClusterInfo {
|
||||
// Add an explicitly-defined default constructor for this class.
|
||||
//
|
||||
// The compiler may delete the default constructor here because
|
||||
// host_compute_core is a const member whose type (std::map) doesn't
|
||||
// necessarily have a user provided constructor -- while libc++ and
|
||||
// libstdc++ 4.8 provide a user defined default constructor, libstdc++ at
|
||||
// least >= 7.3 does not. See also c++11 [class.ctor] p5.
|
||||
//
|
||||
// TODO(klimek): In c++17 we'll be able to initialize host_compute_core
|
||||
// without losing aggregate initialization, which allows us to get rid of
|
||||
// the constructor definitions again.
|
||||
XlaClusterInfo() {}
|
||||
XlaClusterInfo(const string& cluster_name,
|
||||
const NameAttrList& func_name_attrs, Node* node,
|
||||
const std::map<string, int>& host_compute_core)
|
||||
: cluster_name(cluster_name),
|
||||
func_name_attrs(func_name_attrs),
|
||||
node(node),
|
||||
host_compute_core(host_compute_core) {}
|
||||
// XLA cluster name. It might be different from `func_name`.
|
||||
const string cluster_name;
|
||||
// Name and attributes of XLA computation function.
|
||||
|
@ -394,8 +394,8 @@ Status ConstructHostGraph(
|
||||
for (const string& host_func : outside_compilation_host_graphs) {
|
||||
VLOG(4) << "Expanding host graph " << host_func;
|
||||
FunctionBody* host_fbody = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fld->Find(host_func), AttrSlice(), fld,
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*fld->Find(host_func), AttrSlice(), fld,
|
||||
[&](const string& op, const OpDef** sig) {
|
||||
return fld->LookUpOpDef(op, sig);
|
||||
},
|
||||
@ -411,7 +411,8 @@ Status ConstructHostGraph(
|
||||
node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
|
||||
node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
|
||||
Status s;
|
||||
ReverseDFS(*host_fbody->graph, /*enter=*/nullptr,
|
||||
ReverseDFS(
|
||||
*host_fbody->graph, /*enter=*/nullptr,
|
||||
[&](const Node* n) {
|
||||
if (!s.ok()) {
|
||||
return;
|
||||
@ -838,10 +839,15 @@ Status ExtractOutsideCompilationForFunction(
|
||||
FunctionDef shape_inference_fdef = *xla_fdef;
|
||||
shape_inference_fdef.mutable_signature()->set_name(
|
||||
shape_inference_graph);
|
||||
if (fld->Find(shape_inference_graph)) {
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph,
|
||||
shape_inference_fdef));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core));
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
@ -34,14 +35,30 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) {
|
||||
TF_RET_CHECK(n->type_string() == "Const");
|
||||
|
||||
// StatusOrOptional<T> instances hold
|
||||
//
|
||||
// - A non-OK Status to indicate an error that needs to be propagated out of
|
||||
// this pass (e.g. the Graph is malformed).
|
||||
//
|
||||
// - A nullopt to indicate the function that created the instance failed to do
|
||||
// what it set out to do but this is not actually an error
|
||||
// (e.g. TryToGetTensorFromConstOp was passed a non-Const node).
|
||||
//
|
||||
// - A T to indicate a successful operation.
|
||||
template <class T>
|
||||
using StatusOrOptional = xla::StatusOr<absl::optional<T>>;
|
||||
|
||||
StatusOrOptional<Tensor> TryToGetTensorFromConstOp(Node* n) {
|
||||
if (n->type_string() != "Const") {
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
const TensorProto* proto = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
|
||||
Tensor tensor(proto->dtype());
|
||||
TF_RET_CHECK(tensor.FromProto(*proto));
|
||||
*out_tensor = std::move(tensor);
|
||||
return Status::OK();
|
||||
return {tensor};
|
||||
}
|
||||
|
||||
struct SliceInputs {
|
||||
@ -70,7 +87,7 @@ std::vector<int64> IntTensorAsVector(const Tensor& t) {
|
||||
|
||||
// Packages up the inputs to a Slice operation into an instance of
|
||||
// `SliceInputs`.
|
||||
Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) {
|
||||
StatusOrOptional<SliceInputs> GetSliceInputs(Node* slice) {
|
||||
const int kSliceInputIndex = 0;
|
||||
const int kSliceBeginIndex = 1;
|
||||
const int kSliceSizeIndex = 2;
|
||||
@ -81,23 +98,27 @@ Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) {
|
||||
TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
|
||||
const Edge* slice_begin_edge;
|
||||
TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
|
||||
slice_inputs->input =
|
||||
|
||||
SliceInputs slice_inputs;
|
||||
slice_inputs.input =
|
||||
Output(slice_input_edge->src(), slice_input_edge->src_output());
|
||||
slice_inputs->begin =
|
||||
slice_inputs.begin =
|
||||
Output(slice_begin_edge->src(), slice_begin_edge->src_output());
|
||||
slice_inputs->size =
|
||||
slice_inputs.size =
|
||||
Output(slice_size_edge->src(), slice_size_edge->src_output());
|
||||
|
||||
Tensor tf_slice_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size));
|
||||
|
||||
if (tf_slice_size.dims() != 1) {
|
||||
return errors::Internal("Expected vector for the slice size input.");
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> tf_slice_size,
|
||||
TryToGetTensorFromConstOp(slice_inputs.size.node()));
|
||||
if (!tf_slice_size.has_value()) {
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size);
|
||||
return Status::OK();
|
||||
if (tf_slice_size->dims() != 1) {
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size);
|
||||
return {slice_inputs};
|
||||
}
|
||||
|
||||
// Casts `x` to a DT_INT64 if it isn't one already.
|
||||
@ -263,36 +284,43 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns true if `n` is a slice we can rewrite to have a static shape
|
||||
// (i.e. have the output shape only depend on the "size" input). Fills in
|
||||
// `slice_inputs` in the process.
|
||||
bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) {
|
||||
// If `n` is a slice we can rewrite to have a static shape (i.e. have the output
|
||||
// shape only depend on the "size" input) then returns the a SliceInputs
|
||||
// representing the inputs to `n`. Otherwise returns nullopt.
|
||||
StatusOrOptional<SliceInputs> IsRewritableSlice(Node* n) {
|
||||
if (n->type_string() != "Slice") {
|
||||
return false;
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
if (!GetXlaClusterForNode(*n).has_value()) {
|
||||
// There is no need to change slice ops outside XLA clusters.
|
||||
return false;
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
if (!GetSliceInputs(n, slice_inputs).ok()) {
|
||||
// Could not parse slice inputs. E.g. the sizes input was not a constant.
|
||||
return false;
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
||||
GetSliceInputs(n));
|
||||
if (!slice_inputs.has_value()) {
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
// If slice_size[i] < -1 for any i then executing the slice will throw an
|
||||
// error, and we don't do anything here.
|
||||
return absl::c_all_of(slice_inputs->size_as_vector,
|
||||
bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector,
|
||||
[](int64 size_i) { return size_i >= -1; });
|
||||
if (!slice_is_ok) {
|
||||
return {absl::nullopt};
|
||||
}
|
||||
|
||||
return slice_inputs;
|
||||
}
|
||||
|
||||
Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
||||
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite;
|
||||
for (Node* n : g->nodes()) {
|
||||
SliceInputs slice_inputs;
|
||||
if (IsRewritableSlice(n, &slice_inputs)) {
|
||||
slices_to_rewrite.push_back({n, std::move(slice_inputs)});
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
||||
IsRewritableSlice(n));
|
||||
if (slice_inputs.has_value()) {
|
||||
slices_to_rewrite.push_back({n, std::move(*slice_inputs)});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,11 +44,8 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
// TODO(b/111210515): IncreaseDynamismForAutoJitPass creates slices with index
|
||||
// type DT_INT64 which do not have a kernel on GPU.
|
||||
//
|
||||
// REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
// IncreaseDynamismForAutoJitPass);
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
IncreaseDynamismForAutoJitPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
PartiallyDeclusterPass);
|
||||
|
@ -39,12 +39,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
|
||||
// in error case, it returns RET instead of void.
|
||||
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
|
||||
do { \
|
||||
::tensorflow::Status _s(__VA_ARGS__); \
|
||||
if (!TF_PREDICT_TRUE(_s.ok())) { \
|
||||
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
|
||||
return RET; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status PlatformInfoFromContext(OpKernelConstruction* ctx,
|
||||
XlaPlatformInfo* result) {
|
||||
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx,
|
||||
}
|
||||
|
||||
if (!device_allocator) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform* const platform,
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id));
|
||||
xla::StatusOr<se::Platform*> maybe_platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||
|
||||
xla_allocator = absl::make_unique<XlaAllocator>(
|
||||
platform, ctx->device()->GetAllocator({}));
|
||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||
}
|
||||
|
||||
*result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
std::move(xla_allocator), device_allocator);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A closure describing how to run a compiled version of a TensorFlow function.
|
||||
@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
: OpKernel(ctx),
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
}
|
||||
function_(function),
|
||||
platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
|
||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
@ -277,8 +286,10 @@ static Status CompileToLocalExecutable(
|
||||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
return cache->Compile(options, function, constant_args, *variables, ctx,
|
||||
compile_options,
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_args, *variables, ctx, &args));
|
||||
return cache->Compile(options, function, args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
kernel, executable);
|
||||
@ -333,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
|
||||
// in error case, it returns RET instead of void.
|
||||
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
|
||||
do { \
|
||||
::tensorflow::Status _s(__VA_ARGS__); \
|
||||
if (!TF_PREDICT_TRUE(_s.ok())) { \
|
||||
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
|
||||
return RET; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Helper static functions to construct parameters for
|
||||
// XlaLocalLaunchBase constructor from OpKernelConstruction.
|
||||
std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
|
||||
@ -381,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
|
||||
return *func;
|
||||
}
|
||||
|
||||
#undef OP_REQUIRES_OK_RETURN
|
||||
bool MustCompileAttr(OpKernelConstruction* ctx) {
|
||||
bool must_compile;
|
||||
OP_REQUIRES_OK_RETURN(ctx, false,
|
||||
ctx->GetAttr("must_compile", &must_compile));
|
||||
return must_compile;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
@ -396,10 +400,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx),
|
||||
constants_(ConstantsVector(ctx)),
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_));
|
||||
}
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
must_compile_(MustCompileAttr(ctx)) {}
|
||||
|
||||
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaCompileOp " << def().name()
|
||||
@ -409,13 +412,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) {
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster = cannot_compile_cluster_;
|
||||
}
|
||||
|
||||
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_,
|
||||
constants_, /*lazy=*/!must_compile_, &client,
|
||||
&variables, &kernel, &executable));
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
if (status.code() == error::UNIMPLEMENTED) {
|
||||
LOG(WARNING) << "Compilation failed:" << status.ToString()
|
||||
<< ". Falling back to TF function call.";
|
||||
executable = nullptr;
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
AllocatorAttributes host_alloc_attrs;
|
||||
@ -452,9 +472,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->set_output(1, compilation_successful);
|
||||
}
|
||||
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
}
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
@ -33,6 +35,7 @@ namespace tensorflow {
|
||||
class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel {
|
||||
|
||||
protected:
|
||||
// Indexes of compile-time constant inputs
|
||||
std::vector<int> constants_;
|
||||
const std::vector<int> constants_;
|
||||
// Indexes of resource inputs
|
||||
std::vector<int> resources_;
|
||||
const std::vector<int> resources_;
|
||||
|
||||
NameAttrList function_;
|
||||
XlaPlatformInfo platform_info_;
|
||||
const NameAttrList function_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||
@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel {
|
||||
|
||||
private:
|
||||
// Indexes of compile-time constant inputs
|
||||
std::vector<int> constants_;
|
||||
const std::vector<int> constants_;
|
||||
// Indexes of resource inputs
|
||||
std::vector<int> resources_;
|
||||
const std::vector<int> resources_;
|
||||
|
||||
NameAttrList function_;
|
||||
const NameAttrList function_;
|
||||
|
||||
XlaPlatformInfo platform_info_;
|
||||
|
||||
bool must_compile_;
|
||||
const bool must_compile_;
|
||||
|
||||
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
|
||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||
// on any future calls to _XlaCompile.
|
||||
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
|
||||
|
||||
mutex cannot_compile_cluster_mu_;
|
||||
};
|
||||
|
||||
class XlaRunOp : public OpKernel {
|
||||
@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
XlaPlatformInfo platform_info_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,7 +22,7 @@ cc_library(
|
||||
hdrs = ["mark_for_compilation_pass_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -34,7 +34,7 @@ cc_library(
|
||||
hdrs = ["xla_device_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -46,7 +46,7 @@ cc_library(
|
||||
hdrs = ["build_xla_ops_pass_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -58,7 +58,7 @@ cc_library(
|
||||
hdrs = ["xla_ops_common_flags.h"],
|
||||
deps =
|
||||
[
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include <mutex> // NOLINT
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,7 +34,7 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_xla_enable_lazy_compilation",
|
||||
&flags->tf_xla_enable_lazy_compilation, ""),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
xla::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
@ -64,7 +65,18 @@ static void AllocateFlags() {
|
||||
Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
|
||||
"enable fusion of element-wise operations only using XLA when "
|
||||
"global_jit_level is ON*.")});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
xla::ParseFlagsFromEnv(*flag_list);
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Parsed MarkForCompilationPassFlags:";
|
||||
VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
|
||||
VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size;
|
||||
VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size;
|
||||
VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug;
|
||||
VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
|
||||
VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel;
|
||||
VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||
}
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
||||
|
@ -33,7 +33,7 @@ void AppendMarkForCompilationPassFlags(
|
||||
|
||||
// The values of flags associated with the XLA bridge's
|
||||
// mark_for_compilation_pass module.
|
||||
typedef struct {
|
||||
struct MarkForCompilationPassFlags {
|
||||
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
|
||||
// computations on CPU and GPU devices. 0 = use
|
||||
// ConfigProto setting; -1 = off; 1 = on for things
|
||||
@ -55,7 +55,7 @@ typedef struct {
|
||||
// is set to ON* and overrides its behavior. If
|
||||
// true, enable fusion of element-wise operations
|
||||
// only using XLA.
|
||||
} MarkForCompilationPassFlags;
|
||||
};
|
||||
|
||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
@ -41,7 +41,7 @@ static void AllocateFlags() {
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
xla::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Return a pointer to the XlaDeviceFlags struct;
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -35,7 +35,13 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&flags->tf_xla_always_defer_compilation, ""),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
xla::ParseFlagsFromEnv(*flag_list);
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Parsed XlaOpsCommonFlags:";
|
||||
VLOG(1) << " tf_xla_always_defer_compilation = "
|
||||
<< flags->tf_xla_always_defer_compilation;
|
||||
}
|
||||
}
|
||||
|
||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
||||
|
@ -61,8 +61,23 @@ struct OperationFilter {
|
||||
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
|
||||
// auto-clustering stateful RNG ops.
|
||||
bool allow_stateful_rng_ops;
|
||||
|
||||
// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
|
||||
// to cluster ControlTrigger because of how we use deadness analysis.
|
||||
bool allow_control_trigger;
|
||||
|
||||
// Whether ops with dummy implementations are allowed. We avoid
|
||||
// auto-clustering these ops so that the user is not surprised when XLA is
|
||||
// implicitly enabled. If the user explicitly specifies to use XLA, it is fine
|
||||
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
|
||||
// have dummy XLA implementations.
|
||||
bool allow_dummy_ops;
|
||||
};
|
||||
|
||||
bool IsDummyImplOp(absl::string_view op_name) {
|
||||
return op_name == "Assert" || op_name == "CheckNumerics";
|
||||
}
|
||||
|
||||
bool IsStatefulRandomOp(absl::string_view op_name) {
|
||||
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
||||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
||||
@ -225,6 +240,12 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
IsStatefulRandomOp(node->type_string())) {
|
||||
return false;
|
||||
}
|
||||
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
|
||||
return false;
|
||||
}
|
||||
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
|
||||
return false;
|
||||
}
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
|
||||
lib_runtime)) {
|
||||
@ -452,7 +473,14 @@ Status FindCompilationCandidates(
|
||||
|
||||
OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops = registration->compile_resource_ops;
|
||||
op_filter.allow_stateful_rng_ops = registration->requires_compilation;
|
||||
op_filter.allow_stateful_rng_ops =
|
||||
(registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
||||
op_filter.allow_control_trigger =
|
||||
(registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
||||
op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
||||
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
|
||||
@ -467,6 +495,15 @@ Status FindCompilationCandidates(
|
||||
VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
|
||||
continue;
|
||||
}
|
||||
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op";
|
||||
continue;
|
||||
}
|
||||
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": dummy op ("
|
||||
<< node->type_string() << ")";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!op_filter.allow_resource_ops &&
|
||||
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
|
||||
@ -597,11 +634,14 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
®istration));
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
|
||||
// We can always *compile* resource operations and stateful RNGs, even if we
|
||||
// are sometimes unable to auto-cluster them.
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_dummy_ops = true;
|
||||
|
||||
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
|
||||
}
|
||||
|
||||
@ -613,10 +653,8 @@ Status MarkForCompilationPass::Run(
|
||||
GetGlobalJitLevel(options);
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
bool fusion_only = flags->tf_xla_fusion_only;
|
||||
|
||||
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
|
||||
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||
VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
@ -635,9 +673,6 @@ Status MarkForCompilationPass::Run(
|
||||
return false;
|
||||
}
|
||||
|
||||
// If this device requires a JIT, we must say yes.
|
||||
if (registration->requires_compilation) return true;
|
||||
|
||||
// If there is a _XlaCompile annotation, use its value.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
|
||||
@ -674,18 +709,21 @@ Status MarkForCompilationPass::Run(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise use the value of global_jit_level.
|
||||
// Ignore enable_jit_by_default if global jit compilation for CPU
|
||||
// is explicitly requested via tf_xla_cpu_global_jit flag
|
||||
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
|
||||
// Otherwise use the value of global_jit_level and the device's
|
||||
// autoclustering policy.
|
||||
bool should_compile =
|
||||
(ignore_registration || registration->enable_jit_by_default) &&
|
||||
global_jit_level != OptimizerOptions::OFF;
|
||||
registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways ||
|
||||
(registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
|
||||
global_jit_level != OptimizerOptions::OFF);
|
||||
if (!should_compile) {
|
||||
if (global_jit_level == OptimizerOptions::OFF) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
|
||||
} else {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
|
||||
VLOG(2)
|
||||
<< "Rejecting " << node->name()
|
||||
<< ": autoclustering for device only when requested explicitly.";
|
||||
}
|
||||
}
|
||||
return should_compile;
|
||||
@ -1073,12 +1111,10 @@ Status MarkForCompilationPass::RunImpl(
|
||||
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration);
|
||||
|
||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
||||
// Also, always compile if the operator is placed on a device that requires
|
||||
// compilation, or if it contains at least one op that is marked for
|
||||
// Also, always compile if it contains at least one op that is marked for
|
||||
// compilation that is not an Identity op.
|
||||
if (effective_cluster_sizes[cluster] >= min_cluster_size ||
|
||||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation) ||
|
||||
registration->requires_compilation) {
|
||||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) {
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
|
@ -817,14 +817,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
|
||||
ASSERT_FALSE(clusters.empty());
|
||||
string cluster_name = clusters.begin()->second;
|
||||
|
||||
// ctrl_trigger_a has inputs with mismatching deadness so it won't be
|
||||
// clustered. ctrl_trigger_b is okay to cluster.
|
||||
std::unordered_map<string, string> expected_clusters(
|
||||
{{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}});
|
||||
EXPECT_EQ(clusters, expected_clusters);
|
||||
// TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
|
||||
// it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
|
||||
// cluster it because of b/118970344.
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, RandomShape) {
|
||||
@ -923,9 +919,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["test/shape_rng"], "");
|
||||
EXPECT_NE(clusters["test/reshape"], "");
|
||||
EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
|
||||
EXPECT_EQ(clusters["test/shape_rng"], "");
|
||||
EXPECT_EQ(clusters["test/reshape"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
|
||||
@ -1088,7 +1083,7 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
|
||||
EXPECT_NE(clusters["test/c"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) {
|
||||
TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
|
||||
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
|
||||
@ -1104,5 +1099,53 @@ TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) {
|
||||
EXPECT_EQ(clusters["test/a"], "");
|
||||
EXPECT_EQ(clusters["test/b"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
|
||||
absl::string_view xla_cpu_device =
|
||||
"/job:worker/replica:0/task:0/device:XLA_CPU:0";
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||
Output check =
|
||||
ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
|
||||
Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
|
||||
Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
|
||||
n->set_assigned_device_name(string(xla_cpu_device));
|
||||
}
|
||||
}
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["test/check"], "");
|
||||
EXPECT_NE(clusters["test/greaterequal"], "");
|
||||
EXPECT_NE(clusters["test/assert"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||
Output check =
|
||||
ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
|
||||
Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
|
||||
Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(clusters["test/assert"], "");
|
||||
EXPECT_EQ(clusters["test/check"], "");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -133,6 +133,10 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
|
||||
graph->RemoveEdge(out_edge_to_clone);
|
||||
}
|
||||
|
||||
if (n->out_edges().empty()) {
|
||||
graph->RemoveNode(n);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -191,6 +195,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
|
||||
}
|
||||
}
|
||||
|
||||
// Recompute post order since PartiallyDeclusterNode may have deleted nodes.
|
||||
post_order.clear();
|
||||
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/NotBackedge);
|
||||
nodes_to_partially_decluster.clear();
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
|
||||
@ -210,7 +218,8 @@ bool IsIntraClusterEdge(const Edge& edge) {
|
||||
bool IsMustCompileDevice(const DeviceType& device_type) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return registration->requires_compilation;
|
||||
return registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -437,5 +437,32 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
|
||||
const char* const kClusteredProducer0Name = "ClusteredProducer0";
|
||||
const char* const kClusteredProducer1Name = "ClusteredProducer1";
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* input =
|
||||
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
|
||||
Node* clustered_producer_0 =
|
||||
ops::BinaryOp("FakeBinary", input, input,
|
||||
builder.opts().WithName(kClusteredProducer0Name));
|
||||
Node* clustered_producer_1 =
|
||||
ops::BinaryOp("FakeBinary", clustered_producer_0, input,
|
||||
builder.opts().WithName(kClusteredProducer1Name));
|
||||
ops::BinaryOp("FakeBinary", clustered_producer_1, input,
|
||||
builder.opts().WithName("UnclusteredConsumer"));
|
||||
clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
|
||||
clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr);
|
||||
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,132 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
||||
|
||||
#include <deque>
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A thread-safe, first-in-first-out queue.
|
||||
template <typename T>
|
||||
class ProducerConsumerQueue {
|
||||
public:
|
||||
ProducerConsumerQueue()
|
||||
: capacity_(std::numeric_limits<std::size_t>::max()) {}
|
||||
~ProducerConsumerQueue() = default;
|
||||
|
||||
// Wait until the queue is non-full, then append a copy of v.
|
||||
void Put(const T &v);
|
||||
|
||||
// Wait until the queue is non-empty, then remove and return the head value.
|
||||
T Get();
|
||||
|
||||
// If the queue is non-empty, remove the head value, placing it in *pv, and
|
||||
// return true; otherwise return false.
|
||||
bool TryGet(T *pv);
|
||||
|
||||
// Set the capacity of the queue; the queue is full whenever count() >=
|
||||
// capacity(). The initial value is the maximum size_t. Requires size > 0.
|
||||
void set_capacity(std::size_t size);
|
||||
|
||||
// Return the capacity of the queue.
|
||||
std::size_t capacity() const;
|
||||
|
||||
// Return the number of elements in the queue.
|
||||
std::size_t count() const;
|
||||
|
||||
// Implementation details follow. Clients should ignore.
|
||||
private:
|
||||
mutable tensorflow::mutex mu_; // protects all fields below
|
||||
tensorflow::condition_variable non_empty_ GUARDED_BY(mu_);
|
||||
tensorflow::condition_variable non_full_ GUARDED_BY(mu_);
|
||||
std::size_t capacity_ GUARDED_BY(mu_);
|
||||
std::deque<T> queue_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue);
|
||||
};
|
||||
|
||||
// ------------------------------------------------------
|
||||
// Implementation details follow. Clients should ignore.
|
||||
|
||||
// Wait until the queue is non-full, then append a copy of v.
|
||||
template <typename T>
|
||||
void ProducerConsumerQueue<T>::Put(const T &v) {
|
||||
mutex_lock lock(mu_);
|
||||
while (queue_.size() >= capacity_) {
|
||||
non_full_.wait(lock);
|
||||
}
|
||||
queue_.push_back(v);
|
||||
non_empty_.notify_one();
|
||||
}
|
||||
|
||||
// Wait until the queue is non-empty, then remove and return the head value.
|
||||
template <typename T>
|
||||
T ProducerConsumerQueue<T>::Get() {
|
||||
mutex_lock lock(mu_);
|
||||
while (queue_.empty()) {
|
||||
non_empty_.wait(lock);
|
||||
}
|
||||
non_full_.notify_one();
|
||||
T result_value = queue_.front();
|
||||
queue_.pop_front();
|
||||
return result_value;
|
||||
}
|
||||
|
||||
// If the queue is non-empty, remove the head value, placing it in *pv, and
|
||||
// return true; otherwise return false.
|
||||
template <typename T>
|
||||
bool ProducerConsumerQueue<T>::TryGet(T *pv) {
|
||||
mutex_lock lock(mu_);
|
||||
bool got_element = !queue_.empty();
|
||||
if (got_element) {
|
||||
non_full_.notify_one();
|
||||
*pv = queue_.front();
|
||||
queue_.pop_front();
|
||||
}
|
||||
return got_element;
|
||||
}
|
||||
|
||||
// Set the capacity of the queue; the queue is full whenever count() >=
|
||||
// capacity(). The initial value is the maximum size_t. Requires size > 0.
|
||||
template <typename T>
|
||||
void ProducerConsumerQueue<T>::set_capacity(std::size_t size) {
|
||||
mutex_lock lock(mu_);
|
||||
CHECK_NE(size, 0);
|
||||
capacity_ = size;
|
||||
non_full_.notify_all();
|
||||
}
|
||||
|
||||
// Return the capacity of the queue.
|
||||
template <typename T>
|
||||
std::size_t ProducerConsumerQueue<T>::capacity() const {
|
||||
mutex_lock lock(mu_);
|
||||
std::size_t max_elements = capacity_;
|
||||
return max_elements;
|
||||
}
|
||||
|
||||
// Return the number of elements in the queue.
|
||||
template <typename T>
|
||||
std::size_t ProducerConsumerQueue<T>::count() const {
|
||||
mutex_lock lock(mu_);
|
||||
std::size_t num_elements = queue_.size();
|
||||
return num_elements;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
|
@ -1,139 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/producer_consumer_queue.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
typedef ProducerConsumerQueue<int> IntQueue;
|
||||
|
||||
// Insert integers between low inclusive and high exclusive into q.
|
||||
void PushRange(IntQueue *q, int low, int high) {
|
||||
while (low != high) {
|
||||
q->Put(low);
|
||||
VLOG(2) << "Pushing " << low;
|
||||
++low;
|
||||
}
|
||||
}
|
||||
|
||||
// Push the numbers between 0 and 999 inclusive from several threads in the
|
||||
// pool.
|
||||
void PushRanges(IntQueue *queue, thread::ThreadPool *pool) {
|
||||
VLOG(1) << "Adding 20-36";
|
||||
pool->Schedule([queue] { PushRange(queue, 20, 36); });
|
||||
VLOG(1) << "Adding 7-20";
|
||||
pool->Schedule([queue] { PushRange(queue, 7, 20); });
|
||||
VLOG(1) << "Adding 36-501";
|
||||
pool->Schedule([queue] { PushRange(queue, 36, 501); });
|
||||
VLOG(1) << "Adding 501-1000";
|
||||
pool->Schedule([queue] { PushRange(queue, 501, 1000); });
|
||||
VLOG(1) << "Adding 0-5";
|
||||
pool->Schedule([queue] { PushRange(queue, 0, 5); });
|
||||
VLOG(1) << "Adding 5-7";
|
||||
pool->Schedule([queue] { PushRange(queue, 5, 7); });
|
||||
}
|
||||
|
||||
// Pop elements from queue using Get(). Make sure that exactly <high> elements
|
||||
// were present and their values are all integers between 0 and high-1
|
||||
// inclusive.
|
||||
void GetRange(IntQueue *queue, int high) {
|
||||
VLOG(1) << "Testing Wait";
|
||||
std::vector<int> results;
|
||||
for (int i = 0; i != high; ++i) {
|
||||
int r = queue->Get();
|
||||
VLOG(2) << "Waited and got " << r;
|
||||
results.push_back(r);
|
||||
}
|
||||
CHECK_EQ(queue->count(), 0);
|
||||
std::sort(results.begin(), results.end());
|
||||
for (int i = 0; i != high; ++i) {
|
||||
CHECK(results[i] == i);
|
||||
}
|
||||
}
|
||||
|
||||
// Pop elements from queue using TryGet(). Make sure that exactly <high>
|
||||
// elements were present and their values are all integers between 0 and high-1
|
||||
// inclusive.
|
||||
void TryGetRange(IntQueue *queue, int high) {
|
||||
std::vector<int> results;
|
||||
// Give up if we don't get all the elements back from the queue
|
||||
// in 10 seconds.
|
||||
int timeout = 10;
|
||||
int r;
|
||||
for (int i = 0; i != high; ++i) {
|
||||
while (!queue->TryGet(&r)) {
|
||||
if (!timeout--) {
|
||||
LOG(FATAL) << "Can't find all elements in the queue";
|
||||
}
|
||||
VLOG(1) << "Sleeping for a second...";
|
||||
sleep(1);
|
||||
}
|
||||
VLOG(2) << "Popped " << r;
|
||||
results.push_back(r);
|
||||
}
|
||||
CHECK_EQ(queue->count(), 0);
|
||||
CHECK(!queue->TryGet(&r));
|
||||
std::sort(results.begin(), results.end());
|
||||
for (int i = 0; i != high; ++i) {
|
||||
CHECK_EQ(i, results[i]);
|
||||
}
|
||||
}
|
||||
|
||||
const int kNumThreads = 15;
|
||||
|
||||
TEST(ProducerConsumerQueue, GetRange) {
|
||||
IntQueue queue;
|
||||
{
|
||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||
PushRanges(&queue, &pool);
|
||||
}
|
||||
GetRange(&queue, 1000);
|
||||
}
|
||||
|
||||
TEST(ProducerConsumerQueue, TryGetRange) {
|
||||
IntQueue queue;
|
||||
{
|
||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||
PushRanges(&queue, &pool);
|
||||
}
|
||||
TryGetRange(&queue, 1000);
|
||||
}
|
||||
|
||||
TEST(ProducerConsumerQueue, ParallelGetRange) {
|
||||
IntQueue queue;
|
||||
{
|
||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||
pool.Schedule([&queue] { GetRange(&queue, 1000); });
|
||||
PushRanges(&queue, &pool);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ProducerConsumerQueue, ParallelTryGetRange) {
|
||||
IntQueue queue;
|
||||
{
|
||||
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
|
||||
pool.Schedule([&queue] { TryGetRange(&queue, 1000); });
|
||||
PushRanges(&queue, &pool);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -65,14 +66,14 @@ string XlaCompilationCache::DebugString() {
|
||||
|
||||
// Compute a string signature which encodes the shapes of the
|
||||
// arguments in the supplied list.
|
||||
string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
|
||||
string result = sig.name;
|
||||
for (const auto& a : sig.arg_types) {
|
||||
string XlaCompilationCache::Signature::HumanString() const {
|
||||
string result = name;
|
||||
for (const auto& a : arg_types) {
|
||||
absl::StrAppend(&result, ",", DataTypeString(a.first),
|
||||
a.second.DebugString());
|
||||
}
|
||||
|
||||
for (const auto& v : sig.arg_values) {
|
||||
for (const auto& v : arg_values) {
|
||||
absl::StrAppend(&result, "; ", v.DebugString());
|
||||
}
|
||||
return result;
|
||||
@ -84,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
|
||||
|
||||
if (arg_values.size() != other.arg_values.size()) return false;
|
||||
for (int i = 0; i < arg_values.size(); ++i) {
|
||||
if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
|
||||
if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
|
||||
arg_values[i].shape() != other.arg_values[i].shape() ||
|
||||
arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -108,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()(
|
||||
return h;
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::BuildSignature(
|
||||
const NameAttrList& function, const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
Signature* signature) {
|
||||
signature->name = Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
signature->arg_values.reserve(constant_args.size());
|
||||
|
||||
signature->arg_types.reserve(ctx->num_inputs() - constant_args.size());
|
||||
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
if (constant_args.count(i) > 0) {
|
||||
// Use the values of compile time constants in the signature.
|
||||
signature->arg_values.push_back(constant_args.at(i));
|
||||
} else if (variable_args.count(i) > 0) {
|
||||
const OptionalTensor& variable = variable_args.at(i);
|
||||
if (variable.present) {
|
||||
signature->arg_types.emplace_back(variable.value.dtype(),
|
||||
variable.value.shape());
|
||||
} else {
|
||||
signature->arg_types.emplace_back(DT_INVALID, TensorShape());
|
||||
}
|
||||
} else {
|
||||
signature->arg_types.emplace_back(ctx->input_dtype(i),
|
||||
ctx->input(i).shape());
|
||||
xla::StatusOr<XlaCompilationCache::Signature>
|
||||
XlaCompilationCache::BuildSignature(
|
||||
const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args) {
|
||||
Signature signature;
|
||||
signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
signature.arg_values.push_back(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
case XlaCompiler::Argument::kResource:
|
||||
signature.arg_types.emplace_back(arg.type, arg.shape);
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
"Unhandled argument kind in XlaCompilationCache: ",
|
||||
arg.HumanString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
return std::move(signature);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op.
|
||||
Status BuildArguments(const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args,
|
||||
OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
if (constant_args.count(input_num) > 0) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = constant_args.at(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_args.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
if (input.NumElements() > 0) {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = input;
|
||||
}
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
} else {
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
const OptionalTensor& variable = variable_args.at(input_num);
|
||||
arg.name = variable.name;
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.present) {
|
||||
const Tensor& value = variable.value;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
// they are meaningless. However, it is legal to assign to a resource
|
||||
// variable for the first time inside the XLA computation, so we do
|
||||
// permit uninitialized variables.
|
||||
arg.initialized = false;
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompilationCache::BuildExecutable(
|
||||
const XlaCompiler::Options& options,
|
||||
const XlaCompiler::CompilationResult& result,
|
||||
@ -227,25 +164,38 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
|
||||
Status XlaCompilationCache::Compile(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
// Set the compile threshold to 1 to implement CompileMode::kStrict.
|
||||
int64 compile_threshold =
|
||||
compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1;
|
||||
return CompileImpl(options, function, constant_args, variable_args, ctx,
|
||||
compile_options, /*compile_single_op=*/false,
|
||||
absl::optional<int64> compile_threshold;
|
||||
if (compile_mode == CompileMode::kLazy) {
|
||||
compile_threshold = kDefaultCompilationThreshold;
|
||||
}
|
||||
auto compile_fn = [&](XlaCompiler* compiler,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
return compiler->CompileFunction(compile_options, function, args, result);
|
||||
};
|
||||
return CompileImpl(options, function, args, compile_fn,
|
||||
/*compile_threshold=*/compile_threshold,
|
||||
out_compilation_result, out_executable);
|
||||
}
|
||||
|
||||
static bool IsMegamorphic(int64 compile_count, int64 execution_count) {
|
||||
const int64 kCompileThreshold = 10;
|
||||
const int64 kMinExecutionsPerCompile = 50;
|
||||
|
||||
// This heuristic is trying to capture the following property: have we sunk a
|
||||
// certain minimum amount of compile time into the cluster that didn't quite
|
||||
// "pay off"?
|
||||
return compile_count > kCompileThreshold &&
|
||||
execution_count < kMinExecutionsPerCompile * compile_count;
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
@ -253,54 +203,41 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
NameAttrList name;
|
||||
name.set_name(def.op());
|
||||
*name.mutable_attr() = def.attr();
|
||||
return CompileImpl(options, name, constant_args, variable_args, ctx,
|
||||
compile_options,
|
||||
/*compile_single_op=*/true, /*compile_threshold=*/1,
|
||||
auto compile_op = [&](XlaCompiler* compiler,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
std::vector<DataType> result_dtypes(ctx->num_outputs());
|
||||
for (int i = 0; i < result_dtypes.size(); ++i) {
|
||||
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||
}
|
||||
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
|
||||
args, result_dtypes, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
out_compilation_result, out_executable);
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options, bool compile_single_op,
|
||||
int64 compile_threshold,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn,
|
||||
absl::optional<int64> compile_threshold,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
DCHECK_NE(out_executable, nullptr);
|
||||
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "num_inputs=" << ctx->num_inputs()
|
||||
<< " num_constant_args=" << constant_args.size()
|
||||
<< " num_variable_args=" << variable_args.size();
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
TensorShape shape = ctx->input(i).shape();
|
||||
VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i))
|
||||
<< " present=" << ctx->has_input(i)
|
||||
<< " shape=" << shape.DebugString();
|
||||
}
|
||||
for (auto& iterator : variable_args) {
|
||||
const OptionalTensor& variable = iterator.second;
|
||||
VLOG(2) << "variable present=" << variable.present
|
||||
<< " type=" << DataTypeString(variable.value.dtype())
|
||||
<< " shape=" << variable.value.shape().DebugString()
|
||||
<< " TF arg= " << iterator.first;
|
||||
}
|
||||
VLOG(2) << "num_outputs = " << ctx->num_outputs();
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i);
|
||||
VLOG(2) << "num_inputs=" << args.size();
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
VLOG(2) << i << ": " << args[i].HumanString();
|
||||
}
|
||||
}
|
||||
|
||||
TF_RET_CHECK(constant_args.size() + variable_args.size() <=
|
||||
ctx->num_inputs());
|
||||
TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
|
||||
VLOG(2) << "Signature: " << signature.HumanString();
|
||||
|
||||
Signature signature;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildSignature(function, constant_args, variable_args, ctx, &signature));
|
||||
|
||||
VLOG(2) << "Signature: " << SignatureDebugString(signature);
|
||||
// The outer lock protects the existence of the cache entry. It does not
|
||||
// protect the contents of the cache entry.
|
||||
Entry* entry;
|
||||
@ -319,25 +256,67 @@ Status XlaCompilationCache::CompileImpl(
|
||||
// (since they get the benefit of XLA right away without waiting for warmup)
|
||||
// and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
|
||||
// most one cluster-compilation's worth of compile time).
|
||||
bool is_first_execution = [&] {
|
||||
bool is_first_execution;
|
||||
|
||||
// We avoid compiling clusters that have "gone megamorphic" i.e. have an
|
||||
// excessive amount of shape dynamism.
|
||||
bool is_megamorphic;
|
||||
|
||||
{
|
||||
mutex_lock lock(cluster_compile_stats_mu_);
|
||||
auto it =
|
||||
cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
|
||||
.first;
|
||||
return it->second.execution_count++ == 0;
|
||||
}();
|
||||
is_first_execution = it->second.execution_count++ == 0;
|
||||
|
||||
// The is_megamorphic bit is "sticky". We assume clusters that have been
|
||||
// observed to be megamorphic once stay megamorphic forever.
|
||||
it->second.is_megamorphic |=
|
||||
IsMegamorphic(/*compile_count=*/it->second.compile_count,
|
||||
/*execution_count=*/it->second.execution_count);
|
||||
is_megamorphic = it->second.is_megamorphic;
|
||||
}
|
||||
|
||||
// Acquire the cache entry lock and compile, if necessary.
|
||||
// TODO(phawkins): this locking will need to be restructured when we implement
|
||||
// cache eviction.
|
||||
mutex_lock entry_lock(entry->mu);
|
||||
int64 current_request_count = ++entry->request_count;
|
||||
if (!entry->compiled) {
|
||||
VLOG(2) << "Compilation cache miss for signature: "
|
||||
<< SignatureDebugString(signature) << " with request count "
|
||||
VLOG(2) << "Compilation cache entry hit: " << entry->compiled
|
||||
<< " signature: " << signature.HumanString() << " with request count "
|
||||
<< current_request_count << " and compile threshold "
|
||||
<< compile_threshold;
|
||||
if (!is_first_execution && current_request_count < compile_threshold) {
|
||||
<< compile_threshold.value_or(0);
|
||||
if (!entry->compiled) {
|
||||
const bool should_compile = [&] {
|
||||
if (!compile_threshold.has_value()) {
|
||||
// Lazy compilation is disabled.
|
||||
return true;
|
||||
}
|
||||
|
||||
if (is_megamorphic) {
|
||||
VLOG(3) << "Not compiling cluster " << function.name()
|
||||
<< " because it is megamorphic.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_first_execution) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool reached_compile_threshold =
|
||||
current_request_count >= *compile_threshold;
|
||||
if (!reached_compile_threshold) {
|
||||
VLOG(3)
|
||||
<< "Not compiling cluster " << function.name()
|
||||
<< " because it has not reached compile threshold; threshold is "
|
||||
<< *compile_threshold << " execution count "
|
||||
<< current_request_count << ".";
|
||||
}
|
||||
return reached_compile_threshold;
|
||||
}();
|
||||
|
||||
if (!should_compile) {
|
||||
VLOG(2) << "Not compiling for signature: " << signature.HumanString();
|
||||
*out_compilation_result = nullptr;
|
||||
*out_executable = nullptr;
|
||||
return Status::OK();
|
||||
@ -347,21 +326,12 @@ Status XlaCompilationCache::CompileImpl(
|
||||
const uint64 compile_start_us = env->NowMicros();
|
||||
// Do the actual JIT compilation without holding the lock (it can take
|
||||
// a long time.)
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildArguments(constant_args, variable_args, ctx, &args));
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
entry->compiled = true;
|
||||
|
||||
if (compile_single_op) {
|
||||
entry->compilation_status =
|
||||
compiler.CompileSingleOp(compile_options, signature.name, ctx, args,
|
||||
&entry->compilation_result);
|
||||
} else {
|
||||
entry->compilation_status = compiler.CompileFunction(
|
||||
compile_options, function, args, &entry->compilation_result);
|
||||
}
|
||||
compile_fn(&compiler, &entry->compilation_result);
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
CHECK_EQ(entry->executable.get(), nullptr);
|
||||
entry->compilation_status =
|
||||
|
@ -17,9 +17,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -30,13 +33,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
string name; // A descriptive name
|
||||
bool present = false; // Is the tensor present?
|
||||
Tensor value; // If present, what is the Tensor's value?
|
||||
};
|
||||
|
||||
// The XlaCompilationCache class caches the results of the XlaCompiler class,
|
||||
// which converts a Tensorflow graph into a compiled XLA compilation.
|
||||
//
|
||||
@ -58,11 +54,7 @@ class XlaCompilationCache : public ResourceBase {
|
||||
// Compiles a function into a XlaCompiler::CompilationResult that can be used
|
||||
// to execute an XLA Computation. Compilation results are cached.
|
||||
// `function` is the name of a Tensorflow function to compile.
|
||||
// `constant_args` is a map of tensorflow argument number to its constant
|
||||
// value.
|
||||
// `variable_args` is a snapshot of the current values of the
|
||||
// resource variable arguments to `function`; uninitialized variables are
|
||||
// represented by an absent OptionalTensor.
|
||||
// `args` is a description of the arguments to the computation.
|
||||
//
|
||||
// `compile_mode` controls the behavior of the compilation cache on a cache
|
||||
// miss. If `compile_mode` is `kLazy` then, based on some profitability
|
||||
@ -78,9 +70,7 @@ class XlaCompilationCache : public ResourceBase {
|
||||
// outputs.
|
||||
Status Compile(const XlaCompiler::Options& options,
|
||||
const NameAttrList& function,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args,
|
||||
OpKernelContext* ctx,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
@ -90,8 +80,7 @@ class XlaCompilationCache : public ResourceBase {
|
||||
// XlaCompiler::CompileFunction.
|
||||
Status CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable);
|
||||
@ -101,26 +90,6 @@ class XlaCompilationCache : public ResourceBase {
|
||||
|
||||
string DebugString() override;
|
||||
|
||||
private:
|
||||
// Common implementation of Compile and CompileSingleOp.
|
||||
Status CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
bool compile_single_op, int64 compile_threshold,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable);
|
||||
|
||||
// Takes `result` which has been compiled from a Tensorflow subgraph to a
|
||||
// XLA computation already, and generates an XLA LocalExecutable `executable`.
|
||||
Status BuildExecutable(const XlaCompiler::Options& options,
|
||||
const XlaCompiler::CompilationResult& result,
|
||||
std::unique_ptr<xla::LocalExecutable>* executable);
|
||||
|
||||
xla::LocalClient* const client_;
|
||||
const DeviceType device_type_;
|
||||
|
||||
// Describes the types, shapes and any compile-time constant arguments
|
||||
// to a kernel. Key that uniquely identifies a compilation output.
|
||||
struct Signature {
|
||||
@ -137,14 +106,35 @@ class XlaCompilationCache : public ResourceBase {
|
||||
struct Hash {
|
||||
uint64 operator()(const Signature& signature) const;
|
||||
};
|
||||
|
||||
// Returns a human-readable description of the signature.
|
||||
string HumanString() const;
|
||||
};
|
||||
static string SignatureDebugString(const Signature& sig);
|
||||
|
||||
// Builds the signature for a compilation.
|
||||
Status BuildSignature(const NameAttrList& function,
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args,
|
||||
OpKernelContext* ctx, Signature* signature);
|
||||
static xla::StatusOr<Signature> BuildSignature(
|
||||
const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args);
|
||||
|
||||
private:
|
||||
// Common implementation of Compile and CompileSingleOp.
|
||||
Status CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn,
|
||||
absl::optional<int64> compile_threshold,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable);
|
||||
|
||||
// Takes `result` which has been compiled from a Tensorflow subgraph to a
|
||||
// XLA computation already, and generates an XLA LocalExecutable `executable`.
|
||||
Status BuildExecutable(const XlaCompiler::Options& options,
|
||||
const XlaCompiler::CompilationResult& result,
|
||||
std::unique_ptr<xla::LocalExecutable>* executable);
|
||||
|
||||
xla::LocalClient* const client_;
|
||||
const DeviceType device_type_;
|
||||
|
||||
// The value associated with a cache entry.
|
||||
struct Entry {
|
||||
@ -180,7 +170,13 @@ class XlaCompilationCache : public ResourceBase {
|
||||
|
||||
// Cumulative time spent compiling the cluster.
|
||||
int64 cumulative_compile_time_us = 0;
|
||||
|
||||
// True if we have decided that this cluster is too dynamic (i.e. its shapes
|
||||
// change too frequently) to profitably JIT compile. Once a cluster is
|
||||
// tagged megamorphic, it stays megamorphic forever.
|
||||
bool is_megamorphic = false;
|
||||
};
|
||||
|
||||
mutex cluster_compile_stats_mu_;
|
||||
|
||||
// Maps cluster names to compilation statistics for said cluster.
|
||||
|
54
tensorflow/compiler/jit/xla_compilation_cache_test.cc
Normal file
54
tensorflow/compiler/jit/xla_compilation_cache_test.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(XlaCompilationCacheTest, SignatureEquality) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kConstant;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({4, 0});
|
||||
args[0].constant_value = Tensor(DT_INT32, {4, 0});
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1,
|
||||
XlaCompilationCache::BuildSignature(fn, args));
|
||||
|
||||
args[0].type = DT_FLOAT;
|
||||
args[0].constant_value = Tensor(DT_FLOAT, {4, 0});
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2,
|
||||
XlaCompilationCache::BuildSignature(fn, args));
|
||||
|
||||
args[0].shape = TensorShape({0, 4});
|
||||
args[0].constant_value = Tensor(DT_FLOAT, {0, 4});
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3,
|
||||
XlaCompilationCache::BuildSignature(fn, args));
|
||||
|
||||
std::vector<XlaCompilationCache::Signature> signatures = {s1, s2, s3};
|
||||
for (int i = 0; i < signatures.size(); ++i) {
|
||||
for (int j = 0; j < signatures.size(); ++j) {
|
||||
EXPECT_EQ(i == j, signatures[i] == signatures[j])
|
||||
<< signatures[i].HumanString() << " " << signatures[j].HumanString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -187,8 +187,13 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
|
||||
compile_options, result, executable);
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_args, ctx, &args));
|
||||
|
||||
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
}
|
||||
|
||||
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
|
@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
registration.requires_compilation = !compile_on_demand;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
compile_on_demand
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
|
@ -446,7 +446,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
|
||||
// it just-in-time.
|
||||
kernel_factory::OpKernelRegistrar::Factory factory =
|
||||
OpKernel* (*factory)(OpKernelConstruction*) =
|
||||
[](OpKernelConstruction* context) -> OpKernel* {
|
||||
return new XlaCompileOnDemandOp(context);
|
||||
};
|
||||
|
@ -112,6 +112,12 @@ class XlaDevice : public LocalDevice {
|
||||
// compute, host-to-device, and device-to-host communication.
|
||||
bool use_multiple_streams = false;
|
||||
|
||||
// A function that describes how the on-host shapes of
|
||||
// a) argument and return value, for entry computations
|
||||
// b) variables, for all computations,
|
||||
// should be represented in XLA. Parameters/return values will be shaped
|
||||
// according to this function, and reshaped back to/from their declared
|
||||
// shapes for computations. Must be non-null.
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn;
|
||||
|
||||
// If padded_shape_fn is empty, a default implementation that returns
|
||||
|
@ -70,9 +70,12 @@ XlaDeviceContext::XlaDeviceContext(
|
||||
CHECK(device_to_host_stream_ != nullptr);
|
||||
CHECK(stream_ != nullptr);
|
||||
if (!shape_representation_fn_) {
|
||||
shape_representation_fn_ =
|
||||
[](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<TensorShape> { return shape; };
|
||||
shape_representation_fn_ = [](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,7 +102,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
CHECK(xla_tensor);
|
||||
|
||||
Status status = [&]() -> Status {
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape,
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape,
|
||||
shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype()));
|
||||
|
||||
@ -111,9 +114,15 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal()));
|
||||
|
||||
// The cpu_tensor and literal that we created here hold the data of host
|
||||
// tensor in descending layout. The layout could be different from layout in
|
||||
// device_tensor (but the logical shape has to be the same). The
|
||||
// transfer_manager is responsible to do corresponding transposing when
|
||||
// transferring the data to device.
|
||||
xla::BorrowingLiteral literal(
|
||||
static_cast<const char*>(DMAHelper::base(cpu_tensor)),
|
||||
xla_tensor->shaped_buffer().on_host_shape());
|
||||
xla::ShapeUtil::MakeShape(shape.element_type(),
|
||||
xla::AsInt64Slice(shape.dimensions())));
|
||||
|
||||
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
@ -183,8 +192,15 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get());
|
||||
|
||||
// Transfer manager requires the shape of the shaped buffer to be the same as
|
||||
// literal shape except for the layout. Set the literal to use xla_tensor's
|
||||
// shape as it is derived from the cpu_tensor's shape using
|
||||
// shape_representation_fn_.
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal));
|
||||
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
|
||||
xla::LayoutUtil::GetWithDefaultLayout(
|
||||
xla_tensor->shaped_buffer().on_host_shape()),
|
||||
cpu_tensor, &literal));
|
||||
|
||||
TensorReference ref(*device_tensor);
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp);
|
||||
RetvalOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("StackV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("max_size") \
|
||||
.HostMemory("handle"), \
|
||||
StackOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("elem_type", TYPES), \
|
||||
StackPopOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
// and write the tensors they access in order to concatenate them into a batch.
|
||||
// We would need either to call out to an XLA computation to perform the
|
||||
|
@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
std::vector<Device*>* devices) {
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.requires_compilation = true;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
@ -53,24 +53,25 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
|
||||
XlaDevice::Options options;
|
||||
options.platform = platform.ValueOrDie();
|
||||
options.device_name_prefix = name_prefix;
|
||||
options.device_name = DEVICE_XLA_GPU;
|
||||
options.device_ordinal = 0;
|
||||
options.device_ordinal = i;
|
||||
options.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
options.use_multiple_streams = false;
|
||||
options.use_multiple_streams = true;
|
||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
||||
|
||||
// TODO(b/78468222): Uncomment after fixing this bug
|
||||
// status = device->UseGpuDeviceInfo();
|
||||
// if (!status.ok()) {
|
||||
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
|
||||
// " device");
|
||||
// return status;
|
||||
// }
|
||||
Status status = device->UseGpuDeviceInfo();
|
||||
if (!status.ok()) {
|
||||
errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
|
||||
" device number ", i);
|
||||
return status;
|
||||
}
|
||||
|
||||
devices->push_back(device.release());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||
registration.requires_compilation = true;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
@ -191,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
// Return the 'index''th subtree of the given ShapedBuffer as a
|
||||
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
|
||||
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
|
||||
ScopedShapedBuffer ExtractSubShapedBuffer(
|
||||
ShapedBuffer* shaped_buffer, int index,
|
||||
xla::DeviceMemoryAllocator* allocator) {
|
||||
const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape(
|
||||
shaped_buffer->on_host_shape(), index);
|
||||
const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape(
|
||||
shaped_buffer->on_device_shape(), index);
|
||||
|
||||
ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
|
||||
shaped_buffer->platform(),
|
||||
shaped_buffer->device_ordinal());
|
||||
|
||||
auto& shape_tree = shaped_buffer->buffers();
|
||||
auto& sub_shape_tree = sub_shaped_buffer.buffers();
|
||||
sub_shape_tree.CopySubtreeFrom(shape_tree,
|
||||
/*source_base_index=*/{index},
|
||||
/*target_base_index=*/{});
|
||||
shape_tree.ForEachMutableElement(
|
||||
[index](const xla::ShapeIndex& shape_index,
|
||||
tensorflow::se::DeviceMemoryBase* data) {
|
||||
// shape_index is empty for the root node. Ignore that.
|
||||
if (!shape_index.empty() && shape_index[0] == index) {
|
||||
*data = tensorflow::se::DeviceMemoryBase(nullptr, 0);
|
||||
}
|
||||
});
|
||||
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
|
||||
}
|
||||
} // namespace internal
|
||||
using internal::ExtractSubShapedBuffer;
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
@ -391,8 +357,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
if (xla_tensor) {
|
||||
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
||||
xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
|
||||
if (use_multiple_streams_) {
|
||||
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||
}
|
||||
@ -445,7 +410,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
@ -455,18 +419,20 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
Tensor output_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(write.type, write.shape, &output_tensor));
|
||||
if (write.shape.num_elements() > 0) {
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->set_shaped_buffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
|
||||
xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
|
||||
if (use_multiple_streams_) {
|
||||
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||
}
|
||||
}
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
}
|
||||
++output_num;
|
||||
@ -474,4 +440,60 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
if (constant_args.count(input_num) > 0) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = constant_args.at(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_args.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
if (input.NumElements() > 0) {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = input;
|
||||
}
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
} else {
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
const OptionalTensor& variable = variable_args.at(input_num);
|
||||
arg.name = variable.name;
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.present) {
|
||||
const Tensor& value = variable.value;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
// they are meaningless. However, it is legal to assign to a resource
|
||||
// variable for the first time inside the XLA computation, so we do
|
||||
// permit uninitialized variables.
|
||||
arg.initialized = false;
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
class XlaAllocator;
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
string name; // A descriptive name
|
||||
bool present = false; // Is the tensor present?
|
||||
Tensor value; // If present, what is the Tensor's value?
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
@ -139,6 +146,13 @@ class XlaComputationLaunchContext {
|
||||
bool allocate_xla_tensors,
|
||||
bool use_multiple_streams);
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
// op.
|
||||
static Status BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
//
|
||||
@ -223,17 +237,6 @@ class XlaTensorBuffer : public TensorBuffer {
|
||||
Allocator* allocator_;
|
||||
};
|
||||
|
||||
// Exposed in this header file for microbenchmarking purposes, but this is an
|
||||
// internal implementation detail.
|
||||
namespace internal {
|
||||
// Return the 'index''th subtree of the given ShapedBuffer as a
|
||||
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
|
||||
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
|
||||
xla::ScopedShapedBuffer ExtractSubShapedBuffer(
|
||||
xla::ShapedBuffer* shaped_buffer, int index,
|
||||
xla::DeviceMemoryAllocator* allocator);
|
||||
} // namespace internal
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
|
||||
|
@ -1,64 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Contains microbenchmarks for performance critical functions in
|
||||
// xla_launch_util.cc.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs
|
||||
// (cardinality of each non-leaf node's children).
|
||||
void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
|
||||
tensorflow::testing::StopTiming();
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
std::vector<xla::Shape> shapes(fan_out, shape);
|
||||
shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr,
|
||||
/*device_ordinal=*/0);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
// Extract a buffer from approximately the middle of the first level of the
|
||||
// tree.
|
||||
(void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
|
||||
/*index=*/fan_out / 2,
|
||||
/*allocator=*/nullptr)
|
||||
.release();
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ExtractSubBuffer)
|
||||
->ArgPair(1, 4)
|
||||
->ArgPair(1, 8)
|
||||
->ArgPair(1, 32)
|
||||
->ArgPair(1, 64)
|
||||
->ArgPair(1, 128)
|
||||
->ArgPair(1, 256)
|
||||
->ArgPair(1, 512)
|
||||
->ArgPair(2, 4)
|
||||
->ArgPair(2, 8)
|
||||
->ArgPair(2, 32)
|
||||
->ArgPair(2, 64)
|
||||
->ArgPair(2, 128);
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
tensorflow::testing::RunBenchmarks();
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -43,11 +43,10 @@ namespace tensorflow {
|
||||
}
|
||||
}
|
||||
|
||||
Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
|
||||
Status XlaTensor::AllocateShapedBuffer(DataType dtype,
|
||||
const xla::Shape& on_host_shape,
|
||||
xla::LocalClient* client,
|
||||
int device_ordinal) {
|
||||
xla::Shape on_host_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape));
|
||||
xla::Shape on_device_shape =
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||
on_host_shape);
|
||||
|
@ -50,7 +50,7 @@ class XlaTensor {
|
||||
// Assign the internal ShapedBuffer to new memory for the given dtype and
|
||||
// shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it
|
||||
// is replaced and the managed memory deallocated.
|
||||
Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
|
||||
Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape,
|
||||
xla::LocalClient* client, int device_ordinal);
|
||||
|
||||
// Some Tensors can have complex on-device shapes, including tuple shapes. To
|
||||
|
@ -470,12 +470,12 @@ tf_xla_py_test(
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/signal:signal_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:spectral_ops",
|
||||
"//tensorflow/python/ops/signal",
|
||||
],
|
||||
)
|
||||
|
||||
@ -837,8 +837,6 @@ tf_xla_py_test(
|
||||
name = "stack_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
# Stack ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
@ -967,7 +969,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([42], dtype=dtype),
|
||||
np.int32(0),
|
||||
np.array([0], dtype=np.int64),
|
||||
expected=np.array([[42]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
@ -994,15 +996,21 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.int32(3),
|
||||
expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.expand_dims,
|
||||
np.array([[[1, 2], [3, 4]]], dtype=dtype),
|
||||
np.array([2], dtype=np.int64),
|
||||
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
|
||||
|
||||
def testPad(self):
|
||||
for dtype in self.numeric_types:
|
||||
for dtype, pad_type in itertools.product(
|
||||
self.numeric_types, [np.int32, np.int64]):
|
||||
self._testBinary(
|
||||
array_ops.pad,
|
||||
np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=dtype),
|
||||
np.array(
|
||||
[[1, 2], [2, 1]], dtype=np.int32),
|
||||
[[1, 2], [2, 1]], dtype=pad_type),
|
||||
expected=np.array(
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 2, 3, 0],
|
||||
@ -1016,7 +1024,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=dtype),
|
||||
np.array(
|
||||
[[0, 3], [2, 1]], dtype=np.int32),
|
||||
[[0, 3], [2, 1]], dtype=pad_type),
|
||||
expected=np.array(
|
||||
[[7, 7, 1, 2, 3, 7],
|
||||
[7, 7, 4, 5, 6, 7],
|
||||
|
@ -24,10 +24,10 @@ import numpy as np
|
||||
import scipy.signal as sps
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import signal
|
||||
from tensorflow.python.ops import spectral_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
@ -593,6 +593,67 @@ class LazyCompilationTest(test.TestCase):
|
||||
self.assertFalse(
|
||||
InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun"))
|
||||
|
||||
def testIsMegamorphic(self):
|
||||
|
||||
@function.Defun(compiled=True)
|
||||
def CompiledFunction(x):
|
||||
return math_ops.log(x)
|
||||
|
||||
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = CompiledFunction(x)
|
||||
|
||||
# Make the cluster go megamorphic by running it with lots of shape
|
||||
# signatures where the cluster is executed with each signature only a few
|
||||
# times. Then check that we don't compile the cluster ever again.
|
||||
|
||||
for shape in range(10, 50):
|
||||
for _ in range(0, 49):
|
||||
sess.run(y, feed_dict={x: [0.] * shape})
|
||||
|
||||
for _ in range(0, 50):
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(
|
||||
y,
|
||||
feed_dict={x: [0.] * 60},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assertTrue(
|
||||
InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
|
||||
self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
|
||||
|
||||
def testIsNotMegamorphic(self):
|
||||
|
||||
@function.Defun(compiled=True)
|
||||
def CompiledFunction(x):
|
||||
return math_ops.log(x)
|
||||
|
||||
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = CompiledFunction(x)
|
||||
|
||||
# Run the cluster with lots of shape signatures, but in a way that it
|
||||
# isn't megamorphic (i.e. each shape signature sees a lot of executions).
|
||||
# Then check that the cluster has not been marked as megamorphic.
|
||||
|
||||
for shape in range(10, 50):
|
||||
for _ in range(0, 1000):
|
||||
sess.run(y, feed_dict={x: [0.] * shape})
|
||||
|
||||
for _ in range(0, 10):
|
||||
sess.run(y, feed_dict={x: [0.] * 60})
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(
|
||||
y,
|
||||
feed_dict={x: [0.] * 60},
|
||||
run_metadata=run_metadata,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE))
|
||||
self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
|
||||
self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " +
|
||||
|
@ -2466,20 +2466,21 @@ TEST_F(OpTest, Pack) {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO(b/31741898): crashes on GPU.
|
||||
TEST_F(OpTest, Pad) {
|
||||
Repeatedly([this]() {
|
||||
auto type = Choose<DataType>(kAllXlaTypes);
|
||||
std::vector<int64> t_dims = RandomDims();
|
||||
|
||||
// TODO(b/31741996): re-enable DT_INT64 when bug is fixed.
|
||||
// DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
|
||||
DataType tpaddings = DT_INT32;
|
||||
DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
|
||||
std::vector<int64> paddings_vec;
|
||||
std::uniform_int_distribution<int> distribution(0, 7);
|
||||
for (int i = 0; i < t_dims.size(); ++i) {
|
||||
paddings_vec.push_back(distribution(generator()));
|
||||
paddings_vec.push_back(distribution(generator()));
|
||||
std::uniform_int_distribution<int> pad_distribution(0, t_dims[i]);
|
||||
int pad_size = pad_distribution(generator());
|
||||
std::uniform_int_distribution<int> lower_distribution(0, pad_size);
|
||||
int low_pad_size = lower_distribution(generator());
|
||||
paddings_vec.push_back(low_pad_size);
|
||||
paddings_vec.push_back(pad_size - low_pad_size);
|
||||
t_dims[i] -= pad_size;
|
||||
}
|
||||
Tensor paddings;
|
||||
CHECK(
|
||||
|
@ -37,7 +37,7 @@ class ResamplerOpsTest(xla_test.XLATestCase):
|
||||
out = sess.run(resampled, {input_image: image_np, warp: warp_np})
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2)
|
||||
expected, out, rtol=5e-3, half_rtol=1e-2, bfloat16_rtol=3e-2)
|
||||
|
||||
def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np,
|
||||
expected_grad_data, expected_grad_warp):
|
||||
|
@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
|
||||
class VariableOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for resource variable operators."""
|
||||
|
||||
def testWriteEmptyShape(self):
|
||||
# Verifies that we can pass an uninitialized variable with an empty shape,
|
||||
# assign it a value, and successfully return it.
|
||||
for dtype in self.numeric_types:
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
zeros = np.zeros([3, 0], dtype=dtype)
|
||||
v = resource_variable_ops.ResourceVariable(zeros)
|
||||
p = array_ops.placeholder(dtype)
|
||||
x = v.assign(p)
|
||||
with ops.control_dependencies([x]):
|
||||
y = v.read_value()
|
||||
self.assertAllClose(zeros, sess.run(y, {p: zeros}))
|
||||
|
||||
def testOneWriteOneOutput(self):
|
||||
# Regression test for a bug where computations with one non-constant
|
||||
# output and one variable update were mishandled.
|
||||
|
@ -166,6 +166,7 @@ cc_library(
|
||||
"xla_compilation_device.cc",
|
||||
"xla_compiler.cc",
|
||||
"xla_context.cc",
|
||||
"xla_expression.cc",
|
||||
"xla_helpers.cc",
|
||||
"xla_op_kernel.cc",
|
||||
"xla_op_registry.cc",
|
||||
@ -180,6 +181,7 @@ cc_library(
|
||||
"xla_compilation_device.h",
|
||||
"xla_compiler.h",
|
||||
"xla_context.h",
|
||||
"xla_expression.h",
|
||||
"xla_helpers.h",
|
||||
"xla_op_kernel.h",
|
||||
"xla_op_registry.h",
|
||||
@ -194,6 +196,7 @@ cc_library(
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -217,6 +220,7 @@ cc_library(
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -362,8 +366,12 @@ tf_cc_test(
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_compiler_test",
|
||||
srcs = ["xla_compiler_test.cc"],
|
||||
srcs = [
|
||||
"xla_compiler_test.cc",
|
||||
"xla_expression_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":side_effect_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
@ -386,6 +394,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -435,7 +444,7 @@ cc_library(
|
||||
"dump_graph.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
@ -41,7 +41,7 @@ static void AllocateFlags() {
|
||||
"Path prefix to which graphs dumped during debugging should be "
|
||||
"written."),
|
||||
});
|
||||
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
|
||||
xla::ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
||||
|
@ -242,8 +242,6 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
continue;
|
||||
}
|
||||
const string func_attr = it->second;
|
||||
if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
|
||||
kNodeTypeToFunctionAttrMapping->end()) {
|
||||
NameAttrList func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
|
||||
VLOG(2) << "Graph has node " << n->type_string()
|
||||
@ -260,7 +258,6 @@ Status FunctionalizeControlFlowPass::Run(
|
||||
n->AddAttr(func_attr, func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
|
||||
|
@ -23,9 +23,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -51,12 +52,11 @@ namespace {
|
||||
Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
const std::vector<const XlaExpression*>& expressions,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
auto builder = ctx->builder();
|
||||
auto client = ctx->compiler()->client();
|
||||
std::vector<bool> compile_time_constant_flags(expressions.size());
|
||||
std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
|
||||
BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant,
|
||||
/*compile_time_const_nodes=*/nullptr));
|
||||
|
||||
args->resize(expressions.size());
|
||||
@ -65,25 +65,32 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
arg.type = ctx->input_type(i);
|
||||
arg.shape = ctx->InputShape(i);
|
||||
|
||||
if (arg.type == DT_RESOURCE) {
|
||||
return errors::InvalidArgument(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
} else if (expressions[i]->has_constant_value()) {
|
||||
switch (expressions[i]->kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = expressions[i]->constant_value();
|
||||
} else if (compile_time_constant_flags[i]) {
|
||||
break;
|
||||
case XlaExpression::Kind::kXlaOp:
|
||||
if (arg_must_be_compile_time_constant[i]) {
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
|
||||
expressions[i]->ResolveConstant(client));
|
||||
if (!value.has_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Argument to function must be a compile-time constant, but "
|
||||
"unable to resolve argument value to a constant.");
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
TF_RET_CHECK(expressions[i]->resource() == nullptr)
|
||||
<< "Input with resource is not yet implemented.";
|
||||
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
|
||||
expressions[i]->handle()));
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
client->ComputeConstant(constant_graph));
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralToHostTensor(literal, arg.type, &arg.constant_value));
|
||||
arg.constant_value = *value;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
}
|
||||
break;
|
||||
case XlaExpression::Kind::kResource:
|
||||
return errors::Unimplemented(
|
||||
"Resource as function argument is not yet implemented.");
|
||||
case XlaExpression::Kind::kInvalid:
|
||||
return errors::InvalidArgument("Invalid function argument");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -14,11 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
|
||||
if (arg.resource() != nullptr) {
|
||||
ctx->SetResourceOutput(0, arg.resource());
|
||||
} else if (arg.has_constant_value()) {
|
||||
ctx->SetConstantOutput(0, arg.constant_value());
|
||||
} else {
|
||||
ctx->SetOutput(0, arg.handle());
|
||||
}
|
||||
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
|
||||
errors::InvalidArgument("Invalid/missing argument expression"));
|
||||
ctx->SetOutputExpression(0, arg);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
|
||||
errors::InvalidArgument("In[", i, "] must be a vector.",
|
||||
in_shape.DebugString()));
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal));
|
||||
std::vector<int64> vec;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec));
|
||||
|
||||
BCast::Vec vec;
|
||||
for (int64 i = 0; i < in_shape.num_elements(); ++i) {
|
||||
vec.push_back(literal.Get<int>({i}));
|
||||
}
|
||||
shapes.push_back(vec);
|
||||
shapes.push_back(BCast::Vec(vec.begin(), vec.end()));
|
||||
}
|
||||
BCast bcast(shapes[0], shapes[1]);
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
@ -45,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_);
|
||||
OP_REQUIRES(
|
||||
ctx, IsLegacyScalar(concat_dim_tensor_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape),
|
||||
errors::InvalidArgument(
|
||||
"Concat dim tensor should be a scalar integer, but got shape ",
|
||||
"Concat dim tensor should be a scalar, but got shape ",
|
||||
concat_dim_tensor_shape.DebugString()));
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
|
||||
// TODO(annarev): add a helper to support int64 input.
|
||||
const int32 concat_dim = literal.Get<int>({});
|
||||
int64 concat_dim;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim));
|
||||
|
||||
std::vector<xla::XlaOp> values;
|
||||
std::vector<TensorShape> shapes;
|
||||
@ -63,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel {
|
||||
const TensorShape& input_shape = shapes[0];
|
||||
|
||||
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
||||
OP_REQUIRES(ctx,
|
||||
(0 <= axis && axis < input_dims) ||
|
||||
(allow_legacy_scalars() && concat_dim == 0),
|
||||
OP_REQUIRES(ctx, 0 <= axis && axis < input_dims,
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Expected concatenating dimensions in the range "
|
||||
"[",
|
||||
@ -75,14 +72,11 @@ class ConcatBaseOp : public XlaOpKernel {
|
||||
// elements.
|
||||
std::vector<xla::XlaOp> input_data;
|
||||
int output_concat_dim = 0;
|
||||
const bool input_is_scalar = IsLegacyScalar(input_shape);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
xla::XlaOp handle = values[i];
|
||||
const TensorShape& in_shape = shapes[i];
|
||||
const bool in_is_scalar = IsLegacyScalar(in_shape);
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
|
||||
ctx, in_shape.dims() == input_dims,
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
|
||||
input_shape.DebugString(), " vs. shape[", i,
|
||||
@ -131,10 +125,9 @@ class ConcatOffsetOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape concat_dim_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(
|
||||
ctx, IsLegacyScalar(concat_dim_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape),
|
||||
errors::InvalidArgument(
|
||||
"Concat dim tensor should be a scalar integer, but got shape ",
|
||||
"Concat dim tensor should be a scalar, but got shape ",
|
||||
concat_dim_shape.DebugString()));
|
||||
for (int i = 1; i < ctx->num_inputs(); ++i) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)),
|
||||
@ -162,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel {
|
||||
// [0, 5, 0, 0]
|
||||
const int32 N = ctx->num_inputs() - 1;
|
||||
const TensorShape inp0_shape = ctx->InputShape(1);
|
||||
xla::Literal inp0_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal));
|
||||
const int64 dims = inp0_shape.num_elements();
|
||||
std::vector<int64> inp0_dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims));
|
||||
const int64 inp0_rank = inp0_shape.num_elements();
|
||||
|
||||
xla::Literal concat_dim_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
|
||||
const int64 cdim = concat_dim_literal.Get<int>({});
|
||||
int64 cdim;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim));
|
||||
|
||||
VLOG(1) << "ConcatOffset " << cdim << "," << dims;
|
||||
int32 axis = cdim < 0 ? cdim + dims : cdim;
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
|
||||
VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank;
|
||||
int32 axis = cdim < 0 ? cdim + inp0_rank : cdim;
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank),
|
||||
errors::InvalidArgument("Concat dim is out of range: ", axis,
|
||||
" vs. ", dims));
|
||||
" vs. ", inp0_rank));
|
||||
int32 offset = 0;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const TensorShape inp_shape = ctx->InputShape(1 + i);
|
||||
OP_REQUIRES(ctx, dims == inp_shape.num_elements(),
|
||||
errors::InvalidArgument("input ", i, " should contain ", dims,
|
||||
" elements, but got ",
|
||||
OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(),
|
||||
errors::InvalidArgument("input ", i, " should contain ",
|
||||
inp0_rank, " elements, but got ",
|
||||
inp_shape.num_elements()));
|
||||
xla::Literal inp_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal));
|
||||
std::vector<int64> inp_dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims));
|
||||
|
||||
Tensor out_constant(DT_INT32, TensorShape({dims}));
|
||||
Tensor out_constant(DT_INT32, TensorShape({inp0_rank}));
|
||||
auto out_vec = out_constant.vec<int32>();
|
||||
for (int64 j = 0; j < dims; ++j) {
|
||||
for (int64 j = 0; j < inp0_rank; ++j) {
|
||||
if (j == axis) {
|
||||
out_vec(j) = offset;
|
||||
offset += inp_literal.Get<int>({j});
|
||||
offset += inp_dims[j];
|
||||
} else {
|
||||
const int32 inp0_element = inp0_literal.Get<int>({j});
|
||||
const int32 inp_element = inp_literal.Get<int>({j});
|
||||
OP_REQUIRES(ctx, (inp0_element == inp_element),
|
||||
const int32 inp0_element = inp0_dims[j];
|
||||
const int32 inp_element = inp_dims[j];
|
||||
OP_REQUIRES(ctx, inp0_element == inp_element,
|
||||
errors::InvalidArgument("input[", i, ",", j,
|
||||
"] mismatch: ", inp0_element,
|
||||
" vs. ", inp_element));
|
||||
|
@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape(proto_.tensor_shape());
|
||||
|
||||
if (proto_.dtype() == DT_STRING) {
|
||||
LOG(WARNING) << "Not computing Const of type DT_STRING";
|
||||
ctx->SetInvalidOutput(0);
|
||||
return;
|
||||
}
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
|
||||
// To avoid blowups for large constants filled with the same value,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -33,39 +34,20 @@ class FillOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// The output of this Op is a tensor of shape 'dims_shape' with each
|
||||
// element set to the scalar 'dims_literal'.
|
||||
const TensorShape dims_shape = ctx->InputShape(0);
|
||||
const TensorShape value_shape = ctx->InputShape(1);
|
||||
const TensorShape dims_shape = ctx->InputShape("dims");
|
||||
const TensorShape value_shape = ctx->InputShape("value");
|
||||
OP_REQUIRES(
|
||||
ctx, IsLegacyVector(dims_shape),
|
||||
ctx, TensorShapeUtils::IsVector(dims_shape),
|
||||
errors::InvalidArgument("dims must be a vector of int32, got shape ",
|
||||
dims_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(value_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape),
|
||||
errors::InvalidArgument("value must be a scalar, got shape ",
|
||||
value_shape.DebugString()));
|
||||
// Evaluate the 'dims' constant input, reshaping to a vector if it
|
||||
// was a 'legacy' vector (secretly a scalar).
|
||||
xla::Literal dims_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(
|
||||
0, {dims_shape.num_elements()}, &dims_literal));
|
||||
|
||||
// Convert the dims literal into a vector that we can pass to
|
||||
// XlaBuilder.
|
||||
std::vector<int64> broadcast;
|
||||
broadcast.reserve(dims_literal.shape().dimensions(0));
|
||||
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
||||
broadcast.push_back(dims_literal.Get<int>({i}));
|
||||
}
|
||||
// Look up the value input, reshaping to a scalar if it was a
|
||||
// 'legacy' scalar (secretly a vector).
|
||||
xla::XlaOp data = ctx->Input(1);
|
||||
if (value_shape.dims() > 0) {
|
||||
CHECK_EQ(value_shape.dims(), 1);
|
||||
data = xla::Reshape(data, {});
|
||||
}
|
||||
// Emit the actual computation, which broadcasts the scalar to the
|
||||
// desired shape.
|
||||
auto result = xla::Broadcast(data, broadcast);
|
||||
std::vector<int64> dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims));
|
||||
|
||||
auto result = xla::Broadcast(ctx->Input("value"), dims);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
@ -48,9 +48,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
// We require that the dimension argument is a constant, since it lets us
|
||||
// dispatch to a specialized custom-call function without any run-time
|
||||
// overhead, when compiling ahead-of-time.
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
|
||||
const int32 dim = literal.Get<int32>({});
|
||||
int64 dim;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
|
||||
OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
|
||||
OP_REQUIRES(
|
||||
ctx, dim < input_shape.dims(),
|
||||
|
@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
|
||||
--dimno) {
|
||||
auto t_rev = xla::Rev(accum, {dimno});
|
||||
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
|
||||
pad_literal.GetIntegralAsS64({dimno, 0}));
|
||||
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
|
||||
pad_literal.GetIntegralAsS64({dimno, 1}));
|
||||
int64 lhs_padding = pad_literal.Get<int64>({dimno, 0});
|
||||
int64 rhs_padding = pad_literal.Get<int64>({dimno, 1});
|
||||
int64 dim_size = original_shape.dimensions(dimno);
|
||||
|
||||
// Padding amounts on each side must be no more than the size of the
|
||||
@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape pad_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape pad_shape = ctx->InputShape("paddings");
|
||||
|
||||
MirrorPadMode mode;
|
||||
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
|
||||
@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
pad_shape.DebugString()));
|
||||
const int fixed_dims =
|
||||
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
|
||||
? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
ctx, fixed_dims == pad_shape.dim_size(0),
|
||||
ctx, dims == pad_shape.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs",
|
||||
pad_shape.DebugString(), " ", input_shape.DebugString()));
|
||||
|
||||
// Evaluate the 'padding' constant input, reshaping to a matrix.
|
||||
xla::Literal pad_literal;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
auto in0 = ctx->Input(0);
|
||||
auto in0 = ctx->Input("input");
|
||||
xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
|
||||
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
|
||||
xla::StatusOr<xla::XlaOp> accum_status =
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel {
|
||||
explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape pad_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape pad_shape = ctx->InputShape("paddings");
|
||||
const int dims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
pad_shape.DebugString()));
|
||||
const int fixed_dims =
|
||||
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
|
||||
? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
ctx, fixed_dims == pad_shape.dim_size(0),
|
||||
ctx, dims == pad_shape.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs",
|
||||
pad_shape.DebugString(), " ", input_shape.DebugString()));
|
||||
|
||||
if (fixed_dims == 0) {
|
||||
xla::XlaOp input = ctx->Input("input");
|
||||
if (dims == 0) {
|
||||
// Tensor is rank 0. Return it unchanged.
|
||||
ctx->SetOutput(0, ctx->Input(0));
|
||||
ctx->SetOutput(0, input);
|
||||
return;
|
||||
}
|
||||
|
||||
// Evaluate the 'padding' constant input, reshaping to a matrix.
|
||||
xla::Literal pad_literal;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
|
||||
|
||||
xla::PaddingConfig config;
|
||||
for (int i = 0; i < fixed_dims; ++i) {
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
auto* dim = config.add_dimensions();
|
||||
int before = pad_literal.Get<int32>({i, 0});
|
||||
int after = pad_literal.Get<int32>({i, 1});
|
||||
int before = pad_literal.Get<int64>({i, 0});
|
||||
int after = pad_literal.Get<int64>({i, 1});
|
||||
OP_REQUIRES(ctx, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument(
|
||||
"Paddings must be non-negative: ", before, " ", after));
|
||||
@ -73,12 +70,13 @@ class PadOp : public XlaOpKernel {
|
||||
// PadV2 added a "constant_values" input that indicates the pad value.
|
||||
xla::XlaOp constant_values;
|
||||
if (ctx->num_inputs() == 3) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")),
|
||||
errors::InvalidArgument("constant_values must be a scalar."));
|
||||
ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config));
|
||||
ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config));
|
||||
} else {
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
||||
ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config));
|
||||
ctx->SetOutput(0, xla::Pad(input, zero, config));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape sizes_shape = ctx->InputShape(1);
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(ctx, IsLegacyVector(sizes_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape),
|
||||
errors::InvalidArgument("sizes input must be 1-D, not shape ",
|
||||
sizes_shape.DebugString()));
|
||||
const int64 num_dims = sizes_shape.num_elements();
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -46,61 +47,8 @@ class RetvalOp : public XlaOpKernel {
|
||||
// compilation.
|
||||
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
|
||||
} else {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
DataType input_type = ctx->input_type(0);
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
|
||||
if (input_type == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
|
||||
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
|
||||
return;
|
||||
}
|
||||
|
||||
auto is_constant = ctx->builder()->IsConstant(input);
|
||||
if (!is_constant.ok()) {
|
||||
ctx->SetStatus(is_constant.status());
|
||||
return;
|
||||
}
|
||||
|
||||
if (tc.resolve_compile_time_constants() &&
|
||||
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
||||
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
||||
} else {
|
||||
TensorShape shape = ctx->InputShape(0);
|
||||
ctx->SetStatus(is_constant.status());
|
||||
TensorShape representation_shape;
|
||||
if (tc.is_entry_computation()) {
|
||||
xla::StatusOr<TensorShape> shape_or_status =
|
||||
tc.RepresentationShape(shape, ctx->input_type(0));
|
||||
if (!shape_or_status.ok()) {
|
||||
ctx->SetStatus(shape_or_status.status());
|
||||
return;
|
||||
} else {
|
||||
representation_shape = shape_or_status.ValueOrDie();
|
||||
}
|
||||
} else {
|
||||
representation_shape = shape;
|
||||
}
|
||||
|
||||
xla::XlaOp output = input;
|
||||
if (tc.is_entry_computation()) {
|
||||
output = xla::Reshape(input, representation_shape.dim_sizes());
|
||||
} else {
|
||||
// The core from which a return value is returned depends on the
|
||||
// device assignment of the input to the retval. Since we can't change
|
||||
// the device assignment of "input" at this point, we must always
|
||||
// introduce an operator here, even if the shape does not change.
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
output =
|
||||
xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
|
||||
}
|
||||
tc.AddRetval(index_, dtype_, shape, output);
|
||||
}
|
||||
XlaContext& xla_context = XlaContext::Get(ctx);
|
||||
xla_context.SetRetval(index_, ctx->InputExpression(0));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel {
|
||||
}
|
||||
// XlaBuilder::Rev() requires concrete values for dimensions arg.
|
||||
xla::Literal lax;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
|
||||
std::vector<bool> revdims(x_shape.dims());
|
||||
std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
|
||||
revdims.begin());
|
||||
std::vector<int64> dimensions;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax));
|
||||
|
||||
std::vector<int64> dimensions;
|
||||
for (int d = 0; d < x_shape.dims(); ++d) {
|
||||
if (revdims[d]) {
|
||||
if (lax.Get<bool>({d})) {
|
||||
dimensions.push_back(d);
|
||||
}
|
||||
}
|
||||
|
@ -30,31 +30,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
|
||||
xla::Literal literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
|
||||
*value = literal.Get<T>({});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
|
||||
xla::Literal literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
|
||||
switch (literal.shape().element_type()) {
|
||||
case xla::S32:
|
||||
*value = literal.Get<int32>({});
|
||||
break;
|
||||
case xla::S64:
|
||||
*value = literal.Get<int64>({});
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Invalid argument type for argument",
|
||||
index);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The type-specific part of the implementation of Range.
|
||||
template <typename T>
|
||||
xla::StatusOr<xla::XlaOp> CreateRangeTensor(
|
||||
@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel {
|
||||
const TensorShape start_in_shape = ctx->InputShape(0);
|
||||
const TensorShape limit_in_shape = ctx->InputShape(1);
|
||||
const TensorShape delta_in_shape = ctx->InputShape(2);
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
|
||||
errors::InvalidArgument("start must be a scalar, not shape ",
|
||||
start_in_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape),
|
||||
errors::InvalidArgument("limit must be a scalar, not shape ",
|
||||
limit_in_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape),
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape),
|
||||
errors::InvalidArgument("delta must be a scalar, not shape ",
|
||||
delta_in_shape.DebugString()));
|
||||
xla::Literal start, limit, delta;
|
||||
@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel {
|
||||
explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape start_in_shape = ctx->InputShape(0);
|
||||
const TensorShape stop_in_shape = ctx->InputShape(1);
|
||||
const TensorShape num_in_shape = ctx->InputShape(2);
|
||||
const TensorShape start_in_shape = ctx->InputShape("start");
|
||||
const TensorShape stop_in_shape = ctx->InputShape("stop");
|
||||
const TensorShape num_in_shape = ctx->InputShape("num");
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
|
||||
errors::InvalidArgument("start must be a scalar, not shape ",
|
||||
start_in_shape.DebugString()));
|
||||
@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel {
|
||||
DataType type = ctx->input_type(0);
|
||||
|
||||
int64 num;
|
||||
OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
|
||||
OP_REQUIRES(ctx, num > 0,
|
||||
errors::InvalidArgument("Requires num > 0: ", num));
|
||||
Tensor out_constant(type, TensorShape({num}));
|
||||
|
||||
xla::Literal start_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
|
||||
xla::Literal stop_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
|
||||
|
||||
switch (type) {
|
||||
case DT_FLOAT: {
|
||||
float start, stop;
|
||||
OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
|
||||
OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
|
||||
float start = start_literal.GetFirstElement<float>();
|
||||
float stop = stop_literal.GetFirstElement<float>();
|
||||
auto flat = out_constant.flat<float>();
|
||||
if (num == 1) {
|
||||
flat(0) = start;
|
||||
@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel {
|
||||
break;
|
||||
}
|
||||
case DT_DOUBLE: {
|
||||
double start, stop;
|
||||
OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
|
||||
OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
|
||||
double start = start_literal.GetFirstElement<double>();
|
||||
double stop = stop_literal.GetFirstElement<double>();
|
||||
auto flat = out_constant.flat<double>();
|
||||
if (num == 1) {
|
||||
flat(0) = start;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel {
|
||||
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape dim_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape dim_shape = ctx->InputShape("dim");
|
||||
|
||||
// TODO(phawkins): the standard implementation of ExpandDimsOp seems to
|
||||
// accept legacy scalars, even when they should be forbidden by the graphdef
|
||||
// version.
|
||||
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
|
||||
std::vector<int64> dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
|
||||
OP_REQUIRES(ctx, dims.size() == 1,
|
||||
errors::InvalidArgument(absl::StrCat(
|
||||
"dim input to ExpandDims must be a scalar; got ",
|
||||
dim_shape.DebugString())));
|
||||
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
|
||||
|
||||
int dim = literal.data<int32>()[0];
|
||||
int dim = dims[0];
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
|
||||
@ -148,7 +144,7 @@ class ExpandDimsOp : public XlaOpKernel {
|
||||
dim = std::min<int32>(dim, existing_dims_size);
|
||||
new_shape.emplace(new_shape.begin() + dim, 1);
|
||||
|
||||
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
|
||||
ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel {
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
IsLegacyVector(begin_tensor_shape) &&
|
||||
IsLegacyVector(size_tensor_shape) &&
|
||||
TensorShapeUtils::IsVector(begin_tensor_shape) &&
|
||||
TensorShapeUtils::IsVector(size_tensor_shape) &&
|
||||
begin_tensor_shape.num_elements() == input_shape.dims() &&
|
||||
size_tensor_shape.num_elements() == input_shape.dims(),
|
||||
errors::InvalidArgument(
|
||||
|
@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const int32 num_split = num_outputs();
|
||||
const TensorShape index_shape = ctx->InputShape(0);
|
||||
const TensorShape split_dim_shape = ctx->InputShape("split_dim");
|
||||
const TensorShape input_shape = ctx->InputShape(1);
|
||||
|
||||
xla::Literal literal_index;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(split_dim_shape),
|
||||
errors::InvalidArgument("split_dim must be a scalar but has rank ",
|
||||
split_dim_shape.dims()));
|
||||
int64 split_dim_orig;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig));
|
||||
|
||||
int32 split_dim_orig;
|
||||
if (index_shape.dims() == 0) {
|
||||
split_dim_orig = literal_index.Get<int>({});
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, index_shape.dims() == 1,
|
||||
errors::InvalidArgument("split_index input to Split Op must be a "
|
||||
"scalar or a vector with 1 element"));
|
||||
OP_REQUIRES(
|
||||
ctx, index_shape.dim_size(0) == 1,
|
||||
errors::InvalidArgument("split_index input to Split Op must be a "
|
||||
"scalar or a vector with 1 element"));
|
||||
split_dim_orig = literal_index.Get<int>({0});
|
||||
}
|
||||
int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
|
||||
: split_dim_orig;
|
||||
OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
|
||||
@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel {
|
||||
// Check that sizes are correct.
|
||||
int total_split_size = 0;
|
||||
int neg_one_dim = -1;
|
||||
std::vector<int64> split_sizes_vec(num_split, -1);
|
||||
const TensorShape split_size_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx,
|
||||
split_size_shape.dims() == 1 &&
|
||||
@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel {
|
||||
split_size_shape.dims(), "-D and ",
|
||||
split_size_shape.num_elements(), " elements"));
|
||||
// Get the dimension of this split.
|
||||
xla::Literal split_size_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
|
||||
std::vector<int64> split_sizes;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes));
|
||||
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
int slice_size;
|
||||
slice_size = split_size_literal.Get<int>({i});
|
||||
int64 slice_size = split_sizes[i];
|
||||
if (slice_size == -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, neg_one_dim == -1,
|
||||
@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel {
|
||||
i));
|
||||
neg_one_dim = i;
|
||||
} else {
|
||||
split_sizes_vec[i] = slice_size;
|
||||
total_split_size += slice_size;
|
||||
}
|
||||
}
|
||||
@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
total_split_size));
|
||||
|
||||
if (neg_one_dim >= 0) {
|
||||
split_sizes_vec[neg_one_dim] =
|
||||
split_sizes[neg_one_dim] =
|
||||
input_shape.dim_size(split_dim) - total_split_size;
|
||||
}
|
||||
|
||||
@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
std::vector<int64> strides(input_shape.dims(), 1);
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
TensorShape output_shape(input_shape);
|
||||
int slice_size = split_sizes_vec[i];
|
||||
int slice_size = split_sizes[i];
|
||||
output_shape.set_dim(split_dim, slice_size);
|
||||
|
||||
// Slice out the ith split from the split dimension.
|
||||
|
@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp);
|
||||
REGISTER_XLA_OP(
|
||||
Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
|
||||
StackOp);
|
||||
|
||||
class StackPushOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
|
||||
REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
|
||||
|
||||
class StackPopOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
|
||||
REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
|
||||
|
||||
class StackCloseOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
|
||||
REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/type_index.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -44,7 +45,7 @@ class TileOp : public XlaOpKernel {
|
||||
const TensorShape multiples_shape = ctx->InputShape("multiples");
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, IsLegacyVector(multiples_shape),
|
||||
ctx, TensorShapeUtils::IsVector(multiples_shape),
|
||||
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
||||
multiples_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(),
|
||||
|
@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel {
|
||||
: XlaOpKernel(ctx), conjugate_(conjugate) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape perm_tensor_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("x");
|
||||
const TensorShape perm_tensor_shape = ctx->InputShape("perm");
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
|
||||
@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel {
|
||||
". But input(1) is a vector of size ",
|
||||
perm_tensor_shape.num_elements()));
|
||||
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
|
||||
|
||||
std::vector<int32> perm(dims);
|
||||
std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
|
||||
perm.begin());
|
||||
std::vector<int64> perm;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm));
|
||||
|
||||
std::vector<int64> transposed_order;
|
||||
// Check whether permutation is a permutation of integers of [0 .. dims).
|
||||
absl::InlinedVector<bool, 8> bits(dims);
|
||||
bool is_identity = true;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
const int32 d = perm[i];
|
||||
const int64 d = perm[i];
|
||||
OP_REQUIRES(
|
||||
ctx, 0 <= d && d < dims,
|
||||
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
|
||||
@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel {
|
||||
xla::XlaOp transposed;
|
||||
// 0-D, 1-D, and identity transposes do nothing.
|
||||
if (dims <= 1 || is_identity) {
|
||||
transposed = ctx->Input(0);
|
||||
transposed = ctx->Input("x");
|
||||
} else {
|
||||
transposed = xla::Transpose(ctx->Input(0), transposed_order);
|
||||
transposed = xla::Transpose(ctx->Input("x"), transposed_order);
|
||||
}
|
||||
|
||||
// Conjugate the transposed result if this is ConjugateTransposeOp.
|
||||
|
@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
|
||||
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
|
||||
XLAJIT_MAKE_UNARY(Neg, -x);
|
||||
|
||||
// Implements Banker's rounding: numbers that are equidistant between two
|
||||
// integers are rounded towards even.
|
||||
xla::XlaOp RoundToEven(xla::XlaOp x) {
|
||||
auto half = xla::ScalarLike(x, 0.5);
|
||||
auto one = xla::ScalarLike(x, 1.0);
|
||||
auto two = xla::ScalarLike(x, 2.0);
|
||||
|
||||
auto round_val = xla::Floor(x);
|
||||
auto fraction = x - round_val;
|
||||
auto nearest_even_int = round_val - two * xla::Floor(half * x);
|
||||
auto is_odd = xla::Eq(nearest_even_int, one);
|
||||
return xla::Select(xla::Or(xla::Gt(fraction, half),
|
||||
xla::And(xla::Eq(fraction, half), is_odd)),
|
||||
round_val + one, round_val);
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rint, RoundToEven(x));
|
||||
XLAJIT_MAKE_UNARY(Round, RoundToEven(x));
|
||||
XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x));
|
||||
XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
|
||||
|
||||
|
@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor) {
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal));
|
||||
return literal.Clone();
|
||||
}
|
||||
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
|
||||
xla::Shape xla_shape;
|
||||
|
@ -30,6 +30,11 @@ namespace tensorflow {
|
||||
// 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal);
|
||||
|
||||
// Returns a Literal with the contents of 'host_tensor', backed by its own
|
||||
// storage (i.e., not reusing 'host_tensor's buffers.)
|
||||
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor);
|
||||
|
||||
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
|
||||
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
|
@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/deepmind/public/wavenet/python:__subpackages__",
|
||||
"//learning/deepmind/research/alphastar:__subpackages__",
|
||||
"//learning/tfx:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
|
@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto(
|
||||
"XLACompilationDevice::MakeTensorFromProto should not be called");
|
||||
}
|
||||
|
||||
XlaExpression::XlaExpression() = default;
|
||||
|
||||
void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
|
||||
|
||||
void XlaExpression::set_constant_value(Tensor value) {
|
||||
has_constant_value_ = true;
|
||||
constant_value_ = std::move(value);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,9 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -38,8 +35,8 @@ class XlaCompilationAllocator;
|
||||
// This is a 'dummy' TensorFlow device that is only used to execute a
|
||||
// subgraph of XLA compilation Ops to construct a compiled version
|
||||
// of the subgraph's computation. It has a 'dummy' allocator that
|
||||
// backs each Tensor with metadata indicating the computation the
|
||||
// Tensor represents.
|
||||
// backs each Tensor with an XlaExpression. The shape of the Tensor
|
||||
// matches the shape of XlaExpression.
|
||||
//
|
||||
// We deliberately don't register a device factory because we *never*
|
||||
// want placement to put Ops on a compilation device. The device is created
|
||||
@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice {
|
||||
std::unique_ptr<XlaCompilationAllocator> allocator_;
|
||||
};
|
||||
|
||||
// A XlaExpression wraps an XLA computation. Each Tensor on an
|
||||
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
|
||||
// matches the shape of the subcomputation in the XlaOp. Each
|
||||
// expression is either a constant, or a function of previously-compiled
|
||||
// expressions.
|
||||
class XlaExpression {
|
||||
public:
|
||||
XlaExpression();
|
||||
|
||||
// handle() stores the XLA handle of the computation that the
|
||||
// expression represents.
|
||||
void set_handle(const xla::XlaOp& h);
|
||||
const xla::XlaOp& handle() const { return handle_; }
|
||||
|
||||
void set_constant_value(Tensor value);
|
||||
bool has_constant_value() const { return has_constant_value_; }
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
void set_resource(XlaResource* resource) { resource_ = resource; }
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
private:
|
||||
// The XLA handle of the expression's computation.
|
||||
xla::XlaOp handle_;
|
||||
|
||||
// If this expression is a constant with a known value, 'constant_value' is a
|
||||
// host-memory Tensor containing the value. Used to avoid invoking XLA for
|
||||
// expressions that are trivially constant.
|
||||
bool has_constant_value_ = false;
|
||||
Tensor constant_value_;
|
||||
|
||||
XlaResource* resource_ = nullptr; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
|
||||
|
@ -36,10 +36,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -48,7 +51,7 @@ namespace {
|
||||
|
||||
// Checks that arguments `args` match types `types`.
|
||||
Status CheckSignature(const DataTypeVector& types,
|
||||
const std::vector<XlaCompiler::Argument>& args) {
|
||||
absl::Span<const XlaCompiler::Argument> args) {
|
||||
if (args.size() != types.size()) {
|
||||
return errors::Internal("Compilation arguments have ", args.size(),
|
||||
" elements while function has ", types.size());
|
||||
@ -63,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Uses the _Arg and _Retval nodes in the graph to determine a core assignment
|
||||
// for each argument and return value.
|
||||
xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
|
||||
ComputeArgAndRetvalCores(const Graph& graph) {
|
||||
auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
return sharding.value().tile_assignment_devices(0);
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
std::map<int, int> arg_cores;
|
||||
std::map<int, int> retval_cores;
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
|
||||
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
|
||||
if (core < 0) continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0) << "Negative _Arg index";
|
||||
arg_cores[index] = core;
|
||||
} else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
|
||||
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
|
||||
if (core < 0) continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0) << "Negative _Retval index";
|
||||
TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
|
||||
retval_cores[index] = core;
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(arg_cores), std::move(retval_cores));
|
||||
}
|
||||
|
||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
|
||||
int64 step_id) {
|
||||
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
|
||||
// resource manager takes ownership via Create, and unrefs via Cleanup. We
|
||||
// explicitly add a reference to ensure the refcount at entry is maintained at
|
||||
// all exit points; Create and Cleanup are always called in this function.
|
||||
//
|
||||
// The Executor requires us to use ScopedStepContainer. We wrap it in a
|
||||
// unique_ptr so we can capture the cleanup status in the end.
|
||||
xla_context->Ref();
|
||||
Status status;
|
||||
auto step_container = absl::make_unique<ScopedStepContainer>(
|
||||
step_id, [&status, device](const string& name) {
|
||||
status = device->resource_manager()->Cleanup(name);
|
||||
});
|
||||
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
|
||||
step_container->name(), XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
|
||||
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
|
||||
TF_RETURN_IF_ERROR(graph_compiler.Compile());
|
||||
// Explicitly clean up the step container, to capture the cleanup status.
|
||||
step_container.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// - `args` is the list of input arguments
|
||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||
// order.
|
||||
// - `args_core` and `retval_cores` are mapping from arg/return indices to core
|
||||
// assignments.
|
||||
// - If `return_updated_values_for_all_resources` is true, all resources will be
|
||||
// included in `resource_updates`, regardless of whether their value changed.
|
||||
// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
|
||||
// - Sets `*resource_updates` to a description of resources whose values are
|
||||
// written by the computation; the variable writes are the last
|
||||
// - `resource_updates.size()` return values from the computation. Each entry in
|
||||
// `resource_updates` is a ResourceUpdate, whose `index` is the index of a
|
||||
// resource variable argument to the computation to be updated, and `type` is
|
||||
// the type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<XlaExpression>& retvals,
|
||||
const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
std::unique_ptr<xla::XlaOp> token_output,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
bool return_updated_values_for_all_resources, bool always_return_tuple,
|
||||
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
xla::OpMetadata retval_metadata;
|
||||
retval_metadata.set_op_name("XLA_Retvals");
|
||||
builder->SetOpMetadata(retval_metadata);
|
||||
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
|
||||
|
||||
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
|
||||
// cannot change the sharding of the existing output op. To do this, we build
|
||||
// a new identity op to which shardings can be applied.
|
||||
auto identity_op = [builder](xla::XlaOp op) {
|
||||
return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
|
||||
};
|
||||
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||
const XlaExpression& retval = retvals[i];
|
||||
output.type = retval.dtype();
|
||||
switch (retval.kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
output.shape = output.constant_value.shape();
|
||||
break;
|
||||
|
||||
case XlaExpression::Kind::kXlaOp: {
|
||||
output.is_constant = false;
|
||||
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
||||
xla::XlaOp value = retval.handle();
|
||||
auto it = retval_cores.find(i);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, it == retval_cores.end()
|
||||
? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(it->second));
|
||||
if (shape_representation_fn) {
|
||||
// If there is a shape representation function, reshape the output
|
||||
// tensor to the shape given by the representation shape function.
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
||||
output.shape, output.type));
|
||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
||||
} else if (it != retval_cores.end()) {
|
||||
// Apply the sharding to the output, if there is a core assignment.
|
||||
value = identity_op(value);
|
||||
}
|
||||
elems.push_back(value);
|
||||
break;
|
||||
}
|
||||
|
||||
case XlaExpression::Kind::kResource:
|
||||
output.is_constant = false;
|
||||
output.input_index = retval.resource()->arg_num();
|
||||
output.shape = retval.resource()->shape();
|
||||
break;
|
||||
|
||||
case XlaExpression::Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"Invalid expression returned by computation. "
|
||||
"This probably means a return value was not set.");
|
||||
}
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for resources whose values have changed.
|
||||
std::vector<const XlaResource*> arg_resources;
|
||||
arg_resources.reserve(resources.size());
|
||||
for (const auto& resource : resources) {
|
||||
if (resource->arg_num() >= 0) {
|
||||
arg_resources.push_back(resource.get());
|
||||
}
|
||||
}
|
||||
std::sort(arg_resources.begin(), arg_resources.end(),
|
||||
[](const XlaResource* a, const XlaResource* b) {
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
DCHECK_LT(resource->arg_num(), args.size());
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num()];
|
||||
auto it = arg_cores.find(resource->arg_num());
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
|
||||
// TensorArray gradients were modified if their values changed or there are
|
||||
// any newly created gradients.
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
modified =
|
||||
modified ||
|
||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||
}
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = resource->arg_num();
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Ensures the correct sharding is applied to the output.
|
||||
handle = identity_op(handle);
|
||||
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
|
||||
// If we have token output, append it as the last one.
|
||||
if (token_output) {
|
||||
elems.push_back(*token_output);
|
||||
}
|
||||
|
||||
*num_computation_outputs = elems.size();
|
||||
|
||||
// Builds the XLA computation. We *always* form a tuple here to ensure that
|
||||
// the output value is the last thing added into the XLA computation, even
|
||||
// if there is only one output value.
|
||||
auto tuple = xla::Tuple(builder, elems);
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
xla::GetTupleElement(tuple, 0);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool XlaCompiler::Argument::operator==(
|
||||
@ -83,6 +320,39 @@ bool XlaCompiler::Argument::operator==(
|
||||
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
||||
}
|
||||
|
||||
string XlaCompiler::Argument::HumanString() const {
|
||||
string common;
|
||||
if (!name.empty()) {
|
||||
common = absl::StrCat(" name=", name);
|
||||
}
|
||||
absl::StrAppend(&common, " type=", DataTypeString(type),
|
||||
" shape=", shape.DebugString());
|
||||
switch (kind) {
|
||||
case kInvalid:
|
||||
return "invalid";
|
||||
case kConstant:
|
||||
return absl::StrCat("kind=constant", common,
|
||||
" value=", constant_value.DebugString());
|
||||
case kResource: {
|
||||
string output = absl::StrCat("kind=resource", common, " resource_kind=",
|
||||
XlaResource::KindToString(resource_kind),
|
||||
" initialized=", initialized);
|
||||
if (tensor_array_size >= 0) {
|
||||
absl::StrAppend(&output, " tensor_array_size=", tensor_array_size);
|
||||
}
|
||||
if (!tensor_array_gradients.empty()) {
|
||||
absl::StrAppend(&output, " tensor_array_gradients=",
|
||||
absl::StrJoin(tensor_array_gradients, ","));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
case kParameter:
|
||||
return absl::StrCat("kind=parameter", common);
|
||||
case kToken:
|
||||
return absl::StrCat("token", common);
|
||||
}
|
||||
}
|
||||
|
||||
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
: options_(options),
|
||||
initialization_status_(Status::OK()),
|
||||
@ -110,8 +380,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
|
||||
// The default shape representation function is the identity.
|
||||
if (!options_.shape_representation_fn) {
|
||||
options_.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) { return shape; };
|
||||
options_.shape_representation_fn =
|
||||
[](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,15 +446,16 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
const NameAttrList& function,
|
||||
std::vector<XlaCompiler::Argument> args,
|
||||
Status XlaCompiler::CompileFunction(
|
||||
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
const string function_id =
|
||||
Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
|
||||
|
||||
auto it = cache_.find({function_id, args});
|
||||
const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
|
||||
auto it = cache_.find({function_id, arg_vector});
|
||||
if (it != cache_.end()) {
|
||||
*result = it->second;
|
||||
return Status::OK();
|
||||
@ -212,14 +488,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
// lowest-numbered core that consumes the argument. We choose the
|
||||
// lowest-numbered core so the assignment is deterministic.
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::string_view(n->type_string()) == "_Arg") {
|
||||
if (absl::string_view(n->type_string()) ==
|
||||
FunctionLibraryDefinition::kArgOp) {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
|
||||
}
|
||||
}
|
||||
// Do _Retval as a second loop, in case the retval's input is an _Arg (which
|
||||
// may have gotten a device assignment from the first loop).
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::string_view(n->type_string()) == "_Retval") {
|
||||
if (absl::string_view(n->type_string()) ==
|
||||
FunctionLibraryDefinition::kRetOp) {
|
||||
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
|
||||
}
|
||||
}
|
||||
@ -235,7 +513,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
VLOG(1) << "====================================================";
|
||||
|
||||
cache_[{function_id, args}] = *result;
|
||||
cache_[{function_id, arg_vector}] = *result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -247,25 +525,24 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
LOG(FATAL) << "Unreachable case";
|
||||
case XlaCompiler::Argument::kParameter: {
|
||||
TensorShape shape;
|
||||
if (is_entry_computation) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
shape, options_.shape_representation_fn(arg.shape, arg.type));
|
||||
*xla_shape, options_.shape_representation_fn(arg.shape, arg.type));
|
||||
} else {
|
||||
shape = arg.shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(arg.type, arg.shape, xla_shape));
|
||||
}
|
||||
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
|
||||
switch (arg.resource_kind) {
|
||||
case XlaResource::kVariable: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
TensorShape representation_shape,
|
||||
options_.shape_representation_fn(arg.shape, arg.type));
|
||||
return TensorShapeToXLAShape(arg.type, representation_shape,
|
||||
xla_shape);
|
||||
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
|
||||
arg.shape, arg.type));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaResource::kTensorArray: {
|
||||
if (arg.tensor_array_size < 0) {
|
||||
@ -314,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
|
||||
int64 step_id) {
|
||||
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
|
||||
// resource manager takes ownership via Create, and unrefs via Cleanup. We
|
||||
// explicitly add a reference to ensure the refcount at entry is maintained at
|
||||
// all exit points; Create and Cleanup are always called in this function.
|
||||
//
|
||||
// The Executor requires us to use ScopedStepContainer. We wrap it in a
|
||||
// unique_ptr so we can capture the cleanup status in the end.
|
||||
xla_context->Ref();
|
||||
Status status;
|
||||
auto step_container = absl::make_unique<ScopedStepContainer>(
|
||||
step_id, [&status, device](const string& name) {
|
||||
status = device->resource_manager()->Cleanup(name);
|
||||
});
|
||||
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
|
||||
step_container->name(), XlaContext::kXlaContextResourceName,
|
||||
xla_context));
|
||||
|
||||
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
|
||||
TF_RETURN_IF_ERROR(graph_compiler.Compile());
|
||||
// Explicitly clean up the step container, to capture the cleanup status.
|
||||
step_container.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// `args` is the list of input arguments, `retvals` is the list of retvals
|
||||
// produced by _Retval operators, in index order.
|
||||
// If `return_updated_values_for_all_resources` is true, all resources will be
|
||||
// included in `resource_updates`, regardless of whether their value changed.
|
||||
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
|
||||
// Sets `*resource_updates` to a description of resources whose values are
|
||||
// written by the computation; the variable writes are the last
|
||||
// `resource_updates.size()` return values from the computation. Each entry in
|
||||
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
|
||||
// index of a resource variable argument to the computation, and `type` is the
|
||||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<int>& arg_cores,
|
||||
const std::vector<XlaContext::Retval>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
std::unique_ptr<xla::XlaOp> token_output,
|
||||
bool return_updated_values_for_all_resources, bool always_return_tuple,
|
||||
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||
output.type = retvals[i].type;
|
||||
output.shape = retvals[i].shape;
|
||||
const XlaExpression& retval = retvals[i].expression;
|
||||
if (retval.has_constant_value()) {
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
} else if (retval.resource() != nullptr) {
|
||||
output.is_constant = false;
|
||||
output.input_index = retval.resource()->arg_num();
|
||||
} else {
|
||||
output.is_constant = false;
|
||||
elems.push_back(retval.handle());
|
||||
}
|
||||
}
|
||||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for resources whose values have changed.
|
||||
std::vector<const XlaResource*> arg_resources;
|
||||
arg_resources.reserve(resources.size());
|
||||
for (const auto& resource : resources) {
|
||||
if (resource->arg_num() >= 0) {
|
||||
arg_resources.push_back(resource.get());
|
||||
}
|
||||
}
|
||||
std::sort(arg_resources.begin(), arg_resources.end(),
|
||||
[](const XlaResource* a, const XlaResource* b) {
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
xla::OpMetadata retval_metadata;
|
||||
retval_metadata.set_op_name("XLA_Retvals");
|
||||
builder->SetOpMetadata(retval_metadata);
|
||||
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num()];
|
||||
const int core = arg_cores[resource->arg_num()];
|
||||
DCHECK_LT(resource->arg_num(), arg_cores.size());
|
||||
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
|
||||
// TensorArray gradients were modified if their values changed or there are
|
||||
// any newly created gradients.
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
modified =
|
||||
modified ||
|
||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||
}
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = resource->arg_num();
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Since we can't change the sharding metadata of <value> as this point,
|
||||
// create a tuple/get-tuple-element combination so that sharding
|
||||
// assignment will be placed on this value, which will cause the resource
|
||||
// update to be returned from the same device that provided the resource.
|
||||
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
|
||||
// If we have token output, append it as the last one.
|
||||
if (token_output) {
|
||||
elems.push_back(*token_output);
|
||||
}
|
||||
|
||||
*num_computation_outputs = elems.size();
|
||||
|
||||
// Builds the XLA computation. We *always* form a tuple here to ensure that
|
||||
// the output value is the last thing added into the XLA computation, even
|
||||
// if there is only one output value.
|
||||
auto tuple = xla::Tuple(builder, elems);
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
xla::GetTupleElement(tuple, 0);
|
||||
}
|
||||
builder->ClearOpMetadata();
|
||||
|
||||
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
|
||||
if (!computation_status.ok()) {
|
||||
return computation_status.status();
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Builds XLA computations for each of the arguments to the computation.
|
||||
// `args` are the arguments to the computation.
|
||||
Status XlaCompiler::BuildArguments(
|
||||
const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
|
||||
std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
|
||||
const std::map<int, int>& arg_cores,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
|
||||
bool is_entry_computation) {
|
||||
arg_expressions->resize(args.size());
|
||||
*arg_cores = std::vector<int>(args.size(), -1);
|
||||
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
@ -504,7 +622,7 @@ Status XlaCompiler::BuildArguments(
|
||||
arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
|
||||
/*tensor_array_size=*/arg.tensor_array_size,
|
||||
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
|
||||
arg_expression.set_resource(resource);
|
||||
arg_expression = XlaExpression::Resource(resource);
|
||||
if (arg.initialized) {
|
||||
input_mapping->push_back(i);
|
||||
}
|
||||
@ -516,7 +634,7 @@ Status XlaCompiler::BuildArguments(
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
arg_expression.set_constant_value(arg.constant_value);
|
||||
arg_expression = XlaExpression::Constant(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal(
|
||||
@ -541,26 +659,6 @@ Status XlaCompiler::BuildArguments(
|
||||
*input_shapes = arg_shapes;
|
||||
}
|
||||
|
||||
// Use the _Arg nodes in the graph to resolve core assignments.
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (absl::string_view(n->type_string()) != "_Arg") continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0 && index < args.size())
|
||||
<< "_Arg out of bounds: " << index << " vs " << args.size();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
const int core = sharding.value().tile_assignment_devices(0);
|
||||
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
|
||||
(*arg_cores)[index] = core;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
@ -576,11 +674,10 @@ Status XlaCompiler::BuildArguments(
|
||||
xla::OpSharding tuple_sharding;
|
||||
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
||||
for (int64 parameter : *input_mapping) {
|
||||
const int core = (*arg_cores)[parameter];
|
||||
const int root_device = 0;
|
||||
auto it = arg_cores.find(parameter);
|
||||
const int core = it == arg_cores.end() ? 0 : it->second;
|
||||
*tuple_sharding.add_tuple_shardings() =
|
||||
core == -1 ? xla::sharding_builder::AssignDevice(root_device)
|
||||
: xla::sharding_builder::AssignDevice(core);
|
||||
xla::sharding_builder::AssignDevice(core);
|
||||
}
|
||||
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
|
||||
tuple_sharding);
|
||||
@ -589,7 +686,8 @@ Status XlaCompiler::BuildArguments(
|
||||
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
|
||||
}
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
auto it = arg_cores.find(i);
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
@ -597,7 +695,8 @@ Status XlaCompiler::BuildArguments(
|
||||
}
|
||||
} else {
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
auto it = arg_cores.find(i);
|
||||
const int core = it == arg_cores.end() ? -1 : it->second;
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? absl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
@ -632,14 +731,14 @@ Status XlaCompiler::BuildArguments(
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
if (is_entry_computation) {
|
||||
arg_expression.set_handle(
|
||||
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
|
||||
arg_expression = XlaExpression::XlaOp(
|
||||
xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type);
|
||||
} else {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
@ -653,26 +752,28 @@ Status XlaCompiler::BuildArguments(
|
||||
}
|
||||
|
||||
Status XlaCompiler::CompileSingleOp(
|
||||
const XlaCompiler::CompileOptions& options, string const& name,
|
||||
OpKernelContext* ctx, const std::vector<XlaCompiler::Argument>& args,
|
||||
CompilationResult* result) {
|
||||
const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const DataType> result_types, CompilationResult* result) {
|
||||
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
||||
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
Status status;
|
||||
// First create the actual node we care about computing.
|
||||
Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status);
|
||||
Node* main_node = graph->AddNode(node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Create dummy _Arg nodes. Link these to `node` and also via a control
|
||||
// dependency edge to the _SOURCE node.
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
for (int64 i = 0; i < args.size(); ++i) {
|
||||
Node* node;
|
||||
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
|
||||
Status status = NodeBuilder(name, "_Arg")
|
||||
string arg_name = absl::StrCat("_arg", i);
|
||||
Status status =
|
||||
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
|
||||
.ControlInput(graph->source_node())
|
||||
.Attr("T", ctx->input_dtype(i))
|
||||
.Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
|
||||
: args[i].type)
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
@ -680,19 +781,19 @@ Status XlaCompiler::CompileSingleOp(
|
||||
}
|
||||
|
||||
// Similarly with return values, create dummy _Retval nodes fed by `node`.
|
||||
for (int64 i = 0; i < ctx->num_outputs(); ++i) {
|
||||
for (int64 i = 0; i < result_types.size(); ++i) {
|
||||
Node* node;
|
||||
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
|
||||
Status status = NodeBuilder(name, "_Retval")
|
||||
string retval_name = absl::StrCat("_retval", i);
|
||||
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
|
||||
.Input(main_node, i)
|
||||
.Attr("T", ctx->expected_output_dtype(i))
|
||||
.Attr("T", result_types[i])
|
||||
.Attr("index", i)
|
||||
.Finalize(graph.get(), &node);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
|
||||
return CompileGraph(options, name, std::move(graph), args, result);
|
||||
return CompileGraph(options, node_def.name(), std::move(graph), args, result);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -747,12 +848,38 @@ Status ValidateGraph(const Graph* graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the value of any expressions whose values are known at compile-time
|
||||
// to constants.
|
||||
Status ResolveConstantExpressionsToConstants(
|
||||
xla::Client* client, absl::Span<XlaExpression> expressions) {
|
||||
for (XlaExpression& expression : expressions) {
|
||||
if (expression.kind() == XlaExpression::Kind::kXlaOp) {
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
|
||||
expression.ResolveConstant(client));
|
||||
if (constant.has_value()) {
|
||||
expression = XlaExpression::Constant(*constant);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
|
||||
absl::Span<XlaExpression> expressions) {
|
||||
for (XlaExpression& expression : expressions) {
|
||||
if (expression.kind() == XlaExpression::Kind::kConstant) {
|
||||
expression =
|
||||
XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
string const& name,
|
||||
std::unique_ptr<Graph> graph,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
CompilationResult* result) {
|
||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
|
||||
|
||||
@ -774,13 +901,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
options_.device_type, name));
|
||||
|
||||
xla::XlaBuilder builder(name);
|
||||
XlaContext* context = new XlaContext(
|
||||
this, &builder, options_.allow_cpu_custom_calls,
|
||||
options.resolve_compile_time_constants, options.is_entry_computation,
|
||||
XlaContext* context =
|
||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||
&options_.shape_representation_fn);
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaCompiler::Argument> real_args(args);
|
||||
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
|
||||
int token_input_index = -1;
|
||||
std::unique_ptr<xla::XlaOp> token_output;
|
||||
if (options.add_token_input_output) {
|
||||
@ -792,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
real_args.push_back(token_arg);
|
||||
}
|
||||
|
||||
std::map<int, int> arg_cores;
|
||||
std::map<int, int> retval_cores;
|
||||
TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
|
||||
ComputeArgAndRetvalCores(*graph));
|
||||
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
std::vector<int> arg_cores;
|
||||
TF_RETURN_IF_ERROR(BuildArguments(
|
||||
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
|
||||
*graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
|
||||
&arg_expressions, &result->input_mapping, &result->xla_input_shapes,
|
||||
options.is_entry_computation));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
@ -843,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
int num_computation_outputs;
|
||||
result->computation = std::make_shared<xla::XlaComputation>();
|
||||
result->outputs.resize(context->retvals().size());
|
||||
std::vector<XlaExpression> retvals = context->retvals();
|
||||
if (options.resolve_compile_time_constants) {
|
||||
TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
|
||||
client(), absl::Span<XlaExpression>(retvals)));
|
||||
} else {
|
||||
ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
real_args, arg_cores, context->retvals(), context->resources(),
|
||||
std::move(token_output), options.return_updated_values_for_all_resources,
|
||||
real_args, retvals, arg_cores, retval_cores, context->resources(),
|
||||
std::move(token_output),
|
||||
options.is_entry_computation ? options_.shape_representation_fn
|
||||
: ShapeRepresentationFn{},
|
||||
options.return_updated_values_for_all_resources,
|
||||
options.always_return_tuple, &builder, result->computation.get(),
|
||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||
&result->resource_updates));
|
||||
|
@ -18,10 +18,13 @@ limitations under the License.
|
||||
|
||||
#include <stack>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -118,7 +121,7 @@ class XlaCompiler {
|
||||
|
||||
// The type of the argument. If the argument is a resource, this
|
||||
// is the type of the variable's value, not DT_RESOURCE.
|
||||
DataType type;
|
||||
DataType type = DT_INVALID;
|
||||
|
||||
// The shape of the argument. For:
|
||||
// * a parameter: the shape of the parameter.
|
||||
@ -155,6 +158,9 @@ class XlaCompiler {
|
||||
std::set<string> tensor_array_gradients;
|
||||
|
||||
bool operator==(const Argument& other) const;
|
||||
|
||||
// Returns a human-readable summary of the argument.
|
||||
string HumanString() const;
|
||||
};
|
||||
|
||||
// Options pertaining to an individual call to CompileGraph() or
|
||||
@ -259,8 +265,7 @@ class XlaCompiler {
|
||||
std::shared_ptr<xla::XlaComputation> computation;
|
||||
};
|
||||
|
||||
typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&,
|
||||
DataType)>
|
||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
|
||||
ShapeRepresentationFn;
|
||||
struct Options {
|
||||
// Name of the compilation device to use. It must be set by the caller.
|
||||
@ -316,22 +321,23 @@ class XlaCompiler {
|
||||
|
||||
Status CompileFunction(const CompileOptions& options,
|
||||
const NameAttrList& fn_name_attrs,
|
||||
std::vector<Argument> args, CompilationResult* result);
|
||||
absl::Span<const Argument> args,
|
||||
CompilationResult* result);
|
||||
|
||||
// Compiles a tensorflow::Graph into an xla::XlaComputation.
|
||||
// Similar to CompileFunction, but takes a Graph as input rather than a
|
||||
// function.
|
||||
Status CompileGraph(const CompileOptions& options, string const& name,
|
||||
std::unique_ptr<Graph> graph,
|
||||
const std::vector<Argument>& args,
|
||||
absl::Span<const Argument> args,
|
||||
CompilationResult* result);
|
||||
|
||||
// Compiles a single Op, given by an OpKernelContext, into an
|
||||
// Compiles a single Op, given by `node_def`, into an
|
||||
// xla::XlaComputation. Similar to CompileFunction but takes a single Op as
|
||||
// input.
|
||||
Status CompileSingleOp(const CompileOptions& options, string const& name,
|
||||
OpKernelContext* ctx,
|
||||
const std::vector<Argument>& args,
|
||||
Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def,
|
||||
absl::Span<const Argument> args,
|
||||
absl::Span<const DataType> result_types,
|
||||
CompilationResult* result);
|
||||
|
||||
// Returns the shape of the XLA parameter for an argument 'arg'.
|
||||
@ -411,7 +417,8 @@ class XlaCompiler {
|
||||
Status BuildArguments(const Graph& graph,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::XlaBuilder* builder,
|
||||
XlaContext* context, std::vector<int>* arg_cores,
|
||||
XlaContext* context,
|
||||
const std::map<int, int>& arg_cores,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping,
|
||||
std::vector<xla::Shape>* input_shapes,
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -1018,9 +1019,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) {
|
||||
return TensorShape({shape.num_elements()});
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
xla::PrimitiveType ptype;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
|
||||
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
|
||||
};
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
@ -1086,9 +1089,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) {
|
||||
return TensorShape({shape.num_elements()});
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
xla::PrimitiveType ptype;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
|
||||
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
|
||||
};
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
|
@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
||||
|
||||
XlaContext::XlaContext(
|
||||
XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
const std::function<xla::StatusOr<TensorShape>(
|
||||
bool allow_cpu_custom_calls,
|
||||
const std::function<xla::StatusOr<xla::Shape>(
|
||||
const TensorShape&, DataType)>* shape_representation_fn)
|
||||
: compiler_(compiler),
|
||||
builder_(builder),
|
||||
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
||||
resolve_compile_time_constants_(resolve_compile_time_constants),
|
||||
is_entry_computation_(is_entry_computation),
|
||||
shape_representation_fn_(shape_representation_fn) {}
|
||||
|
||||
string XlaContext::DebugString() { return "TLA JIT context"; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void XlaContext::AddRetval(int retval_index, DataType type,
|
||||
const TensorShape& shape, const xla::XlaOp& handle) {
|
||||
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
||||
// Add the return value to the list being built up.
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
void XlaContext::SetRetval(int index, const XlaExpression& expression) {
|
||||
if (retvals_.size() <= index) {
|
||||
retvals_.resize(index + 1);
|
||||
}
|
||||
XlaExpression e;
|
||||
e.set_handle(handle);
|
||||
retvals_[retval_index] = Retval{type, shape, e};
|
||||
retvals_[index] = expression;
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::LiteralSlice& literal) {
|
||||
VLOG(1) << "Adding retval index " << retval_index
|
||||
<< " with non-data-dependent tensor to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
Tensor value;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||
XlaExpression e;
|
||||
e.set_constant_value(value);
|
||||
retvals_[retval_index] = Retval{dtype, value.shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
|
||||
VLOG(1) << "Adding retval index " << retval_index << " with resource "
|
||||
<< resource->name() << ":" << resource->shape().DebugString()
|
||||
<< " to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
XlaExpression e;
|
||||
e.set_resource(resource);
|
||||
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateResource(
|
||||
XlaResource::Kind kind, int arg_num, string name, DataType type,
|
||||
TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
|
||||
@ -133,7 +93,7 @@ Status XlaContext::CreateResource(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<TensorShape> XlaContext::RepresentationShape(
|
||||
xla::StatusOr<xla::Shape> XlaContext::RepresentationShape(
|
||||
const TensorShape& shape, DataType type) const {
|
||||
return (*shape_representation_fn_)(shape, type);
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -46,9 +46,8 @@ class XlaContext : public ResourceBase {
|
||||
// Creates a new XlaContext. See the documentation on the class data fields
|
||||
// for descriptions of the arguments.
|
||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
const std::function<xla::StatusOr<TensorShape>(
|
||||
bool allow_cpu_custom_calls,
|
||||
const std::function<xla::StatusOr<xla::Shape>(
|
||||
const TensorShape&, DataType)>* shape_representation_fn);
|
||||
|
||||
// Virtual method defined by ResourceBase.
|
||||
@ -57,37 +56,19 @@ class XlaContext : public ResourceBase {
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
// Returns the XlaBuilder that Ops use for compiling new expressions.
|
||||
xla::XlaBuilder* builder();
|
||||
xla::XlaBuilder* builder() { return builder_; }
|
||||
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
|
||||
bool resolve_compile_time_constants() const {
|
||||
return resolve_compile_time_constants_;
|
||||
}
|
||||
bool is_entry_computation() const { return is_entry_computation_; }
|
||||
|
||||
const std::vector<XlaExpression>& args() const { return args_; }
|
||||
void set_args(std::vector<XlaExpression> args);
|
||||
|
||||
struct Retval {
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
// An XlaExpression representing the Retval's value.
|
||||
XlaExpression expression;
|
||||
};
|
||||
const std::vector<Retval>& retvals() { return retvals_; }
|
||||
const std::vector<XlaExpression>& retvals() { return retvals_; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
|
||||
const xla::XlaOp& handle);
|
||||
|
||||
// As for Retval, but for return values that are compile-time constants.
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::LiteralSlice& literal);
|
||||
|
||||
// As for Retval, but for return values that are resource handles.
|
||||
Status AddResourceRetval(int retval_index, XlaResource* resource);
|
||||
// Sets a return value.
|
||||
// Since we do not always know in advance how many return values there are,
|
||||
// grows the return values vector to size index+1 if it is smaller.
|
||||
void SetRetval(int index, const XlaExpression& expression);
|
||||
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
@ -105,7 +86,7 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
// Returns the XLA shape to be used to represent a variable of TF `shape`
|
||||
// and `type`, or of an argument or return value of a top-level computation.
|
||||
xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape,
|
||||
xla::StatusOr<xla::Shape> RepresentationShape(const TensorShape& shape,
|
||||
DataType type) const;
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
@ -140,31 +121,19 @@ class XlaContext : public ResourceBase {
|
||||
// Allow ops to emit CustomCall operations for CPU.
|
||||
const bool allow_cpu_custom_calls_;
|
||||
|
||||
// If true, constant return values are returned as Tensors instead of
|
||||
// run-time computation outputs.
|
||||
const bool resolve_compile_time_constants_;
|
||||
|
||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||
// Includes both compile-time constant arguments and runtime parameters.
|
||||
std::vector<XlaExpression> args_;
|
||||
|
||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||
std::vector<Retval> retvals_;
|
||||
std::vector<XlaExpression> retvals_;
|
||||
|
||||
// Holds ownership of resources. The resources are not ordered.
|
||||
std::vector<std::unique_ptr<XlaResource>> resources_;
|
||||
|
||||
// Is this a top-level computation, or an inner computation (e.g., a while
|
||||
// body)?
|
||||
const bool is_entry_computation_;
|
||||
|
||||
// A function that describes how the shapes of
|
||||
// a) argument and return value, for entry computations
|
||||
// b) variables, for all computations,
|
||||
// should be represented in XLA. Parameters/return values will be shaped
|
||||
// according to this function, and reshaped back to/from their declared shapes
|
||||
// for computations. Must be non-null.
|
||||
const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>*
|
||||
// Describes the on-host shapes of parameters and return values. Also see:
|
||||
// XlaDevice::Options::shape_representation_fn.
|
||||
const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*
|
||||
shape_representation_fn_;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
|
145
tensorflow/compiler/tf2xla/xla_expression.cc
Normal file
145
tensorflow/compiler/tf2xla/xla_expression.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
XlaExpression::XlaExpression() = default;
|
||||
|
||||
XlaExpression XlaExpression::Invalid() {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kInvalid;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::Constant(Tensor value) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kConstant;
|
||||
e.dtype_ = value.dtype();
|
||||
e.constant_value_ = value;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kXlaOp;
|
||||
e.dtype_ = dtype;
|
||||
e.handle_ = value;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::Resource(XlaResource* resource) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kResource;
|
||||
e.dtype_ = DT_RESOURCE;
|
||||
e.resource_ = resource;
|
||||
return e;
|
||||
}
|
||||
|
||||
string XlaExpression::HumanString() const {
|
||||
switch (kind_) {
|
||||
case Kind::kInvalid:
|
||||
return "invalid";
|
||||
case Kind::kConstant:
|
||||
return "constant";
|
||||
case Kind::kXlaOp:
|
||||
return "xla_op";
|
||||
case Kind::kResource:
|
||||
return "resource";
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
switch (kind_) {
|
||||
case Kind::kConstant: {
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToBorrowingLiteral(constant_value_, &literal));
|
||||
return xla::ConstantLiteral(builder, literal);
|
||||
}
|
||||
case Kind::kXlaOp:
|
||||
if (builder != handle_.builder()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched builders in XlaExpression::AsXlaOp");
|
||||
}
|
||||
return handle_;
|
||||
default:
|
||||
return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
|
||||
HumanString());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::Client* client) const {
|
||||
switch (kind()) {
|
||||
case Kind::kConstant:
|
||||
return {constant_value()};
|
||||
case Kind::kXlaOp:
|
||||
break;
|
||||
case Kind::kResource:
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"ResolveConstant called on XlaExpression: ", HumanString());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||
handle().builder()->IsConstant(handle()));
|
||||
if (!is_constant) return {absl::nullopt};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||
handle().builder()->BuildConstantSubGraph(handle()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow uses a major to
|
||||
// minor order.
|
||||
std::vector<int64> layout_indices(shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||
client->ComputeConstant(constant_graph, &layout));
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
|
||||
return {tensor};
|
||||
}
|
||||
|
||||
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||
switch (kind_) {
|
||||
case Kind::kConstant:
|
||||
return constant_value().shape();
|
||||
case Kind::kXlaOp: {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
|
||||
handle().builder()->GetShape(handle()));
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
|
||||
return shape;
|
||||
}
|
||||
case Kind::kResource:
|
||||
return TensorShape({});
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"GetShape() called on invalid XlaExpression");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
115
tensorflow/compiler/tf2xla/xla_expression.h
Normal file
115
tensorflow/compiler/tf2xla/xla_expression.h
Normal file
@ -0,0 +1,115 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
|
||||
// compilation.
|
||||
// An expression is one of:
|
||||
// * a constant tensor.
|
||||
// * an xla::XlaOp, representing a symbolic XLA value.
|
||||
// * a resource, e.g., a variable, represented as an XlaResource pointer.
|
||||
//
|
||||
// Constant tensors are mostly an optimization to avoid passing large constants
|
||||
// to XLA, but are also sometimes used to represent tensors that have no XLA
|
||||
// representation, for example, DT_STRING tensors. A canonical use case might be
|
||||
// an error message string.
|
||||
class XlaExpression {
|
||||
public:
|
||||
enum class Kind {
|
||||
kInvalid,
|
||||
kConstant,
|
||||
kXlaOp,
|
||||
kResource,
|
||||
};
|
||||
|
||||
XlaExpression();
|
||||
XlaExpression(const XlaExpression&) = default;
|
||||
XlaExpression& operator=(const XlaExpression&) = default;
|
||||
|
||||
// Builds an invalid expression. (Same as the default constructor, but makes
|
||||
// the intent clearer.)
|
||||
static XlaExpression Invalid();
|
||||
|
||||
// Builds a constant XLA expression.
|
||||
static XlaExpression Constant(Tensor value);
|
||||
|
||||
// Builds a XlaOp expression. Since the mapping from TF data types to XLA
|
||||
// types is not 1-1, the TF type must also be provided; in general it cannot
|
||||
// be derived from the XLA type.
|
||||
static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
|
||||
|
||||
// Builds a resource expression.
|
||||
static XlaExpression Resource(XlaResource* resource);
|
||||
|
||||
Kind kind() const { return kind_; }
|
||||
|
||||
DataType dtype() const { return dtype_; }
|
||||
|
||||
// handle() returns the XlaOp that backs a kXlaOp expression.
|
||||
const xla::XlaOp& handle() const { return handle_; }
|
||||
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
// Returns a human-readable summary of the expression.
|
||||
string HumanString() const;
|
||||
|
||||
// Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
|
||||
// an erroneous XlaOp if the expression is not a constant or an expression.
|
||||
xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
|
||||
|
||||
// If a kXlaOp or kConstant expression can be resolved to a compile-time
|
||||
// constant, returns the value as a host-memory Tensor. Returns an empty
|
||||
// optional if it cannot be resolved. Returns an error if passed a resource
|
||||
// expression.
|
||||
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
|
||||
xla::Client* client) const;
|
||||
|
||||
// Returns the shape of the tensor.
|
||||
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
|
||||
// not the shape of the resource's value.
|
||||
xla::StatusOr<TensorShape> GetShape() const;
|
||||
|
||||
private:
|
||||
Kind kind_ = Kind::kInvalid;
|
||||
|
||||
DataType dtype_ = DT_INVALID;
|
||||
|
||||
// The XLA handle of the expression's computation, if kind_ == kXlaOp.
|
||||
xla::XlaOp handle_;
|
||||
|
||||
// The value of the constant, if kind_ == kConstant.
|
||||
Tensor constant_value_;
|
||||
|
||||
// The resource, if kind_ == kResource. Not owned.
|
||||
XlaResource* resource_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
|
135
tensorflow/compiler/tf2xla/xla_expression_test.cc
Normal file
135
tensorflow/compiler/tf2xla/xla_expression_test.cc
Normal file
@ -0,0 +1,135 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaExpressionTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
client_ = xla::ClientLibrary::LocalClientOrDie();
|
||||
builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
|
||||
constant_ = test::AsScalar<int32>(42);
|
||||
op_ = xla::ConstantR0<int32>(builder_.get(), 7);
|
||||
non_constant_op_ = xla::Parameter(
|
||||
builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
|
||||
resource_ = absl::make_unique<XlaResource>(
|
||||
XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
|
||||
DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
|
||||
/*tensor_array_gradients=*/std::set<string>(),
|
||||
/*tensor_array_multiple_writes_aggregate=*/false);
|
||||
}
|
||||
|
||||
xla::Client* client_;
|
||||
std::unique_ptr<xla::XlaBuilder> builder_;
|
||||
Tensor constant_;
|
||||
xla::XlaOp op_;
|
||||
xla::XlaOp non_constant_op_;
|
||||
std::unique_ptr<XlaResource> resource_;
|
||||
};
|
||||
|
||||
TEST_F(XlaExpressionTest, Kind) {
|
||||
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kConstant ==
|
||||
XlaExpression::Constant(constant_).kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
|
||||
XlaExpression::XlaOp(op_, DT_INT32).kind());
|
||||
EXPECT_TRUE(XlaExpression::Kind::kResource ==
|
||||
XlaExpression::Resource(resource_.get()).kind());
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, HumanString) {
|
||||
EXPECT_EQ("invalid", XlaExpression().HumanString());
|
||||
EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
|
||||
EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
|
||||
EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
|
||||
EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, AsXlaOp) {
|
||||
xla::XlaOp op_as_op =
|
||||
XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
|
||||
EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
|
||||
|
||||
xla::XlaOp const_as_op =
|
||||
XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
|
||||
builder_->BuildConstantSubGraph(const_as_op));
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
|
||||
client_->ComputeConstant(computation));
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
|
||||
value));
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, GetShape) {
|
||||
EXPECT_FALSE(XlaExpression().GetShape().ok());
|
||||
EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
|
||||
XlaExpression::Resource(resource_.get()).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), resource_shape);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
|
||||
XlaExpression::XlaOp(op_, DT_INT32).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), op_shape);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
|
||||
XlaExpression::Constant(constant_).GetShape());
|
||||
EXPECT_EQ(TensorShape({}), constant_shape);
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, ResolveConstant) {
|
||||
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(
|
||||
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
absl::optional<Tensor> op_constant,
|
||||
XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
|
||||
ASSERT_TRUE(op_constant.has_value());
|
||||
test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
|
||||
XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
|
||||
.ResolveConstant(client_));
|
||||
EXPECT_FALSE(op_nonconstant.has_value());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
absl::optional<Tensor> constant_constant,
|
||||
XlaExpression::Constant(constant_).ResolveConstant(client_));
|
||||
ASSERT_TRUE(constant_constant.has_value());
|
||||
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->handle().valid() || expression->resource() != nullptr);
|
||||
VLOG(1) << "Fetched T" << expression->handle();
|
||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
return expression;
|
||||
}
|
||||
|
||||
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
|
||||
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
static void AssignExpressionToTensor(Tensor* tensor,
|
||||
const XlaExpression& value) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK(!expression->handle().valid());
|
||||
return const_cast<XlaExpression*>(expression);
|
||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
*const_cast<XlaExpression*>(expression) = value;
|
||||
}
|
||||
|
||||
// Retrieves the XlaOp from an input Tensor to an Op. This computation was
|
||||
// constructed by an Op that executed previously and created the output Tensor
|
||||
// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
|
||||
static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
|
||||
return CastExpressionFromTensor(tensor)->handle();
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
|
||||
return *CastExpressionFromTensor(context_->input(index));
|
||||
}
|
||||
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(int index) {
|
||||
return GetComputationFromTensor(context_->input(index));
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(
|
||||
absl::string_view name) {
|
||||
return *CastExpressionFromTensor(GetInputTensorByName(name));
|
||||
}
|
||||
|
||||
const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
|
||||
return GetComputationFromTensor(GetInputTensorByName(name));
|
||||
xla::XlaOp XlaOpKernelContext::Input(int index) {
|
||||
return InputExpression(index).AsXlaOp(builder());
|
||||
}
|
||||
|
||||
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
|
||||
return InputExpression(name).AsXlaOp(builder());
|
||||
}
|
||||
|
||||
TensorShape XlaOpKernelContext::InputShape(int index) {
|
||||
@ -125,77 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name,
|
||||
Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
TensorShape new_shape(new_dims);
|
||||
if (tensor.NumElements() != new_shape.num_elements()) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
tensor.shape().DebugString(),
|
||||
" but was asked to be reshaped to incompatible shape ",
|
||||
new_shape.DebugString());
|
||||
}
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
|
||||
auto copy_tensor_to_literal = [](const Tensor& tensor,
|
||||
xla::Literal* literal) {
|
||||
xla::Shape literal_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
|
||||
|
||||
*literal = xla::Literal(literal_shape);
|
||||
|
||||
// memcpy over the payload ...
|
||||
// TODO(phawkins): handle string types.
|
||||
size_t total_bytes = tensor.TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
void* dst_ptr = literal->untyped_data();
|
||||
const void* src_ptr = DMAHelper::base(&tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// If the tensor has a known constant value, there is no need to invoke XLA.
|
||||
if (expression->has_constant_value()) {
|
||||
Tensor temp(tensor.dtype());
|
||||
if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
|
||||
// This should never happen. The constant should have a shape compatible
|
||||
// with the enclosing Tensor.
|
||||
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
|
||||
}
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
}
|
||||
|
||||
// Make sure we treat zero-element tensors as constant.
|
||||
if (new_shape.num_elements() == 0) {
|
||||
Tensor temp(tensor.dtype(), new_shape);
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
}
|
||||
|
||||
xla::XlaOp handle = expression->handle();
|
||||
if (new_shape != tensor.shape()) {
|
||||
// Reshape the handle to the desired shape.
|
||||
handle = xla::Reshape(handle, new_shape.dim_sizes());
|
||||
}
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow's minor
|
||||
// dimension is the last one.
|
||||
std::vector<int64> layout_indices(new_shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
|
||||
xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
|
||||
if (!is_constant.ok()) {
|
||||
Status status = is_constant.status();
|
||||
XlaExpression e = InputExpression(index);
|
||||
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
|
||||
e.ResolveConstant(compiler()->client());
|
||||
if (!constant_or_status.ok()) {
|
||||
Status status = constant_or_status.status();
|
||||
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
|
||||
context_->op_kernel().type_string(),
|
||||
" operator as a compile-time constant.");
|
||||
return status;
|
||||
}
|
||||
|
||||
if (!is_constant.ValueOrDie()) {
|
||||
absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
|
||||
if (!constant.has_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Input ", index, " to ", context_->op_kernel().type_string(),
|
||||
" operator must be a compile-time constant.\n"
|
||||
@ -208,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
"stateful operation such as a random number generator.");
|
||||
}
|
||||
|
||||
// Ask the XLA compiler to evaluate the data handle to a literal.
|
||||
xla::StatusOr<xla::XlaComputation> constant_graph =
|
||||
builder()->BuildConstantSubGraph(handle);
|
||||
if (!constant_graph.ok()) {
|
||||
return errors::Internal(
|
||||
"Error getting a compile-time constant graph for ",
|
||||
context_->op_kernel().name(), " input ", index,
|
||||
".\nError: ", constant_graph.status().error_message());
|
||||
Tensor temp(constant->dtype());
|
||||
if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
constant->shape().DebugString(),
|
||||
" but was asked to be reshaped to incompatible shape ",
|
||||
TensorShape(new_dims).DebugString());
|
||||
}
|
||||
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
|
||||
constant_graph.ValueOrDie(), &layout);
|
||||
if (!computed.ok()) {
|
||||
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
|
||||
" input ", index,
|
||||
" as a compile-time constant.\nError: ",
|
||||
computed.status().error_message());
|
||||
}
|
||||
*constant_literal = std::move(computed).ValueOrDie();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -322,6 +259,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
|
||||
return LiteralToInt64Vector(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
|
||||
absl::string_view name, std::vector<int64>* out) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
xla::Literal literal;
|
||||
TF_RETURN_IF_ERROR(ConstantInputReshaped(
|
||||
index, {InputShape(index).num_elements()}, &literal));
|
||||
return LiteralToInt64Vector(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
|
||||
xla::Literal* out) {
|
||||
xla::Literal literal;
|
||||
@ -372,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
|
||||
handles->clear();
|
||||
shapes->clear();
|
||||
for (const Tensor& input : inputs) {
|
||||
handles->push_back(GetComputationFromTensor(input));
|
||||
handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
||||
shapes->push_back(input.shape());
|
||||
}
|
||||
return Status::OK();
|
||||
@ -413,9 +359,12 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
|
||||
XlaContext& xla_context = XlaContext::Get(ctx);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
TensorShape representation_shape,
|
||||
xla::Shape representation_shape,
|
||||
xla_context.RepresentationShape(variable->shape(), variable->type()));
|
||||
if (representation_shape == variable->shape()) {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
|
||||
if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
|
||||
*value = variable->value();
|
||||
} else {
|
||||
*value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
|
||||
@ -455,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
|
||||
Tensor** output) {
|
||||
void XlaOpKernelContext::SetOutputExpression(int index,
|
||||
const XlaExpression& expression) {
|
||||
Status status = [&] {
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
if (expected_output_dtype(index) == DT_VARIANT) {
|
||||
Tensor* output = nullptr;
|
||||
// Provides a special behavior for DT_VARIANT: a variant is treated as
|
||||
// DT_UINT8 scalar as the type to allow mapping for variant to more generic
|
||||
// types.
|
||||
if (expression.dtype() == DT_VARIANT) {
|
||||
// tensor_data() is not supported for variant Tensor (i.e.,
|
||||
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
|
||||
// XlaExpression inside the Tensor's tensor_data() does not work for
|
||||
// variant. Instead construct a uint8 tensor and store the expression in its
|
||||
// value.
|
||||
// variant. Instead construct a uint8 tensor and store the expression in
|
||||
// its value.
|
||||
// TODO(jpienaar): This should be refactored to stop masquerading
|
||||
// XlaExpressions as Tensors.
|
||||
*output = new Tensor();
|
||||
output = new Tensor();
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context_->allocate_temp(DT_UINT8, tensor_shape, *output));
|
||||
context_->set_output(index, **output);
|
||||
context_->allocate_temp(DT_UINT8, tensor_shape, output));
|
||||
context_->set_output(index, *output);
|
||||
} else {
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
||||
}
|
||||
AssignExpressionToTensor(output, expression);
|
||||
return Status::OK();
|
||||
}();
|
||||
if (!status.ok()) {
|
||||
SetStatus(status);
|
||||
}
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||
// Makes the host Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
auto shape_or = builder()->GetShape(handle);
|
||||
if (!shape_or.ok()) {
|
||||
SetStatus(shape_or.status());
|
||||
return;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context_,
|
||||
allocate_output(index, shape_or.ValueOrDie(), &output));
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
SetOutputExpression(
|
||||
index,
|
||||
XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
const TensorShape& shape = constant.shape();
|
||||
|
||||
xla::BorrowingLiteral literal;
|
||||
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
|
||||
|
||||
xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
|
||||
CHECK(handle.valid());
|
||||
|
||||
// Make the Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_handle(handle);
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetInvalidOutput(int index) {
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape({}), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
xla::XlaOp handle;
|
||||
expression->set_handle(handle);
|
||||
SetOutputExpression(index, XlaExpression::Constant(constant));
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
||||
Tensor* output = nullptr;
|
||||
// The shape of the output tensor is the shape of the resource itself
|
||||
// (i.e., a scalar), not the shape of the resource's value.
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape(), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_resource(resource);
|
||||
SetOutputExpression(index, XlaExpression::Resource(resource));
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
||||
@ -570,10 +482,13 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
||||
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
|
||||
|
||||
XlaContext& xla_context = XlaContext::Get(ctx);
|
||||
TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
||||
xla_context.RepresentationShape(shape, type));
|
||||
if (shape != representation_shape) {
|
||||
handle = xla::Reshape(handle, representation_shape.dim_sizes());
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
|
||||
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
|
||||
handle = xla::Reshape(handle,
|
||||
xla::AsInt64Slice(representation_shape.dimensions()));
|
||||
}
|
||||
return variable->SetValue(handle);
|
||||
}
|
||||
|
@ -88,9 +88,9 @@ class XlaOpKernelContext {
|
||||
// Returns input `index` as a XlaOp. Unlike
|
||||
// OpKernelContext::Input returns a symbolic value rather than a concrete
|
||||
// Tensor.
|
||||
const xla::XlaOp& Input(int index);
|
||||
xla::XlaOp Input(int index);
|
||||
// Returns input `name` as a XlaOp.
|
||||
const xla::XlaOp& Input(absl::string_view name);
|
||||
xla::XlaOp Input(absl::string_view name);
|
||||
|
||||
// Returns true if all inputs are the same shape, otherwise sets the
|
||||
// status to a non-OK value and returns false.
|
||||
@ -111,14 +111,6 @@ class XlaOpKernelContext {
|
||||
Status ConstantInput(int index, xla::Literal* constant_literal);
|
||||
Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
|
||||
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
// returns a non-Ok status. If InputShape(index).num_elements() !=
|
||||
// new_shape.num_elements(), returns an error status.
|
||||
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
// Converts a constant scalar int32 or int64 tensor into an int64.
|
||||
Status ConstantInputAsIntScalar(int index, int64* out);
|
||||
Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
|
||||
@ -134,6 +126,8 @@ class XlaOpKernelContext {
|
||||
// Reshapes and converts a constant int32 or int64 tensor into a vector of
|
||||
// int64s.
|
||||
Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
|
||||
Status ConstantInputReshapedToIntVector(absl::string_view name,
|
||||
std::vector<int64>* out);
|
||||
|
||||
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
|
||||
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
|
||||
@ -148,6 +142,10 @@ class XlaOpKernelContext {
|
||||
Status ConstantInputList(absl::string_view name,
|
||||
std::vector<xla::Literal>* literals);
|
||||
|
||||
// Returns an XlaExpression describing the value of 'index'.
|
||||
const XlaExpression& InputExpression(int index);
|
||||
const XlaExpression& InputExpression(absl::string_view name);
|
||||
|
||||
// Outputs
|
||||
|
||||
int num_outputs() const { return context_->num_outputs(); }
|
||||
@ -165,9 +163,8 @@ class XlaOpKernelContext {
|
||||
// SetConstantOutput where possible.
|
||||
void SetConstantOutput(int index, const Tensor& host_tensor);
|
||||
|
||||
// Sets output `index` to an invalid value.
|
||||
// Any subsequent attempt to consume this output will cause an error.
|
||||
void SetInvalidOutput(int index);
|
||||
// Returns an XlaExpression describing the value of 'index'.
|
||||
void SetOutputExpression(int index, const XlaExpression& expression);
|
||||
|
||||
// Status handling.
|
||||
void SetStatus(const Status& status) { context_->SetStatus(status); }
|
||||
@ -255,10 +252,13 @@ class XlaOpKernelContext {
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
||||
// Wraps OpKernelContext's allocate_output method while providing special
|
||||
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
|
||||
// type to allow mapping for variant to more generic types.
|
||||
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
// returns a non-Ok status. If InputShape(index).num_elements() !=
|
||||
// new_shape.num_elements(), returns an error status.
|
||||
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
OpKernelContext* const context_;
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
// Lazily register the CPU and GPU JIT devices the first time
|
||||
// GetCompilationDevice is called.
|
||||
static void* registration_init = [®istry]() {
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
|
||||
mutex_lock lock(registry.mutex_);
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_CPU];
|
||||
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
registration.requires_compilation = false;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
cpu_global_jit
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
|
||||
registration.compile_resource_ops = false;
|
||||
}
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_GPU];
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.requires_compilation = false;
|
||||
registration.enable_jit_by_default = true;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||
registration.compile_resource_ops = false;
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -66,19 +66,26 @@ class XlaOpRegistry {
|
||||
public:
|
||||
typedef OpKernel* (*Factory)(OpKernelConstruction*);
|
||||
|
||||
enum class AutoclusteringPolicy {
|
||||
// Enable autoclustering if the user requests it, e.g., via
|
||||
// experimental_jit_scope. Does not autocluster if the JIT is enabled
|
||||
// globally (e.g., via the OptimizerOptions in the TF session
|
||||
// configuration.)
|
||||
kIfExplicitlyRequested,
|
||||
// Enable autoclustering if explicitly requested, or if the JIT is enabled
|
||||
// globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
|
||||
kIfEnabledGlobally,
|
||||
// Always try to autocluster ops placed on this device.
|
||||
kAlways,
|
||||
};
|
||||
|
||||
// Describes how to compile operators assigned to a device.
|
||||
struct DeviceRegistration {
|
||||
// The name of the an XLA compilation device to use to compile code.
|
||||
string compilation_device_name;
|
||||
|
||||
// Do operators assigned to this device require compilation?
|
||||
bool requires_compilation;
|
||||
|
||||
// If !requires_compilation, should we try to JIT operators on this device
|
||||
// when XLA JIT compilation is enabled globally via the SessionOptions?
|
||||
// (It is still possible to explicitly mark operators to JIT compile, even
|
||||
// if enable_jit_by_default is false.)
|
||||
bool enable_jit_by_default;
|
||||
// When should we autocluster operators assigned to this device?
|
||||
AutoclusteringPolicy autoclustering_policy;
|
||||
|
||||
// Enable compilation of operators that use DT_RESOURCE types?
|
||||
bool compile_resource_ops = false;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user