merging master with upstream master

This commit is contained in:
Steve Nesae 2018-11-13 10:25:27 -06:00
commit d391ba441b
1179 changed files with 48908 additions and 17415 deletions

View File

@ -14,6 +14,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories() closure_repositories()
http_archive(
name = "base_images_docker",
sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9",
strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6",
urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"],
)
http_archive(
name = "bazel_toolchains",
sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb",
strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b",
urls = [
"https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz",
],
)
http_archive(
name = "io_bazel_rules_docker",
sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd",
strip_prefix = "rules_docker-0.5.1",
urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"],
)
load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace")
remote_config_workspace()
# We must check the bazel version before trying to parse any other BUILD # We must check the bazel version before trying to parse any other BUILD
# files, in case the parsing of those build files depends on the bazel # files, in case the parsing of those build files depends on the bazel
# version we require here. # version we require here.
@ -79,3 +106,4 @@ new_http_archive(
"http://download.tensorflow.org/models/speech_commands_v0.01.zip", "http://download.tensorflow.org/models/speech_commands_v0.01.zip",
], ],
) )

View File

@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
_TF_OPENCL_VERSION = '1.2' _TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@ -1555,6 +1555,9 @@ def main():
check_bazel_version('0.15.0') check_bazel_version('0.15.0')
reset_tf_configure_bazelrc() reset_tf_configure_bazelrc()
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
write_to_bazelrc('import %workspace%/tools/bazel.rc')
cleanup_makefile() cleanup_makefile()
setup_python(environ_cp) setup_python(environ_cp)

View File

@ -352,6 +352,7 @@ package_group(
"//tensorflow/...", "//tensorflow/...",
"//tensorflow_estimator/...", "//tensorflow_estimator/...",
"//tensorflow_fold/llgtm/...", "//tensorflow_fold/llgtm/...",
"//tensorflow_text/...",
"//third_party/py/tensor2tensor/...", "//third_party/py/tensor2tensor/...",
], ],
) )

View File

@ -95,6 +95,7 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib",
], ],
}) + select({ }) + select({
"//tensorflow:with_xla_support": [ "//tensorflow:with_xla_support": [
@ -199,7 +200,7 @@ tf_cuda_cc_test(
size = "small", size = "small",
srcs = ["c_api_test.cc"], srcs = ["c_api_test.cc"],
data = [ data = [
":test_op.so", ":test_op1.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two", "//tensorflow/cc/saved_model:saved_model_half_plus_two",
], ],
kernels = [":test_op_kernel"], kernels = [":test_op_kernel"],
@ -218,6 +219,7 @@ tf_cuda_cc_test(
"//tensorflow/cc:grad_ops", "//tensorflow/cc:grad_ops",
"//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/cc/saved_model:tag_constants", "//tensorflow/cc/saved_model:tag_constants",
"//tensorflow/compiler/jit",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session", "//tensorflow/core:direct_session",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -284,8 +286,8 @@ tf_cc_test(
) )
tf_custom_op_library( tf_custom_op_library(
name = "test_op.so", name = "test_op1.so",
srcs = ["test_op.cc"], srcs = ["test_op1.cc"],
) )
tf_kernel_library( tf_kernel_library(

View File

@ -2810,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
} }
return ret; return ret;
} }
// TF_Server functions ----------------------------------------------
#ifndef __ANDROID__
TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
: target(server->target()), server(std::move(server)) {}
#endif // __ANDROID__
TF_Server* TF_NewServer(const void* proto, size_t proto_len,
TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Server functionality is not supported in Android");
return nullptr;
#else
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
status->status = InvalidArgument(
"Could not parse provided bytes into a ServerDef protocol buffer");
return nullptr;
}
std::unique_ptr<tensorflow::ServerInterface> out_server;
status->status = tensorflow::NewServer(server_def, &out_server);
if (!status->status.ok()) return nullptr;
return new TF_Server(std::move(out_server));
#endif
}
void TF_ServerStart(TF_Server* server, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Server functionality is not supported in Android");
#else
status->status = server->server->Start();
#endif
}
void TF_ServerStop(TF_Server* server, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Server functionality is not supported in Android");
#else
status->status = server->server->Stop();
#endif
}
void TF_ServerJoin(TF_Server* server, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Server functionality is not supported in Android");
#else
status->status = server->server->Join();
#endif
}
const char* TF_ServerTarget(TF_Server* server) {
#ifdef __ANDROID__
return nullptr;
#else
return server->target.c_str();
#endif
}
void TF_DeleteServer(TF_Server* server) { delete server; }
} // end extern "C" } // end extern "C"

View File

@ -1668,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status); const char* name, TF_Status* status);
// --------------------------------------------------------------------------
// In-process TensorFlow server functionality, for use in distributed training.
// A Server instance encapsulates a set of devices and a Session target that
// can participate in distributed training. A server belongs to a cluster
// (specified by a ClusterSpec), and corresponds to a particular task in a
// named job. The server can communicate with any other server in the same
// cluster.
// In-process TensorFlow server.
typedef struct TF_Server TF_Server;
// Creates a new in-process TensorFlow server configured using a serialized
// ServerDef protocol buffer provided via `proto` and `proto_len`.
//
// The server will not serve any requests until TF_ServerStart is invoked.
// The server will stop serving requests once TF_ServerStop or
// TF_DeleteServer is invoked.
TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto,
size_t proto_len,
TF_Status* status);
// Starts an in-process TensorFlow server.
TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status);
// Stops an in-process TensorFlow server.
TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status);
// Blocks until the server has been successfully stopped (via TF_ServerStop or
// TF_ServerClose).
TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status);
// Returns the target string that can be provided to TF_SetTarget() to connect
// a TF_Session to `server`.
//
// The returned string is valid only until TF_DeleteServer is invoked.
TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server);
// Destroy an in-process TensorFlow server, frees memory. If server is running
// it will be stopped and joined.
TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -25,6 +25,7 @@ limitations under the License.
#include <vector> #include <vector>
#ifndef __ANDROID__ #ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/op_gen_lib.h"
#endif #endif
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
@ -179,6 +180,15 @@ struct TF_ApiDefMap {
tensorflow::mutex lock; tensorflow::mutex lock;
}; };
#ifndef __ANDROID__
struct TF_Server {
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
const tensorflow::string target;
std::unique_ptr<tensorflow::ServerInterface> server;
};
#endif
namespace tensorflow { namespace tensorflow {
class TensorCApi { class TensorCApi {

View File

@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) {
// tf_cuda_cc_test() bazel rule and remove the next line. // tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return; if (!GPUDeviceName().empty()) return;
// Load the library. #if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
TF_Status* status = TF_NewStatus(); {
TF_Library* lib = // Load the library.
TF_LoadLibrary("tensorflow/c/test_op.so", status); TF_Status* status = TF_NewStatus();
TF_Code code = TF_GetCode(status); TF_Library* lib =
string status_msg(TF_Message(status)); TF_LoadLibrary("tensorflow/c/test_op1.so", status);
TF_DeleteStatus(status); TF_Code code = TF_GetCode(status);
ASSERT_EQ(TF_OK, code) << status_msg; string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
// Test op list.
TF_Buffer op_list_buf = TF_GetOpList(lib);
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
ASSERT_EQ(op_list.op_size(), 1);
EXPECT_EQ("TestCApi1", op_list.op(0).name());
TF_DeleteLibraryHandle(lib);
}
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
{ {
TF_Buffer* op_list_buffer = TF_GetAllOpList(); TF_Buffer* op_list_buffer = TF_GetAllOpList();
tensorflow::OpList op_list; tensorflow::OpList op_list;
@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) {
EXPECT_TRUE(found); EXPECT_TRUE(found);
TF_DeleteBuffer(op_list_buffer); TF_DeleteBuffer(op_list_buffer);
} }
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
{
// Test op list.
TF_Buffer op_list_buf = TF_GetOpList(lib);
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
ASSERT_EQ(op_list.op_size(), 1);
EXPECT_EQ("TestCApi", op_list.op(0).name());
}
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
TF_DeleteLibraryHandle(lib);
} }
void TestEncodeDecode(int line, const std::vector<string>& data) { void TestEncodeDecode(int line, const std::vector<string>& data) {

View File

@ -69,7 +69,7 @@ tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
hdrs = ["c_api_internal.h"], hdrs = ["c_api_internal.h"],
visibility = [ visibility = [
"//learning/deepmind/courier:__pkg__", "//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [

View File

@ -404,8 +404,7 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::Device* d = nullptr; tensorflow::Device* d = h->handle->op_device();
status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
} }

View File

@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr; return nullptr;
} }
tensorflow::Device* device;
status->status = handle->handle->Device(&device);
if (!status->status.ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn. // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device = tensorflow::XlaDevice* xla_device =
dynamic_cast<tensorflow::XlaDevice*>(device); dynamic_cast<tensorflow::XlaDevice*>(device);

View File

@ -79,10 +79,6 @@ struct TFE_TensorHandle {
tensorflow::Device* op_device) tensorflow::Device* op_device)
: handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
tensorflow::EagerContext* ctx)
: handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
tensorflow::TensorHandle* handle; tensorflow::TensorHandle* handle;

23
tensorflow/c/test_op1.cc Normal file
View File

@ -0,0 +1,23 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc");
} // namespace tensorflow

View File

@ -170,6 +170,7 @@ cc_library_with_android_deps(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
], ],
) )
@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc(
":array_ops", ":array_ops",
":const_op", ":const_op",
":math_ops", ":math_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
], ],
) )

View File

@ -93,7 +93,7 @@ cc_library(
":tfcompile_lib", ":tfcompile_lib",
"//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --cpp_class"); return errors::InvalidArgument("Must specify --cpp_class");
} }
codegen_opts.gen_hlo_profile_printer_data = codegen_opts.gen_hlo_profile_printer_data =
xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces)); &codegen_opts.namespaces));
@ -132,7 +132,7 @@ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list; std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags); AppendMainFlags(&flag_list, &flags);
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::AppendDebugOptionsFlags(&flag_list);
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list); usage += tensorflow::Flags::Usage(argv[0], flag_list);

View File

@ -21,7 +21,6 @@ package(
) )
load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
@ -52,6 +51,7 @@ cc_library(
deps = [ deps = [
":jit_compilation_passes", ":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:cpu_plugin",
], ],
@ -65,6 +65,7 @@ cc_library(
":jit_compilation_passes", ":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:gpu_plugin",
]), ]),
alwayslink = 1, alwayslink = 1,
@ -190,6 +191,7 @@ cc_library(
"//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:stack",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:iterator_ops",
@ -241,6 +243,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
) )
@ -253,6 +256,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
@ -263,6 +267,21 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "xla_compilation_cache_test",
srcs = [
"xla_compilation_cache_test.cc",
],
deps = [
":xla_compilation_cache",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
], ],
) )
@ -500,6 +519,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
], ],
) )
@ -524,25 +544,6 @@ cc_library(
hdrs = ["union_find.h"], hdrs = ["union_find.h"],
) )
cc_library(
name = "producer_consumer_queue",
hdrs = ["producer_consumer_queue.h"],
deps = ["//tensorflow/core:lib"],
)
tf_cc_test(
name = "producer_consumer_queue_test",
size = "small",
srcs = ["producer_consumer_queue_test.cc"],
deps = [
":producer_consumer_queue",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test( tf_cc_test(
name = "deadness_analysis_test", name = "deadness_analysis_test",
size = "small", size = "small",
@ -606,6 +607,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -648,31 +650,6 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "xla_launch_util_test",
size = "small",
srcs = ["xla_launch_util_test.cc"],
deps = [
":common",
":xla_compilation_cache",
":xla_launch_util",
":xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/kernels:variable_ops",
],
)
cc_library( cc_library(
name = "xla_fusion_optimizer", name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"], srcs = ["xla_fusion_optimizer.cc"],

View File

@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) {
return errors::Internal("Could not find compilation device ", return errors::Internal("Could not find compilation device ",
device_type.type()); device_type.type());
} }
*result = registration->requires_compilation; *result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK(); return Status::OK();
} }

View File

@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root,
Output loop_cond = Output loop_cond =
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
latch.output_false);
Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
latch.output_true, increment_by); latch.output_true, increment_by);
Output next_iteration = Output next_iteration =
@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue(
value, frame_name); value, frame_name);
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
latch.output_false);
Output next_iteration = ops::NextIteration( Output next_iteration = ops::NextIteration(
root.WithOpName(prefix + "/next_iteration"), latch.output_true); root.WithOpName(prefix + "/next_iteration"), latch.output_true);
CHECK(root.graph() CHECK(root.graph()

View File

@ -117,6 +117,25 @@ Status PreprocessForEncapsulation(Graph* g,
// Information for XLA computation. // Information for XLA computation.
struct XlaClusterInfo { struct XlaClusterInfo {
// Add an explicitly-defined default constructor for this class.
//
// The compiler may delete the default constructor here because
// host_compute_core is a const member whose type (std::map) doesn't
// necessarily have a user provided constructor -- while libc++ and
// libstdc++ 4.8 provide a user defined default constructor, libstdc++ at
// least >= 7.3 does not. See also c++11 [class.ctor] p5.
//
// TODO(klimek): In c++17 we'll be able to initialize host_compute_core
// without losing aggregate initialization, which allows us to get rid of
// the constructor definitions again.
XlaClusterInfo() {}
XlaClusterInfo(const string& cluster_name,
const NameAttrList& func_name_attrs, Node* node,
const std::map<string, int>& host_compute_core)
: cluster_name(cluster_name),
func_name_attrs(func_name_attrs),
node(node),
host_compute_core(host_compute_core) {}
// XLA cluster name. It might be different from `func_name`. // XLA cluster name. It might be different from `func_name`.
const string cluster_name; const string cluster_name;
// Name and attributes of XLA computation function. // Name and attributes of XLA computation function.

View File

@ -394,12 +394,12 @@ Status ConstructHostGraph(
for (const string& host_func : outside_compilation_host_graphs) { for (const string& host_func : outside_compilation_host_graphs) {
VLOG(4) << "Expanding host graph " << host_func; VLOG(4) << "Expanding host graph " << host_func;
FunctionBody* host_fbody = nullptr; FunctionBody* host_fbody = nullptr;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
FunctionDefToBodyHelper(*fld->Find(host_func), AttrSlice(), fld, *fld->Find(host_func), AttrSlice(), fld,
[&](const string& op, const OpDef** sig) { [&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig); return fld->LookUpOpDef(op, sig);
}, },
&host_fbody)); &host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody); std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
@ -411,52 +411,53 @@ Status ConstructHostGraph(
node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
Status s; Status s;
ReverseDFS(*host_fbody->graph, /*enter=*/nullptr, ReverseDFS(
[&](const Node* n) { *host_fbody->graph, /*enter=*/nullptr,
if (!s.ok()) { [&](const Node* n) {
return; if (!s.ok()) {
} return;
}
Node* copy; Node* copy;
if (node_map.find(n) != node_map.end()) { if (node_map.find(n) != node_map.end()) {
// Already copied this node. // Already copied this node.
copy = node_map.at(n); copy = node_map.at(n);
} else if (IsKeyPlaceholderNode(*n)) { } else if (IsKeyPlaceholderNode(*n)) {
// Change a). // Change a).
copy = key_placeholder; copy = key_placeholder;
node_map[n] = copy; node_map[n] = copy;
} else { } else {
// Copy the node. // Copy the node.
NodeDef copy_def = n->def(); NodeDef copy_def = n->def();
// Change c). // Change c).
copy_def.clear_device(); copy_def.clear_device();
copy = (*host_graph)->AddNode(copy_def, &s); copy = (*host_graph)->AddNode(copy_def, &s);
if (!s.ok()) { if (!s.ok()) {
return; return;
} }
node_map[n] = copy; node_map[n] = copy;
} }
// Only handle input edges. Output edges will be added later as // Only handle input edges. Output edges will be added later as
// its output nodes' input edges. // its output nodes' input edges.
for (auto e : n->in_edges()) { for (auto e : n->in_edges()) {
if (node_map.find(e->src()) == node_map.end()) { if (node_map.find(e->src()) == node_map.end()) {
s = errors::Internal("Cannot find node image for ", s = errors::Internal("Cannot find node image for ",
e->src()->DebugString()); e->src()->DebugString());
return; return;
} }
(*host_graph) (*host_graph)
->AddEdge(node_map[e->src()], e->src_output(), copy, ->AddEdge(node_map[e->src()], e->src_output(), copy,
e->dst_input()); e->dst_input());
} }
// Change b). // Change b).
if (copy->type_string() == "_XlaRecvAtHost" || if (copy->type_string() == "_XlaRecvAtHost" ||
copy->type_string() == "_XlaSendFromHost") { copy->type_string() == "_XlaSendFromHost") {
(*host_graph)->AddControlEdge(copy, sequencer); (*host_graph)->AddControlEdge(copy, sequencer);
} }
}, },
NodeComparatorID()); NodeComparatorID());
if (!s.ok()) { if (!s.ok()) {
return s; return s;
} }
@ -838,7 +839,12 @@ Status ExtractOutsideCompilationForFunction(
FunctionDef shape_inference_fdef = *xla_fdef; FunctionDef shape_inference_fdef = *xla_fdef;
shape_inference_fdef.mutable_signature()->set_name( shape_inference_fdef.mutable_signature()->set_name(
shape_inference_graph); shape_inference_graph);
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); if (fld->Find(shape_inference_graph)) {
TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph,
shape_inference_fdef));
} else {
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
}
} }
} }
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "absl/types/optional.h"
#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/const_op.h"
@ -34,14 +35,30 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) {
TF_RET_CHECK(n->type_string() == "Const"); // StatusOrOptional<T> instances hold
//
// - A non-OK Status to indicate an error that needs to be propagated out of
// this pass (e.g. the Graph is malformed).
//
// - A nullopt to indicate the function that created the instance failed to do
// what it set out to do but this is not actually an error
// (e.g. TryToGetTensorFromConstOp was passed a non-Const node).
//
// - A T to indicate a successful operation.
template <class T>
using StatusOrOptional = xla::StatusOr<absl::optional<T>>;
StatusOrOptional<Tensor> TryToGetTensorFromConstOp(Node* n) {
if (n->type_string() != "Const") {
return {absl::nullopt};
}
const TensorProto* proto = nullptr; const TensorProto* proto = nullptr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
Tensor tensor(proto->dtype()); Tensor tensor(proto->dtype());
TF_RET_CHECK(tensor.FromProto(*proto)); TF_RET_CHECK(tensor.FromProto(*proto));
*out_tensor = std::move(tensor); return {tensor};
return Status::OK();
} }
struct SliceInputs { struct SliceInputs {
@ -70,7 +87,7 @@ std::vector<int64> IntTensorAsVector(const Tensor& t) {
// Packages up the inputs to a Slice operation into an instance of // Packages up the inputs to a Slice operation into an instance of
// `SliceInputs`. // `SliceInputs`.
Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { StatusOrOptional<SliceInputs> GetSliceInputs(Node* slice) {
const int kSliceInputIndex = 0; const int kSliceInputIndex = 0;
const int kSliceBeginIndex = 1; const int kSliceBeginIndex = 1;
const int kSliceSizeIndex = 2; const int kSliceSizeIndex = 2;
@ -81,23 +98,27 @@ Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) {
TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
const Edge* slice_begin_edge; const Edge* slice_begin_edge;
TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
slice_inputs->input =
SliceInputs slice_inputs;
slice_inputs.input =
Output(slice_input_edge->src(), slice_input_edge->src_output()); Output(slice_input_edge->src(), slice_input_edge->src_output());
slice_inputs->begin = slice_inputs.begin =
Output(slice_begin_edge->src(), slice_begin_edge->src_output()); Output(slice_begin_edge->src(), slice_begin_edge->src_output());
slice_inputs->size = slice_inputs.size =
Output(slice_size_edge->src(), slice_size_edge->src_output()); Output(slice_size_edge->src(), slice_size_edge->src_output());
Tensor tf_slice_size; TF_ASSIGN_OR_RETURN(absl::optional<Tensor> tf_slice_size,
TF_RETURN_IF_ERROR( TryToGetTensorFromConstOp(slice_inputs.size.node()));
GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size)); if (!tf_slice_size.has_value()) {
return {absl::nullopt};
if (tf_slice_size.dims() != 1) {
return errors::Internal("Expected vector for the slice size input.");
} }
slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size); if (tf_slice_size->dims() != 1) {
return Status::OK(); return {absl::nullopt};
}
slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size);
return {slice_inputs};
} }
// Casts `x` to a DT_INT64 if it isn't one already. // Casts `x` to a DT_INT64 if it isn't one already.
@ -263,36 +284,43 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
return Status::OK(); return Status::OK();
} }
// Returns true if `n` is a slice we can rewrite to have a static shape // If `n` is a slice we can rewrite to have a static shape (i.e. have the output
// (i.e. have the output shape only depend on the "size" input). Fills in // shape only depend on the "size" input) then returns the a SliceInputs
// `slice_inputs` in the process. // representing the inputs to `n`. Otherwise returns nullopt.
bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { StatusOrOptional<SliceInputs> IsRewritableSlice(Node* n) {
if (n->type_string() != "Slice") { if (n->type_string() != "Slice") {
return false; return {absl::nullopt};
} }
if (!GetXlaClusterForNode(*n).has_value()) { if (!GetXlaClusterForNode(*n).has_value()) {
// There is no need to change slice ops outside XLA clusters. // There is no need to change slice ops outside XLA clusters.
return false; return {absl::nullopt};
} }
if (!GetSliceInputs(n, slice_inputs).ok()) { TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
// Could not parse slice inputs. E.g. the sizes input was not a constant. GetSliceInputs(n));
return false; if (!slice_inputs.has_value()) {
return {absl::nullopt};
} }
// If slice_size[i] < -1 for any i then executing the slice will throw an // If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here. // error, and we don't do anything here.
return absl::c_all_of(slice_inputs->size_as_vector, bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector,
[](int64 size_i) { return size_i >= -1; }); [](int64 size_i) { return size_i >= -1; });
if (!slice_is_ok) {
return {absl::nullopt};
}
return slice_inputs;
} }
Status FindAndRewriteSlices(Graph* g, bool* changed) { Status FindAndRewriteSlices(Graph* g, bool* changed) {
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite; std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite;
for (Node* n : g->nodes()) { for (Node* n : g->nodes()) {
SliceInputs slice_inputs; TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
if (IsRewritableSlice(n, &slice_inputs)) { IsRewritableSlice(n));
slices_to_rewrite.push_back({n, std::move(slice_inputs)}); if (slice_inputs.has_value()) {
slices_to_rewrite.push_back({n, std::move(*slice_inputs)});
} }
} }

View File

@ -44,11 +44,8 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass); MarkForCompilationPass);
// TODO(b/111210515): IncreaseDynamismForAutoJitPass creates slices with index REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
// type DT_INT64 which do not have a kernel on GPU. IncreaseDynamismForAutoJitPass);
//
// REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
// IncreaseDynamismForAutoJitPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
PartiallyDeclusterPass); PartiallyDeclusterPass);

View File

@ -39,12 +39,22 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h" #include "tensorflow/core/util/stream_executor_util.h"
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
// in error case, it returns RET instead of void.
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return RET; \
} \
} while (0)
namespace tensorflow { namespace tensorflow {
namespace { namespace {
Status PlatformInfoFromContext(OpKernelConstruction* ctx, XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
XlaPlatformInfo* result) {
DeviceType device_type = ctx->device_type(); DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr; se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr;
@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx,
} }
if (!device_allocator) { if (!device_allocator) {
TF_ASSIGN_OR_RETURN(se::Platform* const platform, xla::StatusOr<se::Platform*> maybe_platform =
se::MultiPlatformManager::PlatformWithId(platform_id)); se::MultiPlatformManager::PlatformWithId(platform_id);
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
xla_allocator = absl::make_unique<XlaAllocator>( xla_allocator = absl::make_unique<XlaAllocator>(
platform, ctx->device()->GetAllocator({})); maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
} }
*result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
std::move(xla_allocator), device_allocator); std::move(xla_allocator), device_allocator);
return Status::OK();
} }
// A closure describing how to run a compiled version of a TensorFlow function. // A closure describing how to run a compiled version of a TensorFlow function.
@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
: OpKernel(ctx), : OpKernel(ctx),
constants_(constants), constants_(constants),
resources_(resources), resources_(resources),
function_(function) { function_(function),
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); platform_info_(PlatformInfoFromContext(ctx)) {}
}
static Status BuildCompilationCache(OpKernelContext* ctx, static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, const XlaPlatformInfo& platform_info,
@ -277,8 +286,10 @@ static Status CompileToLocalExecutable(
// rather than a one-element tuple. // rather than a one-element tuple.
compile_options.always_return_tuple = false; compile_options.always_return_tuple = false;
return cache->Compile(options, function, constant_args, *variables, ctx, std::vector<XlaCompiler::Argument> args;
compile_options, TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict, : XlaCompilationCache::CompileMode::kStrict,
kernel, executable); kernel, executable);
@ -333,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
} }
namespace { namespace {
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
// in error case, it returns RET instead of void.
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return RET; \
} \
} while (0)
// Helper static functions to construct parameters for // Helper static functions to construct parameters for
// XlaLocalLaunchBase constructor from OpKernelConstruction. // XlaLocalLaunchBase constructor from OpKernelConstruction.
std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
@ -381,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
return *func; return *func;
} }
#undef OP_REQUIRES_OK_RETURN bool MustCompileAttr(OpKernelConstruction* ctx) {
bool must_compile;
OP_REQUIRES_OK_RETURN(ctx, false,
ctx->GetAttr("must_compile", &must_compile));
return must_compile;
}
} // namespace } // namespace
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
@ -396,10 +400,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
: OpKernel(ctx), : OpKernel(ctx),
constants_(ConstantsVector(ctx)), constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)), resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)) { function_(FunctionAttr(ctx)),
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); platform_info_(PlatformInfoFromContext(ctx)),
OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); must_compile_(MustCompileAttr(ctx)) {}
}
void XlaCompileOp::Compute(OpKernelContext* ctx) { void XlaCompileOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaCompileOp " << def().name() VLOG(3) << "XlaCompileOp " << def().name()
@ -409,13 +412,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable; xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables; std::map<int, OptionalTensor> variables;
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) { bool cannot_compile_cluster;
{
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster = cannot_compile_cluster_;
}
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
executable = nullptr; executable = nullptr;
} else { } else {
OP_REQUIRES_OK(ctx, CompileToLocalExecutable( Status status = CompileToLocalExecutable(
ctx, function_, platform_info_, resources_, ctx, function_, platform_info_, resources_, constants_,
constants_, /*lazy=*/!must_compile_, &client, /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
&variables, &kernel, &executable)); if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}
if (status.code() == error::UNIMPLEMENTED) {
LOG(WARNING) << "Compilation failed:" << status.ToString()
<< ". Falling back to TF function call.";
executable = nullptr;
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster_ = true;
}
} }
AllocatorAttributes host_alloc_attrs; AllocatorAttributes host_alloc_attrs;
@ -452,9 +472,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
ctx->set_output(1, compilation_successful); ctx->set_output(1, compilation_successful);
} }
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
}
void XlaRunOp::Compute(OpKernelContext* ctx) { void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name(); VLOG(3) << "XlaRunOp " << def().name();

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#include <atomic>
#include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_launch_util.h"
@ -33,6 +35,7 @@ namespace tensorflow {
class XlaPlatformInfo { class XlaPlatformInfo {
public: public:
XlaPlatformInfo() : device_type_("") {} XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type, explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata, const XlaDevice::Metadata* xla_device_metadata,
@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel {
protected: protected:
// Indexes of compile-time constant inputs // Indexes of compile-time constant inputs
std::vector<int> constants_; const std::vector<int> constants_;
// Indexes of resource inputs // Indexes of resource inputs
std::vector<int> resources_; const std::vector<int> resources_;
NameAttrList function_; const NameAttrList function_;
XlaPlatformInfo platform_info_; const XlaPlatformInfo platform_info_;
}; };
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel {
private: private:
// Indexes of compile-time constant inputs // Indexes of compile-time constant inputs
std::vector<int> constants_; const std::vector<int> constants_;
// Indexes of resource inputs // Indexes of resource inputs
std::vector<int> resources_; const std::vector<int> resources_;
NameAttrList function_; const NameAttrList function_;
XlaPlatformInfo platform_info_; XlaPlatformInfo platform_info_;
bool must_compile_; const bool must_compile_;
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
// error when compiling the cluster this _XlaCompile is supposed to compile.
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
// on any future calls to _XlaCompile.
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
mutex cannot_compile_cluster_mu_;
}; };
class XlaRunOp : public OpKernel { class XlaRunOp : public OpKernel {
@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel {
void Compute(OpKernelContext* ctx) override; void Compute(OpKernelContext* ctx) override;
private: private:
XlaPlatformInfo platform_info_; const XlaPlatformInfo platform_info_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -22,7 +22,7 @@ cc_library(
hdrs = ["mark_for_compilation_pass_flags.h"], hdrs = ["mark_for_compilation_pass_flags.h"],
deps = deps =
[ [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
@ -34,7 +34,7 @@ cc_library(
hdrs = ["xla_device_flags.h"], hdrs = ["xla_device_flags.h"],
deps = deps =
[ [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
@ -46,7 +46,7 @@ cc_library(
hdrs = ["build_xla_ops_pass_flags.h"], hdrs = ["build_xla_ops_pass_flags.h"],
deps = deps =
[ [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
@ -58,7 +58,7 @@ cc_library(
hdrs = ["xla_ops_common_flags.h"], hdrs = ["xla_ops_common_flags.h"],
deps = deps =
[ [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],

View File

@ -16,7 +16,7 @@ limitations under the License.
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" #include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
@ -34,7 +34,7 @@ void AllocateAndParseFlags() {
Flag("tf_xla_enable_lazy_compilation", Flag("tf_xla_enable_lazy_compilation",
&flags->tf_xla_enable_lazy_compilation, ""), &flags->tf_xla_enable_lazy_compilation, ""),
}); });
xla::legacy_flags::ParseFlagsFromEnv(*flag_list); xla::ParseFlagsFromEnv(*flag_list);
} }
} // namespace } // namespace

View File

@ -19,7 +19,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
@ -64,7 +65,18 @@ static void AllocateFlags() {
Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
"enable fusion of element-wise operations only using XLA when " "enable fusion of element-wise operations only using XLA when "
"global_jit_level is ON*.")}); "global_jit_level is ON*.")});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list); xla::ParseFlagsFromEnv(*flag_list);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Parsed MarkForCompilationPassFlags:";
VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size;
VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size;
VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug;
VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel;
VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
}
} }
// Append to *append_to flag definitions associated with the XLA bridge's // Append to *append_to flag definitions associated with the XLA bridge's

View File

@ -33,7 +33,7 @@ void AppendMarkForCompilationPassFlags(
// The values of flags associated with the XLA bridge's // The values of flags associated with the XLA bridge's
// mark_for_compilation_pass module. // mark_for_compilation_pass module.
typedef struct { struct MarkForCompilationPassFlags {
int32 tf_xla_auto_jit; // Control compilation of operators into XLA int32 tf_xla_auto_jit; // Control compilation of operators into XLA
// computations on CPU and GPU devices. 0 = use // computations on CPU and GPU devices. 0 = use
// ConfigProto setting; -1 = off; 1 = on for things // ConfigProto setting; -1 = off; 1 = on for things
@ -55,7 +55,7 @@ typedef struct {
// is set to ON* and overrides its behavior. If // is set to ON* and overrides its behavior. If
// true, enable fusion of element-wise operations // true, enable fusion of element-wise operations
// only using XLA. // only using XLA.
} MarkForCompilationPassFlags; };
// Return a pointer to the MarkForCompilationPassFlags struct; // Return a pointer to the MarkForCompilationPassFlags struct;
// repeated calls return the same pointer. // repeated calls return the same pointer.

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
@ -41,7 +41,7 @@ static void AllocateFlags() {
"Switch a device into 'on-demand' mode, where instead of " "Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."), "autoclustering ops are compiled one by one just-in-time."),
}); });
xla::legacy_flags::ParseFlagsFromEnv(*flag_list); xla::ParseFlagsFromEnv(*flag_list);
} }
// Return a pointer to the XlaDeviceFlags struct; // Return a pointer to the XlaDeviceFlags struct;

View File

@ -17,8 +17,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" #include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
@ -35,7 +35,13 @@ void AllocateAndParseFlags() {
Flag("tf_xla_always_defer_compilation", Flag("tf_xla_always_defer_compilation",
&flags->tf_xla_always_defer_compilation, ""), &flags->tf_xla_always_defer_compilation, ""),
}); });
xla::legacy_flags::ParseFlagsFromEnv(*flag_list); xla::ParseFlagsFromEnv(*flag_list);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Parsed XlaOpsCommonFlags:";
VLOG(1) << " tf_xla_always_defer_compilation = "
<< flags->tf_xla_always_defer_compilation;
}
} }
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {

View File

@ -61,8 +61,23 @@ struct OperationFilter {
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
// auto-clustering stateful RNG ops. // auto-clustering stateful RNG ops.
bool allow_stateful_rng_ops; bool allow_stateful_rng_ops;
// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
// to cluster ControlTrigger because of how we use deadness analysis.
bool allow_control_trigger;
// Whether ops with dummy implementations are allowed. We avoid
// auto-clustering these ops so that the user is not surprised when XLA is
// implicitly enabled. If the user explicitly specifies to use XLA, it is fine
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
// have dummy XLA implementations.
bool allow_dummy_ops;
}; };
bool IsDummyImplOp(absl::string_view op_name) {
return op_name == "Assert" || op_name == "CheckNumerics";
}
bool IsStatefulRandomOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" || return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
@ -225,6 +240,12 @@ bool IsCompilableCall(const NodeDef& call_def,
IsStatefulRandomOp(node->type_string())) { IsStatefulRandomOp(node->type_string())) {
return false; return false;
} }
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
return false;
}
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
return false;
}
if (!HasXLAKernel(*node, jit_device_type) && if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) { lib_runtime)) {
@ -452,7 +473,14 @@ Status FindCompilationCandidates(
OperationFilter op_filter; OperationFilter op_filter;
op_filter.allow_resource_ops = registration->compile_resource_ops; op_filter.allow_resource_ops = registration->compile_resource_ops;
op_filter.allow_stateful_rng_ops = registration->requires_compilation; op_filter.allow_stateful_rng_ops =
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
op_filter.allow_control_trigger =
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
if (!HasXLAKernel(*node, jit_device_type) && if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0, !IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
@ -467,6 +495,15 @@ Status FindCompilationCandidates(
VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
continue; continue;
} }
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op";
continue;
}
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
VLOG(2) << "Rejecting " << node->name() << ": dummy op ("
<< node->type_string() << ")";
continue;
}
if (!op_filter.allow_resource_ops && if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
@ -597,11 +634,14 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
&registration)); &registration));
DeviceType jit_device_type(registration->compilation_device_name); DeviceType jit_device_type(registration->compilation_device_name);
// We can always *compile* resource operations and stateful RNGs, even if we // We can always *compile* resource operations, stateful RNGs and dummy ops,
// are sometimes unable to auto-cluster them. // even if we are sometimes unable to auto-cluster them.
OperationFilter op_filter; OperationFilter op_filter;
op_filter.allow_resource_ops = true; op_filter.allow_resource_ops = true;
op_filter.allow_stateful_rng_ops = true; op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_dummy_ops = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
} }
@ -613,10 +653,8 @@ Status MarkForCompilationPass::Run(
GetGlobalJitLevel(options); GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags(); legacy_flags::GetMarkForCompilationPassFlags();
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
bool fusion_only = flags->tf_xla_fusion_only; bool fusion_only = flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
const FunctionLibraryDefinition* fld = options.flib_def; const FunctionLibraryDefinition* fld = options.flib_def;
@ -635,9 +673,6 @@ Status MarkForCompilationPass::Run(
return false; return false;
} }
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
// If there is a _XlaCompile annotation, use its value. // If there is a _XlaCompile annotation, use its value.
bool compile = false; bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
@ -674,18 +709,21 @@ Status MarkForCompilationPass::Run(
return false; return false;
} }
// Otherwise use the value of global_jit_level. // Otherwise use the value of global_jit_level and the device's
// Ignore enable_jit_by_default if global jit compilation for CPU // autoclustering policy.
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
bool should_compile = bool should_compile =
(ignore_registration || registration->enable_jit_by_default) && registration->autoclustering_policy ==
global_jit_level != OptimizerOptions::OFF; XlaOpRegistry::AutoclusteringPolicy::kAlways ||
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
global_jit_level != OptimizerOptions::OFF);
if (!should_compile) { if (!should_compile) {
if (global_jit_level == OptimizerOptions::OFF) { if (global_jit_level == OptimizerOptions::OFF) {
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
} else { } else {
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; VLOG(2)
<< "Rejecting " << node->name()
<< ": autoclustering for device only when requested explicitly.";
} }
} }
return should_compile; return should_compile;
@ -1073,12 +1111,10 @@ Status MarkForCompilationPass::RunImpl(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration); XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
// Compile if this is a cluster of >= min_cluster_size compilable operators. // Compile if this is a cluster of >= min_cluster_size compilable operators.
// Also, always compile if the operator is placed on a device that requires // Also, always compile if it contains at least one op that is marked for
// compilation, or if it contains at least one op that is marked for
// compilation that is not an Identity op. // compilation that is not an Identity op.
if (effective_cluster_sizes[cluster] >= min_cluster_size || if (effective_cluster_sizes[cluster] >= min_cluster_size ||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) {
registration->requires_compilation) {
string& name = cluster_names[cluster]; string& name = cluster_names[cluster];
if (name.empty()) { if (name.empty()) {

View File

@ -817,14 +817,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
std::unordered_map<string, string> clusters = GetClusters(*graph); std::unordered_map<string, string> clusters = GetClusters(*graph);
ASSERT_FALSE(clusters.empty()); // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
string cluster_name = clusters.begin()->second; // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
// cluster it because of b/118970344.
// ctrl_trigger_a has inputs with mismatching deadness so it won't be EXPECT_TRUE(clusters.empty());
// clustered. ctrl_trigger_b is okay to cluster.
std::unordered_map<string, string> expected_clusters(
{{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}});
EXPECT_EQ(clusters, expected_clusters);
} }
TEST(XlaCompilationTest, RandomShape) { TEST(XlaCompilationTest, RandomShape) {
@ -923,9 +919,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph); std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/shape_rng"], ""); EXPECT_EQ(clusters["test/shape_rng"], "");
EXPECT_NE(clusters["test/reshape"], ""); EXPECT_EQ(clusters["test/reshape"], "");
EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
} }
TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
@ -1088,7 +1083,7 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
EXPECT_NE(clusters["test/c"], ""); EXPECT_NE(clusters["test/c"], "");
} }
TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
Scope root = Scope::NewRootScope().ExitOnError(); Scope root = Scope::NewRootScope().ExitOnError();
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
@ -1104,5 +1099,53 @@ TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) {
EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/a"], "");
EXPECT_EQ(clusters["test/b"], ""); EXPECT_EQ(clusters["test/b"], "");
} }
TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
absl::string_view xla_cpu_device =
"/job:worker/replica:0/task:0/device:XLA_CPU:0";
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output check =
ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
for (Node* n : graph->nodes()) {
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
n->set_assigned_device_name(string(xla_cpu_device));
}
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/check"], "");
EXPECT_NE(clusters["test/greaterequal"], "");
EXPECT_NE(clusters["test/assert"], "");
}
TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output check =
ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/assert"], "");
EXPECT_EQ(clusters["test/check"], "");
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -133,6 +133,10 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
graph->RemoveEdge(out_edge_to_clone); graph->RemoveEdge(out_edge_to_clone);
} }
if (n->out_edges().empty()) {
graph->RemoveNode(n);
}
return Status::OK(); return Status::OK();
} }
@ -191,6 +195,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
} }
} }
// Recompute post order since PartiallyDeclusterNode may have deleted nodes.
post_order.clear();
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
nodes_to_partially_decluster.clear(); nodes_to_partially_decluster.clear();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
@ -210,7 +218,8 @@ bool IsIntraClusterEdge(const Edge& edge) {
bool IsMustCompileDevice(const DeviceType& device_type) { bool IsMustCompileDevice(const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration; const XlaOpRegistry::DeviceRegistration* registration;
if (XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) { if (XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return registration->requires_compilation; return registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
} }
return false; return false;

View File

@ -437,5 +437,32 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
} }
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
const char* const kClusteredProducer0Name = "ClusteredProducer0";
const char* const kClusteredProducer1Name = "ClusteredProducer1";
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* input =
ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
Node* clustered_producer_0 =
ops::BinaryOp("FakeBinary", input, input,
builder.opts().WithName(kClusteredProducer0Name));
Node* clustered_producer_1 =
ops::BinaryOp("FakeBinary", clustered_producer_0, input,
builder.opts().WithName(kClusteredProducer1Name));
ops::BinaryOp("FakeBinary", clustered_producer_1, input,
builder.opts().WithName("UnclusteredConsumer"));
clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr);
EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -1,132 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_
#include <deque>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
// A thread-safe, first-in-first-out queue.
template <typename T>
class ProducerConsumerQueue {
public:
ProducerConsumerQueue()
: capacity_(std::numeric_limits<std::size_t>::max()) {}
~ProducerConsumerQueue() = default;
// Wait until the queue is non-full, then append a copy of v.
void Put(const T &v);
// Wait until the queue is non-empty, then remove and return the head value.
T Get();
// If the queue is non-empty, remove the head value, placing it in *pv, and
// return true; otherwise return false.
bool TryGet(T *pv);
// Set the capacity of the queue; the queue is full whenever count() >=
// capacity(). The initial value is the maximum size_t. Requires size > 0.
void set_capacity(std::size_t size);
// Return the capacity of the queue.
std::size_t capacity() const;
// Return the number of elements in the queue.
std::size_t count() const;
// Implementation details follow. Clients should ignore.
private:
mutable tensorflow::mutex mu_; // protects all fields below
tensorflow::condition_variable non_empty_ GUARDED_BY(mu_);
tensorflow::condition_variable non_full_ GUARDED_BY(mu_);
std::size_t capacity_ GUARDED_BY(mu_);
std::deque<T> queue_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue);
};
// ------------------------------------------------------
// Implementation details follow. Clients should ignore.
// Wait until the queue is non-full, then append a copy of v.
template <typename T>
void ProducerConsumerQueue<T>::Put(const T &v) {
mutex_lock lock(mu_);
while (queue_.size() >= capacity_) {
non_full_.wait(lock);
}
queue_.push_back(v);
non_empty_.notify_one();
}
// Wait until the queue is non-empty, then remove and return the head value.
template <typename T>
T ProducerConsumerQueue<T>::Get() {
mutex_lock lock(mu_);
while (queue_.empty()) {
non_empty_.wait(lock);
}
non_full_.notify_one();
T result_value = queue_.front();
queue_.pop_front();
return result_value;
}
// If the queue is non-empty, remove the head value, placing it in *pv, and
// return true; otherwise return false.
template <typename T>
bool ProducerConsumerQueue<T>::TryGet(T *pv) {
mutex_lock lock(mu_);
bool got_element = !queue_.empty();
if (got_element) {
non_full_.notify_one();
*pv = queue_.front();
queue_.pop_front();
}
return got_element;
}
// Set the capacity of the queue; the queue is full whenever count() >=
// capacity(). The initial value is the maximum size_t. Requires size > 0.
template <typename T>
void ProducerConsumerQueue<T>::set_capacity(std::size_t size) {
mutex_lock lock(mu_);
CHECK_NE(size, 0);
capacity_ = size;
non_full_.notify_all();
}
// Return the capacity of the queue.
template <typename T>
std::size_t ProducerConsumerQueue<T>::capacity() const {
mutex_lock lock(mu_);
std::size_t max_elements = capacity_;
return max_elements;
}
// Return the number of elements in the queue.
template <typename T>
std::size_t ProducerConsumerQueue<T>::count() const {
mutex_lock lock(mu_);
std::size_t num_elements = queue_.size();
return num_elements;
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_

View File

@ -1,139 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/producer_consumer_queue.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
typedef ProducerConsumerQueue<int> IntQueue;
// Insert integers between low inclusive and high exclusive into q.
void PushRange(IntQueue *q, int low, int high) {
while (low != high) {
q->Put(low);
VLOG(2) << "Pushing " << low;
++low;
}
}
// Push the numbers between 0 and 999 inclusive from several threads in the
// pool.
void PushRanges(IntQueue *queue, thread::ThreadPool *pool) {
VLOG(1) << "Adding 20-36";
pool->Schedule([queue] { PushRange(queue, 20, 36); });
VLOG(1) << "Adding 7-20";
pool->Schedule([queue] { PushRange(queue, 7, 20); });
VLOG(1) << "Adding 36-501";
pool->Schedule([queue] { PushRange(queue, 36, 501); });
VLOG(1) << "Adding 501-1000";
pool->Schedule([queue] { PushRange(queue, 501, 1000); });
VLOG(1) << "Adding 0-5";
pool->Schedule([queue] { PushRange(queue, 0, 5); });
VLOG(1) << "Adding 5-7";
pool->Schedule([queue] { PushRange(queue, 5, 7); });
}
// Pop elements from queue using Get(). Make sure that exactly <high> elements
// were present and their values are all integers between 0 and high-1
// inclusive.
void GetRange(IntQueue *queue, int high) {
VLOG(1) << "Testing Wait";
std::vector<int> results;
for (int i = 0; i != high; ++i) {
int r = queue->Get();
VLOG(2) << "Waited and got " << r;
results.push_back(r);
}
CHECK_EQ(queue->count(), 0);
std::sort(results.begin(), results.end());
for (int i = 0; i != high; ++i) {
CHECK(results[i] == i);
}
}
// Pop elements from queue using TryGet(). Make sure that exactly <high>
// elements were present and their values are all integers between 0 and high-1
// inclusive.
void TryGetRange(IntQueue *queue, int high) {
std::vector<int> results;
// Give up if we don't get all the elements back from the queue
// in 10 seconds.
int timeout = 10;
int r;
for (int i = 0; i != high; ++i) {
while (!queue->TryGet(&r)) {
if (!timeout--) {
LOG(FATAL) << "Can't find all elements in the queue";
}
VLOG(1) << "Sleeping for a second...";
sleep(1);
}
VLOG(2) << "Popped " << r;
results.push_back(r);
}
CHECK_EQ(queue->count(), 0);
CHECK(!queue->TryGet(&r));
std::sort(results.begin(), results.end());
for (int i = 0; i != high; ++i) {
CHECK_EQ(i, results[i]);
}
}
const int kNumThreads = 15;
TEST(ProducerConsumerQueue, GetRange) {
IntQueue queue;
{
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
PushRanges(&queue, &pool);
}
GetRange(&queue, 1000);
}
TEST(ProducerConsumerQueue, TryGetRange) {
IntQueue queue;
{
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
PushRanges(&queue, &pool);
}
TryGetRange(&queue, 1000);
}
TEST(ProducerConsumerQueue, ParallelGetRange) {
IntQueue queue;
{
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
pool.Schedule([&queue] { GetRange(&queue, 1000); });
PushRanges(&queue, &pool);
}
}
TEST(ProducerConsumerQueue, ParallelTryGetRange) {
IntQueue queue;
{
thread::ThreadPool pool(Env::Default(), "test", kNumThreads);
pool.Schedule([&queue] { TryGetRange(&queue, 1000); });
PushRanges(&queue, &pool);
}
}
} // namespace
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
@ -65,14 +66,14 @@ string XlaCompilationCache::DebugString() {
// Compute a string signature which encodes the shapes of the // Compute a string signature which encodes the shapes of the
// arguments in the supplied list. // arguments in the supplied list.
string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string XlaCompilationCache::Signature::HumanString() const {
string result = sig.name; string result = name;
for (const auto& a : sig.arg_types) { for (const auto& a : arg_types) {
absl::StrAppend(&result, ",", DataTypeString(a.first), absl::StrAppend(&result, ",", DataTypeString(a.first),
a.second.DebugString()); a.second.DebugString());
} }
for (const auto& v : sig.arg_values) { for (const auto& v : arg_values) {
absl::StrAppend(&result, "; ", v.DebugString()); absl::StrAppend(&result, "; ", v.DebugString());
} }
return result; return result;
@ -84,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
if (arg_values.size() != other.arg_values.size()) return false; if (arg_values.size() != other.arg_values.size()) return false;
for (int i = 0; i < arg_values.size(); ++i) { for (int i = 0; i < arg_values.size(); ++i) {
if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
arg_values[i].shape() != other.arg_values[i].shape() ||
arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
return false; return false;
} }
} }
@ -108,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()(
return h; return h;
} }
Status XlaCompilationCache::BuildSignature( xla::StatusOr<XlaCompilationCache::Signature>
const NameAttrList& function, const std::map<int, Tensor>& constant_args, XlaCompilationCache::BuildSignature(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const NameAttrList& function,
Signature* signature) { absl::Span<const XlaCompiler::Argument> args) {
signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); Signature signature;
signature->arg_values.reserve(constant_args.size()); signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
for (const XlaCompiler::Argument& arg : args) {
signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
for (int i = 0; i < ctx->num_inputs(); ++i) { signature.arg_values.push_back(arg.constant_value);
if (constant_args.count(i) > 0) { break;
// Use the values of compile time constants in the signature. case XlaCompiler::Argument::kParameter:
signature->arg_values.push_back(constant_args.at(i)); case XlaCompiler::Argument::kResource:
} else if (variable_args.count(i) > 0) { signature.arg_types.emplace_back(arg.type, arg.shape);
const OptionalTensor& variable = variable_args.at(i); break;
if (variable.present) { default:
signature->arg_types.emplace_back(variable.value.dtype(), return errors::InvalidArgument(
variable.value.shape()); "Unhandled argument kind in XlaCompilationCache: ",
} else { arg.HumanString());
signature->arg_types.emplace_back(DT_INVALID, TensorShape());
}
} else {
signature->arg_types.emplace_back(ctx->input_dtype(i),
ctx->input(i).shape());
} }
} }
return Status::OK(); return std::move(signature);
} }
namespace {
// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op.
Status BuildArguments(const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
if (constant_args.count(input_num) > 0) {
// Handles compile-time constants.
const Tensor& input = constant_args.at(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_args.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
if (input.NumElements() > 0) {
arg.kind = XlaCompiler::Argument::kParameter;
} else {
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = input;
}
arg.type = input.dtype();
arg.shape = input.shape();
} else {
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
const OptionalTensor& variable = variable_args.at(input_num);
arg.name = variable.name;
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.present) {
const Tensor& value = variable.value;
arg.type = value.dtype();
arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
// they are meaningless. However, it is legal to assign to a resource
// variable for the first time inside the XLA computation, so we do
// permit uninitialized variables.
arg.initialized = false;
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
}
}
return Status::OK();
}
} // namespace
Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::BuildExecutable(
const XlaCompiler::Options& options, const XlaCompiler::Options& options,
const XlaCompiler::CompilationResult& result, const XlaCompiler::CompilationResult& result,
@ -227,25 +164,38 @@ Status XlaCompilationCache::BuildExecutable(
Status XlaCompilationCache::Compile( Status XlaCompilationCache::Compile(
const XlaCompiler::Options& options, const NameAttrList& function, const XlaCompiler::Options& options, const NameAttrList& function,
const std::map<int, Tensor>& constant_args, absl::Span<const XlaCompiler::Argument> args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompileOptions& compile_options,
CompileMode compile_mode, CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result, const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) { xla::LocalExecutable** out_executable) {
// Set the compile threshold to 1 to implement CompileMode::kStrict. absl::optional<int64> compile_threshold;
int64 compile_threshold = if (compile_mode == CompileMode::kLazy) {
compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; compile_threshold = kDefaultCompilationThreshold;
return CompileImpl(options, function, constant_args, variable_args, ctx, }
compile_options, /*compile_single_op=*/false, auto compile_fn = [&](XlaCompiler* compiler,
XlaCompiler::CompilationResult* result) {
return compiler->CompileFunction(compile_options, function, args, result);
};
return CompileImpl(options, function, args, compile_fn,
/*compile_threshold=*/compile_threshold, /*compile_threshold=*/compile_threshold,
out_compilation_result, out_executable); out_compilation_result, out_executable);
} }
static bool IsMegamorphic(int64 compile_count, int64 execution_count) {
const int64 kCompileThreshold = 10;
const int64 kMinExecutionsPerCompile = 50;
// This heuristic is trying to capture the following property: have we sunk a
// certain minimum amount of compile time into the cluster that didn't quite
// "pay off"?
return compile_count > kCompileThreshold &&
execution_count < kMinExecutionsPerCompile * compile_count;
}
Status XlaCompilationCache::CompileSingleOp( Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options, const XlaCompiler::Options& options,
const std::map<int, Tensor>& constant_args, absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompileOptions& compile_options,
const XlaCompiler::CompilationResult** out_compilation_result, const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) { xla::LocalExecutable** out_executable) {
@ -253,54 +203,41 @@ Status XlaCompilationCache::CompileSingleOp(
NameAttrList name; NameAttrList name;
name.set_name(def.op()); name.set_name(def.op());
*name.mutable_attr() = def.attr(); *name.mutable_attr() = def.attr();
return CompileImpl(options, name, constant_args, variable_args, ctx, auto compile_op = [&](XlaCompiler* compiler,
compile_options, XlaCompiler::CompilationResult* result) {
/*compile_single_op=*/true, /*compile_threshold=*/1, std::vector<DataType> result_dtypes(ctx->num_outputs());
for (int i = 0; i < result_dtypes.size(); ++i) {
result_dtypes[i] = ctx->expected_output_dtype(i);
}
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
args, result_dtypes, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,
out_compilation_result, out_executable); out_compilation_result, out_executable);
} }
Status XlaCompilationCache::CompileImpl( Status XlaCompilationCache::CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function, const XlaCompiler::Options& options, const NameAttrList& function,
const std::map<int, Tensor>& constant_args, absl::Span<const XlaCompiler::Argument> args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, const std::function<Status(XlaCompiler* compiler,
const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, XlaCompiler::CompilationResult*)>& compile_fn,
int64 compile_threshold, absl::optional<int64> compile_threshold,
const XlaCompiler::CompilationResult** out_compilation_result, const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) { xla::LocalExecutable** out_executable) {
DCHECK_NE(out_executable, nullptr); DCHECK_NE(out_executable, nullptr);
VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
VLOG(2) << "num_inputs=" << ctx->num_inputs() VLOG(2) << "num_inputs=" << args.size();
<< " num_constant_args=" << constant_args.size() for (int i = 0; i < args.size(); i++) {
<< " num_variable_args=" << variable_args.size(); VLOG(2) << i << ": " << args[i].HumanString();
for (int i = 0; i < ctx->num_inputs(); i++) {
TensorShape shape = ctx->input(i).shape();
VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i))
<< " present=" << ctx->has_input(i)
<< " shape=" << shape.DebugString();
}
for (auto& iterator : variable_args) {
const OptionalTensor& variable = iterator.second;
VLOG(2) << "variable present=" << variable.present
<< " type=" << DataTypeString(variable.value.dtype())
<< " shape=" << variable.value.shape().DebugString()
<< " TF arg= " << iterator.first;
}
VLOG(2) << "num_outputs = " << ctx->num_outputs();
for (int i = 0; i < ctx->num_outputs(); i++) {
VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i);
} }
} }
TF_RET_CHECK(constant_args.size() + variable_args.size() <= TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
ctx->num_inputs()); VLOG(2) << "Signature: " << signature.HumanString();
Signature signature;
TF_RETURN_IF_ERROR(
BuildSignature(function, constant_args, variable_args, ctx, &signature));
VLOG(2) << "Signature: " << SignatureDebugString(signature);
// The outer lock protects the existence of the cache entry. It does not // The outer lock protects the existence of the cache entry. It does not
// protect the contents of the cache entry. // protect the contents of the cache entry.
Entry* entry; Entry* entry;
@ -319,25 +256,67 @@ Status XlaCompilationCache::CompileImpl(
// (since they get the benefit of XLA right away without waiting for warmup) // (since they get the benefit of XLA right away without waiting for warmup)
// and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
// most one cluster-compilation's worth of compile time). // most one cluster-compilation's worth of compile time).
bool is_first_execution = [&] { bool is_first_execution;
// We avoid compiling clusters that have "gone megamorphic" i.e. have an
// excessive amount of shape dynamism.
bool is_megamorphic;
{
mutex_lock lock(cluster_compile_stats_mu_); mutex_lock lock(cluster_compile_stats_mu_);
auto it = auto it =
cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
.first; .first;
return it->second.execution_count++ == 0; is_first_execution = it->second.execution_count++ == 0;
}();
// The is_megamorphic bit is "sticky". We assume clusters that have been
// observed to be megamorphic once stay megamorphic forever.
it->second.is_megamorphic |=
IsMegamorphic(/*compile_count=*/it->second.compile_count,
/*execution_count=*/it->second.execution_count);
is_megamorphic = it->second.is_megamorphic;
}
// Acquire the cache entry lock and compile, if necessary. // Acquire the cache entry lock and compile, if necessary.
// TODO(phawkins): this locking will need to be restructured when we implement // TODO(phawkins): this locking will need to be restructured when we implement
// cache eviction. // cache eviction.
mutex_lock entry_lock(entry->mu); mutex_lock entry_lock(entry->mu);
int64 current_request_count = ++entry->request_count; int64 current_request_count = ++entry->request_count;
VLOG(2) << "Compilation cache entry hit: " << entry->compiled
<< " signature: " << signature.HumanString() << " with request count "
<< current_request_count << " and compile threshold "
<< compile_threshold.value_or(0);
if (!entry->compiled) { if (!entry->compiled) {
VLOG(2) << "Compilation cache miss for signature: " const bool should_compile = [&] {
<< SignatureDebugString(signature) << " with request count " if (!compile_threshold.has_value()) {
<< current_request_count << " and compile threshold " // Lazy compilation is disabled.
<< compile_threshold; return true;
if (!is_first_execution && current_request_count < compile_threshold) { }
if (is_megamorphic) {
VLOG(3) << "Not compiling cluster " << function.name()
<< " because it is megamorphic.";
return false;
}
if (is_first_execution) {
return true;
}
bool reached_compile_threshold =
current_request_count >= *compile_threshold;
if (!reached_compile_threshold) {
VLOG(3)
<< "Not compiling cluster " << function.name()
<< " because it has not reached compile threshold; threshold is "
<< *compile_threshold << " execution count "
<< current_request_count << ".";
}
return reached_compile_threshold;
}();
if (!should_compile) {
VLOG(2) << "Not compiling for signature: " << signature.HumanString();
*out_compilation_result = nullptr; *out_compilation_result = nullptr;
*out_executable = nullptr; *out_executable = nullptr;
return Status::OK(); return Status::OK();
@ -347,21 +326,12 @@ Status XlaCompilationCache::CompileImpl(
const uint64 compile_start_us = env->NowMicros(); const uint64 compile_start_us = env->NowMicros();
// Do the actual JIT compilation without holding the lock (it can take // Do the actual JIT compilation without holding the lock (it can take
// a long time.) // a long time.)
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(
BuildArguments(constant_args, variable_args, ctx, &args));
XlaCompiler compiler(options); XlaCompiler compiler(options);
entry->compiled = true; entry->compiled = true;
if (compile_single_op) { entry->compilation_status =
entry->compilation_status = compile_fn(&compiler, &entry->compilation_result);
compiler.CompileSingleOp(compile_options, signature.name, ctx, args,
&entry->compilation_result);
} else {
entry->compilation_status = compiler.CompileFunction(
compile_options, function, args, &entry->compilation_result);
}
TF_RETURN_IF_ERROR(entry->compilation_status); TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr); CHECK_EQ(entry->executable.get(), nullptr);
entry->compilation_status = entry->compilation_status =

View File

@ -17,9 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
@ -30,13 +33,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
string name; // A descriptive name
bool present = false; // Is the tensor present?
Tensor value; // If present, what is the Tensor's value?
};
// The XlaCompilationCache class caches the results of the XlaCompiler class, // The XlaCompilationCache class caches the results of the XlaCompiler class,
// which converts a Tensorflow graph into a compiled XLA compilation. // which converts a Tensorflow graph into a compiled XLA compilation.
// //
@ -58,11 +54,7 @@ class XlaCompilationCache : public ResourceBase {
// Compiles a function into a XlaCompiler::CompilationResult that can be used // Compiles a function into a XlaCompiler::CompilationResult that can be used
// to execute an XLA Computation. Compilation results are cached. // to execute an XLA Computation. Compilation results are cached.
// `function` is the name of a Tensorflow function to compile. // `function` is the name of a Tensorflow function to compile.
// `constant_args` is a map of tensorflow argument number to its constant // `args` is a description of the arguments to the computation.
// value.
// `variable_args` is a snapshot of the current values of the
// resource variable arguments to `function`; uninitialized variables are
// represented by an absent OptionalTensor.
// //
// `compile_mode` controls the behavior of the compilation cache on a cache // `compile_mode` controls the behavior of the compilation cache on a cache
// miss. If `compile_mode` is `kLazy` then, based on some profitability // miss. If `compile_mode` is `kLazy` then, based on some profitability
@ -78,9 +70,7 @@ class XlaCompilationCache : public ResourceBase {
// outputs. // outputs.
Status Compile(const XlaCompiler::Options& options, Status Compile(const XlaCompiler::Options& options,
const NameAttrList& function, const NameAttrList& function,
const std::map<int, Tensor>& constant_args, absl::Span<const XlaCompiler::Argument> args,
const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompileOptions& compile_options,
CompileMode compile_mode, CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result, const XlaCompiler::CompilationResult** out_compilation_result,
@ -90,8 +80,7 @@ class XlaCompilationCache : public ResourceBase {
// XlaCompiler::CompileFunction. // XlaCompiler::CompileFunction.
Status CompileSingleOp( Status CompileSingleOp(
const XlaCompiler::Options& options, const XlaCompiler::Options& options,
const std::map<int, Tensor>& constant_args, absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompileOptions& compile_options,
const XlaCompiler::CompilationResult** out_compilation_result, const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable); xla::LocalExecutable** out_executable);
@ -101,26 +90,6 @@ class XlaCompilationCache : public ResourceBase {
string DebugString() override; string DebugString() override;
private:
// Common implementation of Compile and CompileSingleOp.
Status CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function,
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op, int64 compile_threshold,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
// XLA computation already, and generates an XLA LocalExecutable `executable`.
Status BuildExecutable(const XlaCompiler::Options& options,
const XlaCompiler::CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable);
xla::LocalClient* const client_;
const DeviceType device_type_;
// Describes the types, shapes and any compile-time constant arguments // Describes the types, shapes and any compile-time constant arguments
// to a kernel. Key that uniquely identifies a compilation output. // to a kernel. Key that uniquely identifies a compilation output.
struct Signature { struct Signature {
@ -137,14 +106,35 @@ class XlaCompilationCache : public ResourceBase {
struct Hash { struct Hash {
uint64 operator()(const Signature& signature) const; uint64 operator()(const Signature& signature) const;
}; };
// Returns a human-readable description of the signature.
string HumanString() const;
}; };
static string SignatureDebugString(const Signature& sig);
// Builds the signature for a compilation. // Builds the signature for a compilation.
Status BuildSignature(const NameAttrList& function, static xla::StatusOr<Signature> BuildSignature(
const std::map<int, Tensor>& constant_args, const NameAttrList& function,
const std::map<int, OptionalTensor>& variable_args, absl::Span<const XlaCompiler::Argument> args);
OpKernelContext* ctx, Signature* signature);
private:
// Common implementation of Compile and CompileSingleOp.
Status CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args,
const std::function<Status(XlaCompiler* compiler,
XlaCompiler::CompilationResult*)>& compile_fn,
absl::optional<int64> compile_threshold,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
// XLA computation already, and generates an XLA LocalExecutable `executable`.
Status BuildExecutable(const XlaCompiler::Options& options,
const XlaCompiler::CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable);
xla::LocalClient* const client_;
const DeviceType device_type_;
// The value associated with a cache entry. // The value associated with a cache entry.
struct Entry { struct Entry {
@ -180,7 +170,13 @@ class XlaCompilationCache : public ResourceBase {
// Cumulative time spent compiling the cluster. // Cumulative time spent compiling the cluster.
int64 cumulative_compile_time_us = 0; int64 cumulative_compile_time_us = 0;
// True if we have decided that this cluster is too dynamic (i.e. its shapes
// change too frequently) to profitably JIT compile. Once a cluster is
// tagged megamorphic, it stays megamorphic forever.
bool is_megamorphic = false;
}; };
mutex cluster_compile_stats_mu_; mutex cluster_compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster. // Maps cluster names to compilation statistics for said cluster.

View File

@ -0,0 +1,54 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(XlaCompilationCacheTest, SignatureEquality) {
NameAttrList fn;
fn.set_name("afunction");
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kConstant;
args[0].type = DT_INT32;
args[0].shape = TensorShape({4, 0});
args[0].constant_value = Tensor(DT_INT32, {4, 0});
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1,
XlaCompilationCache::BuildSignature(fn, args));
args[0].type = DT_FLOAT;
args[0].constant_value = Tensor(DT_FLOAT, {4, 0});
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2,
XlaCompilationCache::BuildSignature(fn, args));
args[0].shape = TensorShape({0, 4});
args[0].constant_value = Tensor(DT_FLOAT, {0, 4});
TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3,
XlaCompilationCache::BuildSignature(fn, args));
std::vector<XlaCompilationCache::Signature> signatures = {s1, s2, s3};
for (int i = 0; i < signatures.size(); ++i) {
for (int j = 0; j < signatures.size(); ++j) {
EXPECT_EQ(i == j, signatures[i] == signatures[j])
<< signatures[i].HumanString() << " " << signatures[j].HumanString();
}
}
}
} // namespace
} // namespace tensorflow

View File

@ -187,8 +187,13 @@ Status XlaCompileOnDemandOp::Compile(
compile_options.always_return_tuple = false; compile_options.always_return_tuple = false;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx); std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
compile_options, result, executable); std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
} }
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {

View File

@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_CPU_XLA_JIT; registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
registration.requires_compilation = !compile_on_demand; registration.autoclustering_policy =
registration.enable_jit_by_default = false; compile_on_demand
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true; registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);

View File

@ -446,7 +446,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
// Any op assigned to the device that isn't rewritten by the graph rewriter // Any op assigned to the device that isn't rewritten by the graph rewriter
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
// it just-in-time. // it just-in-time.
kernel_factory::OpKernelRegistrar::Factory factory = OpKernel* (*factory)(OpKernelConstruction*) =
[](OpKernelConstruction* context) -> OpKernel* { [](OpKernelConstruction* context) -> OpKernel* {
return new XlaCompileOnDemandOp(context); return new XlaCompileOnDemandOp(context);
}; };

View File

@ -112,6 +112,12 @@ class XlaDevice : public LocalDevice {
// compute, host-to-device, and device-to-host communication. // compute, host-to-device, and device-to-host communication.
bool use_multiple_streams = false; bool use_multiple_streams = false;
// A function that describes how the on-host shapes of
// a) argument and return value, for entry computations
// b) variables, for all computations,
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared
// shapes for computations. Must be non-null.
XlaCompiler::ShapeRepresentationFn shape_representation_fn; XlaCompiler::ShapeRepresentationFn shape_representation_fn;
// If padded_shape_fn is empty, a default implementation that returns // If padded_shape_fn is empty, a default implementation that returns

View File

@ -70,9 +70,12 @@ XlaDeviceContext::XlaDeviceContext(
CHECK(device_to_host_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr); CHECK(stream_ != nullptr);
if (!shape_representation_fn_) { if (!shape_representation_fn_) {
shape_representation_fn_ = shape_representation_fn_ = [](const TensorShape& shape,
[](const TensorShape& shape, DataType dtype) -> xla::StatusOr<xla::Shape> {
DataType dtype) -> xla::StatusOr<TensorShape> { return shape; }; xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
} }
} }
@ -99,7 +102,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
CHECK(xla_tensor); CHECK(xla_tensor);
Status status = [&]() -> Status { Status status = [&]() -> Status {
TF_ASSIGN_OR_RETURN(TensorShape shape, TF_ASSIGN_OR_RETURN(xla::Shape shape,
shape_representation_fn_(device_tensor->shape(), shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype())); device_tensor->dtype()));
@ -111,9 +114,15 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal())); stream_->parent()->device_ordinal()));
// The cpu_tensor and literal that we created here hold the data of host
// tensor in descending layout. The layout could be different from layout in
// device_tensor (but the logical shape has to be the same). The
// transfer_manager is responsible to do corresponding transposing when
// transferring the data to device.
xla::BorrowingLiteral literal( xla::BorrowingLiteral literal(
static_cast<const char*>(DMAHelper::base(cpu_tensor)), static_cast<const char*>(DMAHelper::base(cpu_tensor)),
xla_tensor->shaped_buffer().on_host_shape()); xla::ShapeUtil::MakeShape(shape.element_type(),
xla::AsInt64Slice(shape.dimensions())));
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< xla_tensor->shaped_buffer().ToString(); << xla_tensor->shaped_buffer().ToString();
@ -183,8 +192,15 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get());
// Transfer manager requires the shape of the shaped buffer to be the same as
// literal shape except for the layout. Set the literal to use xla_tensor's
// shape as it is derived from the cpu_tensor's shape using
// shape_representation_fn_.
xla::MutableBorrowingLiteral literal; xla::MutableBorrowingLiteral literal;
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal)); TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
xla::LayoutUtil::GetWithDefaultLayout(
xla_tensor->shaped_buffer().on_host_shape()),
cpu_tensor, &literal));
TensorReference ref(*device_tensor); TensorReference ref(*device_tensor);
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/stack.h"
#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow { namespace tensorflow {
@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel {
.Device(DEVICE) \ .Device(DEVICE) \
.TypeConstraint<string>("T") \ .TypeConstraint<string>("T") \
.HostMemory("input"), \ .HostMemory("input"), \
RetvalOp); RetvalOp); \
\
REGISTER_KERNEL_BUILDER(Name("StackV2") \
.Device(DEVICE) \
.HostMemory("max_size") \
.HostMemory("handle"), \
StackOp); \
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("T", TYPES), \
TemplatedStackPushOp</*allow_swapping=*/false>); \
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("elem_type", TYPES), \
StackPopOp); \
REGISTER_KERNEL_BUILDER( \
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
// TODO(phawkins): currently we do not register the QueueEnqueueMany, // TODO(b/118881356): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
// and write the tensors they access in order to concatenate them into a batch. // and write the tensors they access in order to concatenate them into a batch.
// We would need either to call out to an XLA computation to perform the // We would need either to call out to an XLA computation to perform the

View File

@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
std::vector<Device*>* devices) { std::vector<Device*>* devices) {
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.requires_compilation = true; registration.autoclustering_policy =
registration.enable_jit_by_default = false; XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true; registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
@ -53,24 +53,25 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
return Status::OK(); return Status::OK();
} }
XlaDevice::Options options; for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
options.platform = platform.ValueOrDie(); XlaDevice::Options options;
options.device_name_prefix = name_prefix; options.platform = platform.ValueOrDie();
options.device_name = DEVICE_XLA_GPU; options.device_name_prefix = name_prefix;
options.device_ordinal = 0; options.device_name = DEVICE_XLA_GPU;
options.compilation_device_name = DEVICE_GPU_XLA_JIT; options.device_ordinal = i;
options.use_multiple_streams = false; options.compilation_device_name = DEVICE_GPU_XLA_JIT;
auto device = absl::make_unique<XlaDevice>(session_options, options); options.use_multiple_streams = true;
auto device = absl::make_unique<XlaDevice>(session_options, options);
// TODO(b/78468222): Uncomment after fixing this bug Status status = device->UseGpuDeviceInfo();
// status = device->UseGpuDeviceInfo(); if (!status.ok()) {
// if (!status.ok()) { errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, " device number ", i);
// " device"); return status;
// return status; }
// }
devices->push_back(device.release()); devices->push_back(device.release());
}
return Status::OK(); return Status::OK();
} }

View File

@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
registration.requires_compilation = true; registration.autoclustering_policy =
registration.enable_jit_by_default = false; XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true; registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration); registration);

View File

@ -191,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
return Status::OK(); return Status::OK();
} }
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
ScopedShapedBuffer ExtractSubShapedBuffer(
ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator) {
const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_host_shape(), index);
const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_device_shape(), index);
ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
shaped_buffer->platform(),
shaped_buffer->device_ordinal());
auto& shape_tree = shaped_buffer->buffers();
auto& sub_shape_tree = sub_shaped_buffer.buffers();
sub_shape_tree.CopySubtreeFrom(shape_tree,
/*source_base_index=*/{index},
/*target_base_index=*/{});
shape_tree.ForEachMutableElement(
[index](const xla::ShapeIndex& shape_index,
tensorflow::se::DeviceMemoryBase* data) {
// shape_index is empty for the root node. Ignore that.
if (!shape_index.empty() && shape_index[0] == index) {
*data = tensorflow::se::DeviceMemoryBase(nullptr, 0);
}
});
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
}
} // namespace internal
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext( XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams) bool allocate_xla_tensors, bool use_multiple_streams)
@ -391,8 +357,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) { if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer( xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) { if (use_multiple_streams_) {
xla_tensor->ResetDefinitionEvent(definition_event, stream); xla_tensor->ResetDefinitionEvent(definition_event, stream);
} }
@ -445,7 +410,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) { for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({}); Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (variable_infos[i].var()->tensor()->dtype() != write.type) { if (variable_infos[i].var()->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write"); return errors::Internal("Mismatched type in variable write");
@ -455,18 +419,20 @@ Status XlaComputationLaunchContext::PopulateOutputs(
Tensor output_tensor; Tensor output_tensor;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ctx->allocate_temp(write.type, write.shape, &output_tensor)); ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); if (write.shape.num_elements() > 0) {
CHECK(xla_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
xla_tensor->set_shaped_buffer( CHECK(xla_tensor);
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
if (use_multiple_streams_) { if (use_multiple_streams_) {
xla_tensor->ResetDefinitionEvent(definition_event, stream); xla_tensor->ResetDefinitionEvent(definition_event, stream);
}
} }
*variable_infos[i].var()->tensor() = output_tensor; *variable_infos[i].var()->tensor() = output_tensor;
} else { } else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
Tensor output_tensor = XlaTensorBuffer::MakeTensor( Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator); write.type, write.shape, buffer, allocator);
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
*variable_infos[i].var()->tensor() = output_tensor; *variable_infos[i].var()->tensor() = output_tensor;
} }
++output_num; ++output_num;
@ -474,4 +440,60 @@ Status XlaComputationLaunchContext::PopulateOutputs(
return Status::OK(); return Status::OK();
} }
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
if (constant_args.count(input_num) > 0) {
// Handles compile-time constants.
const Tensor& input = constant_args.at(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_args.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
if (input.NumElements() > 0) {
arg.kind = XlaCompiler::Argument::kParameter;
} else {
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = input;
}
arg.type = input.dtype();
arg.shape = input.shape();
} else {
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
const OptionalTensor& variable = variable_args.at(input_num);
arg.name = variable.name;
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.present) {
const Tensor& value = variable.value;
arg.type = value.dtype();
arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
// they are meaningless. However, it is legal to assign to a resource
// variable for the first time inside the XLA computation, so we do
// permit uninitialized variables.
arg.initialized = false;
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
}
}
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -35,6 +35,13 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
class XlaAllocator; class XlaAllocator;
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
string name; // A descriptive name
bool present = false; // Is the tensor present?
Tensor value; // If present, what is the Tensor's value?
};
// Takes a snapshot of the values of resource variable arguments, whose indices // Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back // are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is // resource variables since concurrent updates may modify the shape, and it is
@ -139,6 +146,13 @@ class XlaComputationLaunchContext {
bool allocate_xla_tensors, bool allocate_xla_tensors,
bool use_multiple_streams); bool use_multiple_streams);
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
// op.
static Status BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()). // Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable. // `variables` is a map from TensorFlow argument number to resource variable.
// //
@ -223,17 +237,6 @@ class XlaTensorBuffer : public TensorBuffer {
Allocator* allocator_; Allocator* allocator_;
}; };
// Exposed in this header file for microbenchmarking purposes, but this is an
// internal implementation detail.
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
xla::ScopedShapedBuffer ExtractSubShapedBuffer(
xla::ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator);
} // namespace internal
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_

View File

@ -1,64 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains microbenchmarks for performance critical functions in
// xla_launch_util.cc.
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs
// (cardinality of each non-leaf node's children).
void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
tensorflow::testing::StopTiming();
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
for (int i = 0; i < depth; ++i) {
std::vector<xla::Shape> shapes(fan_out, shape);
shape = xla::ShapeUtil::MakeTupleShape(shapes);
}
xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr,
/*device_ordinal=*/0);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
// Extract a buffer from approximately the middle of the first level of the
// tree.
(void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
/*index=*/fan_out / 2,
/*allocator=*/nullptr)
.release();
}
}
BENCHMARK(BM_ExtractSubBuffer)
->ArgPair(1, 4)
->ArgPair(1, 8)
->ArgPair(1, 32)
->ArgPair(1, 64)
->ArgPair(1, 128)
->ArgPair(1, 256)
->ArgPair(1, 512)
->ArgPair(2, 4)
->ArgPair(2, 8)
->ArgPair(2, 32)
->ArgPair(2, 64)
->ArgPair(2, 128);
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
tensorflow::testing::RunBenchmarks();
return RUN_ALL_TESTS();
}

View File

@ -43,11 +43,10 @@ namespace tensorflow {
} }
} }
Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, Status XlaTensor::AllocateShapedBuffer(DataType dtype,
const xla::Shape& on_host_shape,
xla::LocalClient* client, xla::LocalClient* client,
int device_ordinal) { int device_ordinal) {
xla::Shape on_host_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape));
xla::Shape on_device_shape = xla::Shape on_device_shape =
client->backend().transfer_manager()->HostShapeToDeviceShape( client->backend().transfer_manager()->HostShapeToDeviceShape(
on_host_shape); on_host_shape);

View File

@ -50,7 +50,7 @@ class XlaTensor {
// Assign the internal ShapedBuffer to new memory for the given dtype and // Assign the internal ShapedBuffer to new memory for the given dtype and
// shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it
// is replaced and the managed memory deallocated. // is replaced and the managed memory deallocated.
Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape,
xla::LocalClient* client, int device_ordinal); xla::LocalClient* client, int device_ordinal);
// Some Tensors can have complex on-device shapes, including tuple shapes. To // Some Tensors can have complex on-device shapes, including tuple shapes. To

View File

@ -470,12 +470,12 @@ tf_xla_py_test(
tags = ["optonly"], tags = ["optonly"],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops", "//tensorflow/python:spectral_ops",
"//tensorflow/python/ops/signal",
], ],
) )
@ -837,8 +837,6 @@ tf_xla_py_test(
name = "stack_ops_test", name = "stack_ops_test",
size = "small", size = "small",
srcs = ["stack_ops_test.py"], srcs = ["stack_ops_test.py"],
# Stack ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import itertools
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
@ -967,7 +969,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
self._testBinary( self._testBinary(
array_ops.expand_dims, array_ops.expand_dims,
np.array([42], dtype=dtype), np.array([42], dtype=dtype),
np.int32(0), np.array([0], dtype=np.int64),
expected=np.array([[42]], dtype=dtype)) expected=np.array([[42]], dtype=dtype))
self._testBinary( self._testBinary(
array_ops.expand_dims, array_ops.expand_dims,
@ -994,15 +996,21 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([[[1, 2], [3, 4]]], dtype=dtype), np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.int32(3), np.int32(3),
expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
self._testBinary(
array_ops.expand_dims,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
np.array([2], dtype=np.int64),
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
def testPad(self): def testPad(self):
for dtype in self.numeric_types: for dtype, pad_type in itertools.product(
self.numeric_types, [np.int32, np.int64]):
self._testBinary( self._testBinary(
array_ops.pad, array_ops.pad,
np.array( np.array(
[[1, 2, 3], [4, 5, 6]], dtype=dtype), [[1, 2, 3], [4, 5, 6]], dtype=dtype),
np.array( np.array(
[[1, 2], [2, 1]], dtype=np.int32), [[1, 2], [2, 1]], dtype=pad_type),
expected=np.array( expected=np.array(
[[0, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0], [0, 0, 1, 2, 3, 0],
@ -1016,7 +1024,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array( np.array(
[[1, 2, 3], [4, 5, 6]], dtype=dtype), [[1, 2, 3], [4, 5, 6]], dtype=dtype),
np.array( np.array(
[[0, 3], [2, 1]], dtype=np.int32), [[0, 3], [2, 1]], dtype=pad_type),
expected=np.array( expected=np.array(
[[7, 7, 1, 2, 3, 7], [[7, 7, 1, 2, 3, 7],
[7, 7, 4, 5, 6, 7], [7, 7, 4, 5, 6, 7],

View File

@ -24,10 +24,10 @@ import numpy as np
import scipy.signal as sps import scipy.signal as sps
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import signal
from tensorflow.python.ops import spectral_ops from tensorflow.python.ops import spectral_ops
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest

View File

@ -593,6 +593,67 @@ class LazyCompilationTest(test.TestCase):
self.assertFalse( self.assertFalse(
InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun"))
def testIsMegamorphic(self):
@function.Defun(compiled=True)
def CompiledFunction(x):
return math_ops.log(x)
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.float32)
y = CompiledFunction(x)
# Make the cluster go megamorphic by running it with lots of shape
# signatures where the cluster is executed with each signature only a few
# times. Then check that we don't compile the cluster ever again.
for shape in range(10, 50):
for _ in range(0, 49):
sess.run(y, feed_dict={x: [0.] * shape})
for _ in range(0, 50):
run_metadata = config_pb2.RunMetadata()
sess.run(
y,
feed_dict={x: [0.] * 60},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertTrue(
InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
def testIsNotMegamorphic(self):
@function.Defun(compiled=True)
def CompiledFunction(x):
return math_ops.log(x)
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.float32)
y = CompiledFunction(x)
# Run the cluster with lots of shape signatures, but in a way that it
# isn't megamorphic (i.e. each shape signature sees a lot of executions).
# Then check that the cluster has not been marked as megamorphic.
for shape in range(10, 50):
for _ in range(0, 1000):
sess.run(y, feed_dict={x: [0.] * shape})
for _ in range(0, 10):
sess.run(y, feed_dict={x: [0.] * 60})
run_metadata = config_pb2.RunMetadata()
sess.run(
y,
feed_dict={x: [0.] * 60},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
if __name__ == "__main__": if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " +

View File

@ -2466,20 +2466,21 @@ TEST_F(OpTest, Pack) {
}); });
} }
// TODO(b/31741898): crashes on GPU.
TEST_F(OpTest, Pad) { TEST_F(OpTest, Pad) {
Repeatedly([this]() { Repeatedly([this]() {
auto type = Choose<DataType>(kAllXlaTypes); auto type = Choose<DataType>(kAllXlaTypes);
std::vector<int64> t_dims = RandomDims(); std::vector<int64> t_dims = RandomDims();
// TODO(b/31741996): re-enable DT_INT64 when bug is fixed. DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
// DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
DataType tpaddings = DT_INT32;
std::vector<int64> paddings_vec; std::vector<int64> paddings_vec;
std::uniform_int_distribution<int> distribution(0, 7);
for (int i = 0; i < t_dims.size(); ++i) { for (int i = 0; i < t_dims.size(); ++i) {
paddings_vec.push_back(distribution(generator())); std::uniform_int_distribution<int> pad_distribution(0, t_dims[i]);
paddings_vec.push_back(distribution(generator())); int pad_size = pad_distribution(generator());
std::uniform_int_distribution<int> lower_distribution(0, pad_size);
int low_pad_size = lower_distribution(generator());
paddings_vec.push_back(low_pad_size);
paddings_vec.push_back(pad_size - low_pad_size);
t_dims[i] -= pad_size;
} }
Tensor paddings; Tensor paddings;
CHECK( CHECK(

View File

@ -37,7 +37,7 @@ class ResamplerOpsTest(xla_test.XLATestCase):
out = sess.run(resampled, {input_image: image_np, warp: warp_np}) out = sess.run(resampled, {input_image: image_np, warp: warp_np})
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2) expected, out, rtol=5e-3, half_rtol=1e-2, bfloat16_rtol=3e-2)
def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np,
expected_grad_data, expected_grad_warp): expected_grad_data, expected_grad_warp):

View File

@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
class VariableOpsTest(xla_test.XLATestCase): class VariableOpsTest(xla_test.XLATestCase):
"""Test cases for resource variable operators.""" """Test cases for resource variable operators."""
def testWriteEmptyShape(self):
# Verifies that we can pass an uninitialized variable with an empty shape,
# assign it a value, and successfully return it.
for dtype in self.numeric_types:
with self.test_session() as sess, self.test_scope():
zeros = np.zeros([3, 0], dtype=dtype)
v = resource_variable_ops.ResourceVariable(zeros)
p = array_ops.placeholder(dtype)
x = v.assign(p)
with ops.control_dependencies([x]):
y = v.read_value()
self.assertAllClose(zeros, sess.run(y, {p: zeros}))
def testOneWriteOneOutput(self): def testOneWriteOneOutput(self):
# Regression test for a bug where computations with one non-constant # Regression test for a bug where computations with one non-constant
# output and one variable update were mishandled. # output and one variable update were mishandled.

View File

@ -166,6 +166,7 @@ cc_library(
"xla_compilation_device.cc", "xla_compilation_device.cc",
"xla_compiler.cc", "xla_compiler.cc",
"xla_context.cc", "xla_context.cc",
"xla_expression.cc",
"xla_helpers.cc", "xla_helpers.cc",
"xla_op_kernel.cc", "xla_op_kernel.cc",
"xla_op_registry.cc", "xla_op_registry.cc",
@ -180,6 +181,7 @@ cc_library(
"xla_compilation_device.h", "xla_compilation_device.h",
"xla_compiler.h", "xla_compiler.h",
"xla_context.h", "xla_context.h",
"xla_expression.h",
"xla_helpers.h", "xla_helpers.h",
"xla_op_kernel.h", "xla_op_kernel.h",
"xla_op_registry.h", "xla_op_registry.h",
@ -194,6 +196,7 @@ cc_library(
":side_effect_util", ":side_effect_util",
":tf2xla_util", ":tf2xla_util",
"//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -217,6 +220,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
alwayslink = 1, alwayslink = 1,
@ -362,8 +366,12 @@ tf_cc_test(
tf_cc_test( tf_cc_test(
name = "xla_compiler_test", name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"], srcs = [
"xla_compiler_test.cc",
"xla_expression_test.cc",
],
deps = [ deps = [
":common",
":side_effect_util", ":side_effect_util",
":xla_compiler", ":xla_compiler",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
@ -386,6 +394,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -435,7 +444,7 @@ cc_library(
"dump_graph.h", "dump_graph.h",
], ],
deps = [ deps = [
"//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" #include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
@ -41,7 +41,7 @@ static void AllocateFlags() {
"Path prefix to which graphs dumped during debugging should be " "Path prefix to which graphs dumped during debugging should be "
"written."), "written."),
}); });
xla::legacy_flags::ParseFlagsFromEnv(*flag_list); xla::ParseFlagsFromEnv(*flag_list);
} }
// Append to *append_to flag definitions associated with the XLA bridge's // Append to *append_to flag definitions associated with the XLA bridge's

View File

@ -242,23 +242,20 @@ Status FunctionalizeControlFlowPass::Run(
continue; continue;
} }
const string func_attr = it->second; const string func_attr = it->second;
if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != NameAttrList func;
kNodeTypeToFunctionAttrMapping->end()) { TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
NameAttrList func; VLOG(2) << "Graph has node " << n->type_string()
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); << ". Corresponding function: " << func.name();
VLOG(2) << "Graph has node " << n->type_string() string new_func_name = options.flib_def->UniqueFunctionName(
<< ". Corresponding function: " << func.name(); absl::StrCat(func.name(), "_f15n_"));
string new_func_name = options.flib_def->UniqueFunctionName( bool modified;
absl::StrCat(func.name(), "_f15n_")); TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
bool modified; func.name(), new_func_name, func.attr(), options.flib_def, flr,
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( &canonicalized_name_to_new_name, &modified));
func.name(), new_func_name, func.attr(), options.flib_def, flr, if (modified) {
&canonicalized_name_to_new_name, &modified)); n->ClearAttr(func_attr);
if (modified) { func.set_name(new_func_name);
n->ClearAttr(func_attr); n->AddAttr(func_attr, func);
func.set_name(new_func_name);
n->AddAttr(func_attr, func);
}
} }
} }

View File

@ -23,9 +23,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -51,12 +52,11 @@ namespace {
Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector<const XlaExpression*>& expressions, const std::vector<const XlaExpression*>& expressions,
std::vector<XlaCompiler::Argument>* args) { std::vector<XlaCompiler::Argument>* args) {
auto builder = ctx->builder();
auto client = ctx->compiler()->client(); auto client = ctx->compiler()->client();
std::vector<bool> compile_time_constant_flags(expressions.size()); std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*graph, &compile_time_constant_flags, BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant,
/*compile_time_const_nodes=*/nullptr)); /*compile_time_const_nodes=*/nullptr));
args->resize(expressions.size()); args->resize(expressions.size());
@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.type = ctx->input_type(i); arg.type = ctx->input_type(i);
arg.shape = ctx->InputShape(i); arg.shape = ctx->InputShape(i);
if (arg.type == DT_RESOURCE) { switch (expressions[i]->kind()) {
return errors::InvalidArgument( case XlaExpression::Kind::kConstant:
"Resource as function argument is not yet implemented."); arg.kind = XlaCompiler::Argument::kConstant;
} else if (expressions[i]->has_constant_value()) { arg.constant_value = expressions[i]->constant_value();
arg.kind = XlaCompiler::Argument::kConstant; break;
arg.constant_value = expressions[i]->constant_value(); case XlaExpression::Kind::kXlaOp:
} else if (compile_time_constant_flags[i]) { if (arg_must_be_compile_time_constant[i]) {
arg.kind = XlaCompiler::Argument::kConstant; TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
TF_RET_CHECK(expressions[i]->resource() == nullptr) expressions[i]->ResolveConstant(client));
<< "Input with resource is not yet implemented."; if (!value.has_value()) {
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( return errors::InvalidArgument(
expressions[i]->handle())); "Argument to function must be a compile-time constant, but "
TF_ASSIGN_OR_RETURN(auto literal, "unable to resolve argument value to a constant.");
client->ComputeConstant(constant_graph)); }
TF_RETURN_IF_ERROR( arg.kind = XlaCompiler::Argument::kConstant;
LiteralToHostTensor(literal, arg.type, &arg.constant_value)); arg.constant_value = *value;
} else { } else {
arg.kind = XlaCompiler::Argument::kParameter; arg.kind = XlaCompiler::Argument::kParameter;
}
break;
case XlaExpression::Kind::kResource:
return errors::Unimplemented(
"Resource as function argument is not yet implemented.");
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument("Invalid function argument");
} }
} }
return Status::OK(); return Status::OK();

View File

@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow { namespace tensorflow {
@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel {
} }
const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
if (arg.resource() != nullptr) { OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
ctx->SetResourceOutput(0, arg.resource()); errors::InvalidArgument("Invalid/missing argument expression"));
} else if (arg.has_constant_value()) { ctx->SetOutputExpression(0, arg);
ctx->SetConstantOutput(0, arg.constant_value());
} else {
ctx->SetOutput(0, arg.handle());
}
} }
private: private:

View File

@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
errors::InvalidArgument("In[", i, "] must be a vector.", errors::InvalidArgument("In[", i, "] must be a vector.",
in_shape.DebugString())); in_shape.DebugString()));
xla::Literal literal; std::vector<int64> vec;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec));
BCast::Vec vec; shapes.push_back(BCast::Vec(vec.begin(), vec.end()));
for (int64 i = 0; i < in_shape.num_elements(); ++i) {
vec.push_back(literal.Get<int>({i}));
}
shapes.push_back(vec);
} }
BCast bcast(shapes[0], shapes[1]); BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(), OP_REQUIRES(ctx, bcast.IsValid(),

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
@ -45,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_);
OP_REQUIRES( OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape),
ctx, IsLegacyScalar(concat_dim_tensor_shape), errors::InvalidArgument(
errors::InvalidArgument( "Concat dim tensor should be a scalar, but got shape ",
"Concat dim tensor should be a scalar integer, but got shape ", concat_dim_tensor_shape.DebugString()));
concat_dim_tensor_shape.DebugString())); int64 concat_dim;
xla::Literal literal; OP_REQUIRES_OK(ctx,
OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim));
// TODO(annarev): add a helper to support int64 input.
const int32 concat_dim = literal.Get<int>({});
std::vector<xla::XlaOp> values; std::vector<xla::XlaOp> values;
std::vector<TensorShape> shapes; std::vector<TensorShape> shapes;
@ -63,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel {
const TensorShape& input_shape = shapes[0]; const TensorShape& input_shape = shapes[0];
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
OP_REQUIRES(ctx, OP_REQUIRES(ctx, 0 <= axis && axis < input_dims,
(0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument( errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range " "ConcatOp : Expected concatenating dimensions in the range "
"[", "[",
@ -75,14 +72,11 @@ class ConcatBaseOp : public XlaOpKernel {
// elements. // elements.
std::vector<xla::XlaOp> input_data; std::vector<xla::XlaOp> input_data;
int output_concat_dim = 0; int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
xla::XlaOp handle = values[i]; xla::XlaOp handle = values[i];
const TensorShape& in_shape = shapes[i]; const TensorShape& in_shape = shapes[i];
const bool in_is_scalar = IsLegacyScalar(in_shape);
OP_REQUIRES( OP_REQUIRES(
ctx, ctx, in_shape.dims() == input_dims,
in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument( errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ", "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i, input_shape.DebugString(), " vs. shape[", i,
@ -131,11 +125,10 @@ class ConcatOffsetOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape concat_dim_shape = ctx->InputShape(0); const TensorShape concat_dim_shape = ctx->InputShape(0);
OP_REQUIRES( OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape),
ctx, IsLegacyScalar(concat_dim_shape), errors::InvalidArgument(
errors::InvalidArgument( "Concat dim tensor should be a scalar, but got shape ",
"Concat dim tensor should be a scalar integer, but got shape ", concat_dim_shape.DebugString()));
concat_dim_shape.DebugString()));
for (int i = 1; i < ctx->num_inputs(); ++i) { for (int i = 1; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)),
errors::InvalidArgument("input ", i, errors::InvalidArgument("input ", i,
@ -162,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel {
// [0, 5, 0, 0] // [0, 5, 0, 0]
const int32 N = ctx->num_inputs() - 1; const int32 N = ctx->num_inputs() - 1;
const TensorShape inp0_shape = ctx->InputShape(1); const TensorShape inp0_shape = ctx->InputShape(1);
xla::Literal inp0_literal; std::vector<int64> inp0_dims;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims));
const int64 dims = inp0_shape.num_elements(); const int64 inp0_rank = inp0_shape.num_elements();
xla::Literal concat_dim_literal; int64 cdim;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim));
const int64 cdim = concat_dim_literal.Get<int>({});
VLOG(1) << "ConcatOffset " << cdim << "," << dims; VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank;
int32 axis = cdim < 0 ? cdim + dims : cdim; int32 axis = cdim < 0 ? cdim + inp0_rank : cdim;
OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank),
errors::InvalidArgument("Concat dim is out of range: ", axis, errors::InvalidArgument("Concat dim is out of range: ", axis,
" vs. ", dims)); " vs. ", inp0_rank));
int32 offset = 0; int32 offset = 0;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
const TensorShape inp_shape = ctx->InputShape(1 + i); const TensorShape inp_shape = ctx->InputShape(1 + i);
OP_REQUIRES(ctx, dims == inp_shape.num_elements(), OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(),
errors::InvalidArgument("input ", i, " should contain ", dims, errors::InvalidArgument("input ", i, " should contain ",
" elements, but got ", inp0_rank, " elements, but got ",
inp_shape.num_elements())); inp_shape.num_elements()));
xla::Literal inp_literal; std::vector<int64> inp_dims;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims));
Tensor out_constant(DT_INT32, TensorShape({dims})); Tensor out_constant(DT_INT32, TensorShape({inp0_rank}));
auto out_vec = out_constant.vec<int32>(); auto out_vec = out_constant.vec<int32>();
for (int64 j = 0; j < dims; ++j) { for (int64 j = 0; j < inp0_rank; ++j) {
if (j == axis) { if (j == axis) {
out_vec(j) = offset; out_vec(j) = offset;
offset += inp_literal.Get<int>({j}); offset += inp_dims[j];
} else { } else {
const int32 inp0_element = inp0_literal.Get<int>({j}); const int32 inp0_element = inp0_dims[j];
const int32 inp_element = inp_literal.Get<int>({j}); const int32 inp_element = inp_dims[j];
OP_REQUIRES(ctx, (inp0_element == inp_element), OP_REQUIRES(ctx, inp0_element == inp_element,
errors::InvalidArgument("input[", i, ",", j, errors::InvalidArgument("input[", i, ",", j,
"] mismatch: ", inp0_element, "] mismatch: ", inp0_element,
" vs. ", inp_element)); " vs. ", inp_element));

View File

@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape(proto_.tensor_shape()); TensorShape shape(proto_.tensor_shape());
if (proto_.dtype() == DT_STRING) {
LOG(WARNING) << "Not computing Const of type DT_STRING";
ctx->SetInvalidOutput(0);
return;
}
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
// To avoid blowups for large constants filled with the same value, // To avoid blowups for large constants filled with the same value,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -33,39 +34,20 @@ class FillOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
// The output of this Op is a tensor of shape 'dims_shape' with each // The output of this Op is a tensor of shape 'dims_shape' with each
// element set to the scalar 'dims_literal'. // element set to the scalar 'dims_literal'.
const TensorShape dims_shape = ctx->InputShape(0); const TensorShape dims_shape = ctx->InputShape("dims");
const TensorShape value_shape = ctx->InputShape(1); const TensorShape value_shape = ctx->InputShape("value");
OP_REQUIRES( OP_REQUIRES(
ctx, IsLegacyVector(dims_shape), ctx, TensorShapeUtils::IsVector(dims_shape),
errors::InvalidArgument("dims must be a vector of int32, got shape ", errors::InvalidArgument("dims must be a vector of int32, got shape ",
dims_shape.DebugString())); dims_shape.DebugString()));
OP_REQUIRES(ctx, IsLegacyScalar(value_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape),
errors::InvalidArgument("value must be a scalar, got shape ", errors::InvalidArgument("value must be a scalar, got shape ",
value_shape.DebugString())); value_shape.DebugString()));
// Evaluate the 'dims' constant input, reshaping to a vector if it
// was a 'legacy' vector (secretly a scalar).
xla::Literal dims_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(
0, {dims_shape.num_elements()}, &dims_literal));
// Convert the dims literal into a vector that we can pass to std::vector<int64> dims;
// XlaBuilder. OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims));
std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
broadcast.push_back(dims_literal.Get<int>({i}));
}
// Look up the value input, reshaping to a scalar if it was a
// 'legacy' scalar (secretly a vector).
xla::XlaOp data = ctx->Input(1);
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
data = xla::Reshape(data, {});
}
// Emit the actual computation, which broadcasts the scalar to the
// desired shape.
auto result = xla::Broadcast(data, broadcast);
auto result = xla::Broadcast(ctx->Input("value"), dims);
ctx->SetOutput(0, result); ctx->SetOutput(0, result);
} }
}; };

View File

@ -48,9 +48,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// We require that the dimension argument is a constant, since it lets us // We require that the dimension argument is a constant, since it lets us
// dispatch to a specialized custom-call function without any run-time // dispatch to a specialized custom-call function without any run-time
// overhead, when compiling ahead-of-time. // overhead, when compiling ahead-of-time.
xla::Literal literal; int64 dim;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
const int32 dim = literal.Get<int32>({});
OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
OP_REQUIRES( OP_REQUIRES(
ctx, dim < input_shape.dims(), ctx, dim < input_shape.dims(),

View File

@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel {
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) { --dimno) {
auto t_rev = xla::Rev(accum, {dimno}); auto t_rev = xla::Rev(accum, {dimno});
TF_ASSIGN_OR_RETURN(int64 lhs_padding, int64 lhs_padding = pad_literal.Get<int64>({dimno, 0});
pad_literal.GetIntegralAsS64({dimno, 0})); int64 rhs_padding = pad_literal.Get<int64>({dimno, 1});
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
pad_literal.GetIntegralAsS64({dimno, 1}));
int64 dim_size = original_shape.dimensions(dimno); int64 dim_size = original_shape.dimensions(dimno);
// Padding amounts on each side must be no more than the size of the // Padding amounts on each side must be no more than the size of the
@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel {
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0); const TensorShape input_shape = ctx->InputShape("input");
const TensorShape pad_shape = ctx->InputShape(1); const TensorShape pad_shape = ctx->InputShape("paddings");
MirrorPadMode mode; MirrorPadMode mode;
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel {
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
errors::InvalidArgument("paddings must be a matrix with 2 columns: ", errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
pad_shape.DebugString())); pad_shape.DebugString()));
const int fixed_dims =
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
? 1
: dims;
OP_REQUIRES( OP_REQUIRES(
ctx, fixed_dims == pad_shape.dim_size(0), ctx, dims == pad_shape.dim_size(0),
errors::InvalidArgument( errors::InvalidArgument(
"The first dimension of paddings must be the rank of inputs", "The first dimension of paddings must be the rank of inputs",
pad_shape.DebugString(), " ", input_shape.DebugString())); pad_shape.DebugString(), " ", input_shape.DebugString()));
// Evaluate the 'padding' constant input, reshaping to a matrix. // Evaluate the 'padding' constant input, reshaping to a matrix.
xla::Literal pad_literal; xla::Literal pad_literal;
OP_REQUIRES_OK( OP_REQUIRES_OK(ctx,
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0); auto in0 = ctx->Input("input");
xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0); xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
xla::StatusOr<xla::XlaOp> accum_status = xla::StatusOr<xla::XlaOp> accum_status =

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel {
explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0); const TensorShape input_shape = ctx->InputShape("input");
const TensorShape pad_shape = ctx->InputShape(1); const TensorShape pad_shape = ctx->InputShape("paddings");
const int dims = input_shape.dims(); const int dims = input_shape.dims();
OP_REQUIRES( OP_REQUIRES(
ctx, ctx,
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
errors::InvalidArgument("paddings must be a matrix with 2 columns: ", errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
pad_shape.DebugString())); pad_shape.DebugString()));
const int fixed_dims =
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
? 1
: dims;
OP_REQUIRES( OP_REQUIRES(
ctx, fixed_dims == pad_shape.dim_size(0), ctx, dims == pad_shape.dim_size(0),
errors::InvalidArgument( errors::InvalidArgument(
"The first dimension of paddings must be the rank of inputs", "The first dimension of paddings must be the rank of inputs",
pad_shape.DebugString(), " ", input_shape.DebugString())); pad_shape.DebugString(), " ", input_shape.DebugString()));
if (fixed_dims == 0) { xla::XlaOp input = ctx->Input("input");
if (dims == 0) {
// Tensor is rank 0. Return it unchanged. // Tensor is rank 0. Return it unchanged.
ctx->SetOutput(0, ctx->Input(0)); ctx->SetOutput(0, input);
return; return;
} }
// Evaluate the 'padding' constant input, reshaping to a matrix.
xla::Literal pad_literal; xla::Literal pad_literal;
OP_REQUIRES_OK( OP_REQUIRES_OK(ctx,
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
xla::PaddingConfig config; xla::PaddingConfig config;
for (int i = 0; i < fixed_dims; ++i) { for (int i = 0; i < dims; ++i) {
auto* dim = config.add_dimensions(); auto* dim = config.add_dimensions();
int before = pad_literal.Get<int32>({i, 0}); int before = pad_literal.Get<int64>({i, 0});
int after = pad_literal.Get<int32>({i, 1}); int after = pad_literal.Get<int64>({i, 1});
OP_REQUIRES(ctx, before >= 0 && after >= 0, OP_REQUIRES(ctx, before >= 0 && after >= 0,
errors::InvalidArgument( errors::InvalidArgument(
"Paddings must be non-negative: ", before, " ", after)); "Paddings must be non-negative: ", before, " ", after));
@ -73,12 +70,13 @@ class PadOp : public XlaOpKernel {
// PadV2 added a "constant_values" input that indicates the pad value. // PadV2 added a "constant_values" input that indicates the pad value.
xla::XlaOp constant_values; xla::XlaOp constant_values;
if (ctx->num_inputs() == 3) { if (ctx->num_inputs() == 3) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), OP_REQUIRES(
errors::InvalidArgument("constant_values must be a scalar.")); ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")),
ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); errors::InvalidArgument("constant_values must be a scalar."));
ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config));
} else { } else {
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); ctx->SetOutput(0, xla::Pad(input, zero, config));
} }
} }
}; };

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0); const TensorShape input_shape = ctx->InputShape(0);
const TensorShape sizes_shape = ctx->InputShape(1); const TensorShape sizes_shape = ctx->InputShape(1);
// Preliminary validation of sizes. // Preliminary validation of sizes.
OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape),
errors::InvalidArgument("sizes input must be 1-D, not shape ", errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes_shape.DebugString())); sizes_shape.DebugString()));
const int64 num_dims = sizes_shape.num_elements(); const int64 num_dims = sizes_shape.num_elements();

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -46,61 +47,8 @@ class RetvalOp : public XlaOpKernel {
// compilation. // compilation.
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
} else { } else {
xla::XlaOp input = ctx->Input(0); XlaContext& xla_context = XlaContext::Get(ctx);
const TensorShape input_shape = ctx->InputShape(0); xla_context.SetRetval(index_, ctx->InputExpression(0));
DataType input_type = ctx->input_type(0);
XlaContext& tc = XlaContext::Get(ctx);
if (input_type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
ctx->SetStatus(tc.AddResourceRetval(index_, resource));
return;
}
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
ctx->SetStatus(is_constant.status());
return;
}
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
ctx->SetStatus(is_constant.status());
TensorShape representation_shape;
if (tc.is_entry_computation()) {
xla::StatusOr<TensorShape> shape_or_status =
tc.RepresentationShape(shape, ctx->input_type(0));
if (!shape_or_status.ok()) {
ctx->SetStatus(shape_or_status.status());
return;
} else {
representation_shape = shape_or_status.ValueOrDie();
}
} else {
representation_shape = shape;
}
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
output = xla::Reshape(input, representation_shape.dim_sizes());
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
// the device assignment of "input" at this point, we must always
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
output =
xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
}
} }
} }

View File

@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel {
} }
// XlaBuilder::Rev() requires concrete values for dimensions arg. // XlaBuilder::Rev() requires concrete values for dimensions arg.
xla::Literal lax; xla::Literal lax;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax));
std::vector<bool> revdims(x_shape.dims());
std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
revdims.begin());
std::vector<int64> dimensions;
std::vector<int64> dimensions;
for (int d = 0; d < x_shape.dims(); ++d) { for (int d = 0; d < x_shape.dims(); ++d) {
if (revdims[d]) { if (lax.Get<bool>({d})) {
dimensions.push_back(d); dimensions.push_back(d);
} }
} }

View File

@ -30,31 +30,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
template <typename T>
Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
*value = literal.Get<T>({});
return Status::OK();
}
Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
switch (literal.shape().element_type()) {
case xla::S32:
*value = literal.Get<int32>({});
break;
case xla::S64:
*value = literal.Get<int64>({});
break;
default:
return errors::InvalidArgument("Invalid argument type for argument",
index);
}
return Status::OK();
}
// The type-specific part of the implementation of Range. // The type-specific part of the implementation of Range.
template <typename T> template <typename T>
xla::StatusOr<xla::XlaOp> CreateRangeTensor( xla::StatusOr<xla::XlaOp> CreateRangeTensor(
@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel {
const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape start_in_shape = ctx->InputShape(0);
const TensorShape limit_in_shape = ctx->InputShape(1); const TensorShape limit_in_shape = ctx->InputShape(1);
const TensorShape delta_in_shape = ctx->InputShape(2); const TensorShape delta_in_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
errors::InvalidArgument("start must be a scalar, not shape ", errors::InvalidArgument("start must be a scalar, not shape ",
start_in_shape.DebugString())); start_in_shape.DebugString()));
OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape),
errors::InvalidArgument("limit must be a scalar, not shape ", errors::InvalidArgument("limit must be a scalar, not shape ",
limit_in_shape.DebugString())); limit_in_shape.DebugString()));
OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape),
errors::InvalidArgument("delta must be a scalar, not shape ", errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in_shape.DebugString())); delta_in_shape.DebugString()));
xla::Literal start, limit, delta; xla::Literal start, limit, delta;
@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel {
explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape start_in_shape = ctx->InputShape("start");
const TensorShape stop_in_shape = ctx->InputShape(1); const TensorShape stop_in_shape = ctx->InputShape("stop");
const TensorShape num_in_shape = ctx->InputShape(2); const TensorShape num_in_shape = ctx->InputShape("num");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
errors::InvalidArgument("start must be a scalar, not shape ", errors::InvalidArgument("start must be a scalar, not shape ",
start_in_shape.DebugString())); start_in_shape.DebugString()));
@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel {
DataType type = ctx->input_type(0); DataType type = ctx->input_type(0);
int64 num; int64 num;
OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
OP_REQUIRES(ctx, num > 0, OP_REQUIRES(ctx, num > 0,
errors::InvalidArgument("Requires num > 0: ", num)); errors::InvalidArgument("Requires num > 0: ", num));
Tensor out_constant(type, TensorShape({num})); Tensor out_constant(type, TensorShape({num}));
xla::Literal start_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
xla::Literal stop_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
switch (type) { switch (type) {
case DT_FLOAT: { case DT_FLOAT: {
float start, stop; float start = start_literal.GetFirstElement<float>();
OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); float stop = stop_literal.GetFirstElement<float>();
OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
auto flat = out_constant.flat<float>(); auto flat = out_constant.flat<float>();
if (num == 1) { if (num == 1) {
flat(0) = start; flat(0) = start;
@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel {
break; break;
} }
case DT_DOUBLE: { case DT_DOUBLE: {
double start, stop; double start = start_literal.GetFirstElement<double>();
OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); double stop = stop_literal.GetFirstElement<double>();
OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
auto flat = out_constant.flat<double>(); auto flat = out_constant.flat<double>();
if (num == 1) { if (num == 1) {
flat(0) = start; flat(0) = start;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
namespace tensorflow { namespace tensorflow {
@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel {
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0); const TensorShape input_shape = ctx->InputShape("input");
const TensorShape dim_shape = ctx->InputShape(1); const TensorShape dim_shape = ctx->InputShape("dim");
// TODO(phawkins): the standard implementation of ExpandDimsOp seems to std::vector<int64> dims;
// accept legacy scalars, even when they should be forbidden by the graphdef OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
// version. OP_REQUIRES(ctx, dims.size() == 1,
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
errors::InvalidArgument(absl::StrCat( errors::InvalidArgument(absl::StrCat(
"dim input to ExpandDims must be a scalar; got ", "dim input to ExpandDims must be a scalar; got ",
dim_shape.DebugString()))); dim_shape.DebugString())));
int dim = dims[0];
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
int dim = literal.data<int32>()[0];
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
@ -148,7 +144,7 @@ class ExpandDimsOp : public XlaOpKernel {
dim = std::min<int32>(dim, existing_dims_size); dim = std::min<int32>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1); new_shape.emplace(new_shape.begin() + dim, 1);
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
} }
}; };
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel {
OP_REQUIRES( OP_REQUIRES(
ctx, ctx,
IsLegacyVector(begin_tensor_shape) && TensorShapeUtils::IsVector(begin_tensor_shape) &&
IsLegacyVector(size_tensor_shape) && TensorShapeUtils::IsVector(size_tensor_shape) &&
begin_tensor_shape.num_elements() == input_shape.dims() && begin_tensor_shape.num_elements() == input_shape.dims() &&
size_tensor_shape.num_elements() == input_shape.dims(), size_tensor_shape.num_elements() == input_shape.dims(),
errors::InvalidArgument( errors::InvalidArgument(

View File

@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const int32 num_split = num_outputs(); const int32 num_split = num_outputs();
const TensorShape index_shape = ctx->InputShape(0); const TensorShape split_dim_shape = ctx->InputShape("split_dim");
const TensorShape input_shape = ctx->InputShape(1); const TensorShape input_shape = ctx->InputShape(1);
xla::Literal literal_index; OP_REQUIRES(
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); ctx, TensorShapeUtils::IsScalar(split_dim_shape),
errors::InvalidArgument("split_dim must be a scalar but has rank ",
split_dim_shape.dims()));
int64 split_dim_orig;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig));
int32 split_dim_orig;
if (index_shape.dims() == 0) {
split_dim_orig = literal_index.Get<int>({});
} else {
OP_REQUIRES(
ctx, index_shape.dims() == 1,
errors::InvalidArgument("split_index input to Split Op must be a "
"scalar or a vector with 1 element"));
OP_REQUIRES(
ctx, index_shape.dim_size(0) == 1,
errors::InvalidArgument("split_index input to Split Op must be a "
"scalar or a vector with 1 element"));
split_dim_orig = literal_index.Get<int>({0});
}
int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
: split_dim_orig; : split_dim_orig;
OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel {
// Check that sizes are correct. // Check that sizes are correct.
int total_split_size = 0; int total_split_size = 0;
int neg_one_dim = -1; int neg_one_dim = -1;
std::vector<int64> split_sizes_vec(num_split, -1);
const TensorShape split_size_shape = ctx->InputShape(1); const TensorShape split_size_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
split_size_shape.dims() == 1 && split_size_shape.dims() == 1 &&
@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel {
split_size_shape.dims(), "-D and ", split_size_shape.dims(), "-D and ",
split_size_shape.num_elements(), " elements")); split_size_shape.num_elements(), " elements"));
// Get the dimension of this split. // Get the dimension of this split.
xla::Literal split_size_literal; std::vector<int64> split_sizes;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes));
for (int i = 0; i < num_split; ++i) { for (int i = 0; i < num_split; ++i) {
int slice_size; int64 slice_size = split_sizes[i];
slice_size = split_size_literal.Get<int>({i});
if (slice_size == -1) { if (slice_size == -1) {
OP_REQUIRES( OP_REQUIRES(
ctx, neg_one_dim == -1, ctx, neg_one_dim == -1,
@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel {
i)); i));
neg_one_dim = i; neg_one_dim = i;
} else { } else {
split_sizes_vec[i] = slice_size;
total_split_size += slice_size; total_split_size += slice_size;
} }
} }
@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel {
total_split_size)); total_split_size));
if (neg_one_dim >= 0) { if (neg_one_dim >= 0) {
split_sizes_vec[neg_one_dim] = split_sizes[neg_one_dim] =
input_shape.dim_size(split_dim) - total_split_size; input_shape.dim_size(split_dim) - total_split_size;
} }
@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel {
std::vector<int64> strides(input_shape.dims(), 1); std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < num_split; ++i) { for (int i = 0; i < num_split; ++i) {
TensorShape output_shape(input_shape); TensorShape output_shape(input_shape);
int slice_size = split_sizes_vec[i]; int slice_size = split_sizes[i];
output_shape.set_dim(split_dim, slice_size); output_shape.set_dim(split_dim, slice_size);
// Slice out the ith split from the split dimension. // Slice out the ith split from the split dimension.

View File

@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackOp); TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
}; };
REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp); REGISTER_XLA_OP(
Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
StackOp);
class StackPushOp : public XlaOpKernel { class StackPushOp : public XlaOpKernel {
public: public:
@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
}; };
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
class StackPopOp : public XlaOpKernel { class StackPopOp : public XlaOpKernel {
public: public:
@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
}; };
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
class StackCloseOp : public XlaOpKernel { class StackCloseOp : public XlaOpKernel {
public: public:
@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
}; };
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
} // anonymous namespace } // anonymous namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
@ -44,7 +45,7 @@ class TileOp : public XlaOpKernel {
const TensorShape multiples_shape = ctx->InputShape("multiples"); const TensorShape multiples_shape = ctx->InputShape("multiples");
OP_REQUIRES( OP_REQUIRES(
ctx, IsLegacyVector(multiples_shape), ctx, TensorShapeUtils::IsVector(multiples_shape),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples_shape.DebugString())); multiples_shape.DebugString()));
OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(),

View File

@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel {
: XlaOpKernel(ctx), conjugate_(conjugate) {} : XlaOpKernel(ctx), conjugate_(conjugate) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0); const TensorShape input_shape = ctx->InputShape("x");
const TensorShape perm_tensor_shape = ctx->InputShape(1); const TensorShape perm_tensor_shape = ctx->InputShape("perm");
// Preliminary validation of sizes. // Preliminary validation of sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel {
". But input(1) is a vector of size ", ". But input(1) is a vector of size ",
perm_tensor_shape.num_elements())); perm_tensor_shape.num_elements()));
xla::Literal literal; std::vector<int64> perm;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm));
std::vector<int32> perm(dims);
std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
perm.begin());
std::vector<int64> transposed_order; std::vector<int64> transposed_order;
// Check whether permutation is a permutation of integers of [0 .. dims). // Check whether permutation is a permutation of integers of [0 .. dims).
absl::InlinedVector<bool, 8> bits(dims); absl::InlinedVector<bool, 8> bits(dims);
bool is_identity = true; bool is_identity = true;
for (int i = 0; i < dims; ++i) { for (int i = 0; i < dims; ++i) {
const int32 d = perm[i]; const int64 d = perm[i];
OP_REQUIRES( OP_REQUIRES(
ctx, 0 <= d && d < dims, ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel {
xla::XlaOp transposed; xla::XlaOp transposed;
// 0-D, 1-D, and identity transposes do nothing. // 0-D, 1-D, and identity transposes do nothing.
if (dims <= 1 || is_identity) { if (dims <= 1 || is_identity) {
transposed = ctx->Input(0); transposed = ctx->Input("x");
} else { } else {
transposed = xla::Transpose(ctx->Input(0), transposed_order); transposed = xla::Transpose(ctx->Input("x"), transposed_order);
} }
// Conjugate the transposed result if this is ConjugateTransposeOp. // Conjugate the transposed result if this is ConjugateTransposeOp.

View File

@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
XLAJIT_MAKE_UNARY(Neg, -x); XLAJIT_MAKE_UNARY(Neg, -x);
// Implements Banker's rounding: numbers that are equidistant between two XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x));
// integers are rounded towards even. XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x));
xla::XlaOp RoundToEven(xla::XlaOp x) {
auto half = xla::ScalarLike(x, 0.5);
auto one = xla::ScalarLike(x, 1.0);
auto two = xla::ScalarLike(x, 2.0);
auto round_val = xla::Floor(x);
auto fraction = x - round_val;
auto nearest_even_int = round_val - two * xla::Floor(half * x);
auto is_odd = xla::Eq(nearest_even_int, one);
return xla::Select(xla::Or(xla::Gt(fraction, half),
xla::And(xla::Eq(fraction, half), is_odd)),
round_val + one, round_val);
}
XLAJIT_MAKE_UNARY(Rint, RoundToEven(x));
XLAJIT_MAKE_UNARY(Round, RoundToEven(x));
XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));

View File

@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
return Status::OK(); return Status::OK();
} }
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor) {
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal));
return literal.Clone();
}
Status HostTensorToMutableBorrowingLiteral( Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
xla::Shape xla_shape; xla::Shape xla_shape;

View File

@ -30,6 +30,11 @@ namespace tensorflow {
// 'host_tensor'. // 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal); xla::BorrowingLiteral* literal);
// Returns a Literal with the contents of 'host_tensor', backed by its own
// storage (i.e., not reusing 'host_tensor's buffers.)
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor);
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
// owned by 'host_tensor', but is mutable via the xla::Literal methods. // owned by 'host_tensor', but is mutable via the xla::Literal methods.
Status HostTensorToMutableBorrowingLiteral( Status HostTensorToMutableBorrowingLiteral(

View File

@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//learning/deepmind/public/wavenet/python:__subpackages__", "//learning/deepmind/public/wavenet/python:__subpackages__",
"//learning/deepmind/research/alphastar:__subpackages__",
"//learning/tfx:__subpackages__", "//learning/tfx:__subpackages__",
"//tensorflow:internal", "//tensorflow:internal",
], ],

View File

@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto(
"XLACompilationDevice::MakeTensorFromProto should not be called"); "XLACompilationDevice::MakeTensorFromProto should not be called");
} }
XlaExpression::XlaExpression() = default;
void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
void XlaExpression::set_constant_value(Tensor value) {
has_constant_value_ = true;
constant_value_ = std::move(value);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,9 +18,6 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
@ -38,8 +35,8 @@ class XlaCompilationAllocator;
// This is a 'dummy' TensorFlow device that is only used to execute a // This is a 'dummy' TensorFlow device that is only used to execute a
// subgraph of XLA compilation Ops to construct a compiled version // subgraph of XLA compilation Ops to construct a compiled version
// of the subgraph's computation. It has a 'dummy' allocator that // of the subgraph's computation. It has a 'dummy' allocator that
// backs each Tensor with metadata indicating the computation the // backs each Tensor with an XlaExpression. The shape of the Tensor
// Tensor represents. // matches the shape of XlaExpression.
// //
// We deliberately don't register a device factory because we *never* // We deliberately don't register a device factory because we *never*
// want placement to put Ops on a compilation device. The device is created // want placement to put Ops on a compilation device. The device is created
@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_; std::unique_ptr<XlaCompilationAllocator> allocator_;
}; };
// A XlaExpression wraps an XLA computation. Each Tensor on an
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
// matches the shape of the subcomputation in the XlaOp. Each
// expression is either a constant, or a function of previously-compiled
// expressions.
class XlaExpression {
public:
XlaExpression();
// handle() stores the XLA handle of the computation that the
// expression represents.
void set_handle(const xla::XlaOp& h);
const xla::XlaOp& handle() const { return handle_; }
void set_constant_value(Tensor value);
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
void set_resource(XlaResource* resource) { resource_ = resource; }
XlaResource* resource() const { return resource_; }
private:
// The XLA handle of the expression's computation.
xla::XlaOp handle_;
// If this expression is a constant with a known value, 'constant_value' is a
// host-memory Tensor containing the value. Used to avoid invoking XLA for
// expressions that are trivially constant.
bool has_constant_value_ = false;
Tensor constant_value_;
XlaResource* resource_ = nullptr; // Not owned.
};
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_

View File

@ -36,10 +36,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -48,7 +51,7 @@ namespace {
// Checks that arguments `args` match types `types`. // Checks that arguments `args` match types `types`.
Status CheckSignature(const DataTypeVector& types, Status CheckSignature(const DataTypeVector& types,
const std::vector<XlaCompiler::Argument>& args) { absl::Span<const XlaCompiler::Argument> args) {
if (args.size() != types.size()) { if (args.size() != types.size()) {
return errors::Internal("Compilation arguments have ", args.size(), return errors::Internal("Compilation arguments have ", args.size(),
" elements while function has ", types.size()); " elements while function has ", types.size());
@ -63,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types,
return Status::OK(); return Status::OK();
} }
// Uses the _Arg and _Retval nodes in the graph to determine a core assignment
// for each argument and return value.
xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
ComputeArgAndRetvalCores(const Graph& graph) {
auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
return sharding.value().tile_assignment_devices(0);
} else {
return -1;
}
};
std::map<int, int> arg_cores;
std::map<int, int> retval_cores;
for (const Node* n : graph.nodes()) {
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
if (core < 0) continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0) << "Negative _Arg index";
arg_cores[index] = core;
} else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
if (core < 0) continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0) << "Negative _Retval index";
TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
retval_cores[index] = core;
}
}
return std::make_pair(std::move(arg_cores), std::move(retval_cores));
}
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
int64 step_id) {
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
// resource manager takes ownership via Create, and unrefs via Cleanup. We
// explicitly add a reference to ensure the refcount at entry is maintained at
// all exit points; Create and Cleanup are always called in this function.
//
// The Executor requires us to use ScopedStepContainer. We wrap it in a
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
return Status::OK();
}
// Builds the XLA computation.
// - `args` is the list of input arguments
// - `retvals` is the list of retvals produced by _Retval operators, in index
// order.
// - `args_core` and `retval_cores` are mapping from arg/return indices to core
// assignments.
// - If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
// - Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
// - `resource_updates.size()` return values from the computation. Each entry in
// `resource_updates` is a ResourceUpdate, whose `index` is the index of a
// resource variable argument to the computation to be updated, and `type` is
// the type of the final output.
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<XlaExpression>& retvals,
const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
const std::vector<std::unique_ptr<XlaResource>>& resources,
std::unique_ptr<xla::XlaOp> token_output,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
bool return_updated_values_for_all_resources, bool always_return_tuple,
xla::XlaBuilder* builder, xla::XlaComputation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
// Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names.
xla::OpMetadata retval_metadata;
retval_metadata.set_op_name("XLA_Retvals");
builder->SetOpMetadata(retval_metadata);
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
// cannot change the sharding of the existing output op. To do this, we build
// a new identity op to which shardings can be applied.
auto identity_op = [builder](xla::XlaOp op) {
return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
};
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i];
const XlaExpression& retval = retvals[i];
output.type = retval.dtype();
switch (retval.kind()) {
case XlaExpression::Kind::kConstant:
output.is_constant = true;
output.constant_value = retval.constant_value();
output.shape = output.constant_value.shape();
break;
case XlaExpression::Kind::kXlaOp: {
output.is_constant = false;
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
xla::XlaOp value = retval.handle();
auto it = retval_cores.find(i);
xla::XlaScopedShardingAssignment assign_sharding(
builder, it == retval_cores.end()
? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(it->second));
if (shape_representation_fn) {
// If there is a shape representation function, reshape the output
// tensor to the shape given by the representation shape function.
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
output.shape, output.type));
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
} else if (it != retval_cores.end()) {
// Apply the sharding to the output, if there is a core assignment.
value = identity_op(value);
}
elems.push_back(value);
break;
}
case XlaExpression::Kind::kResource:
output.is_constant = false;
output.input_index = retval.resource()->arg_num();
output.shape = retval.resource()->shape();
break;
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument(
"Invalid expression returned by computation. "
"This probably means a return value was not set.");
}
}
*num_nonconst_outputs = elems.size();
// Add return values for resources whose values have changed.
std::vector<const XlaResource*> arg_resources;
arg_resources.reserve(resources.size());
for (const auto& resource : resources) {
if (resource->arg_num() >= 0) {
arg_resources.push_back(resource.get());
}
}
std::sort(arg_resources.begin(), arg_resources.end(),
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num() < b->arg_num();
});
for (const XlaResource* resource : arg_resources) {
DCHECK_LT(resource->arg_num(), args.size());
const XlaCompiler::Argument& arg = args[resource->arg_num()];
auto it = arg_cores.find(resource->arg_num());
const int core = it == arg_cores.end() ? -1 : it->second;
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
modified =
modified ||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Ensures the correct sharding is applied to the output.
handle = identity_op(handle);
elems.push_back(handle);
}
}
// If we have token output, append it as the last one.
if (token_output) {
elems.push_back(*token_output);
}
*num_computation_outputs = elems.size();
// Builds the XLA computation. We *always* form a tuple here to ensure that
// the output value is the last thing added into the XLA computation, even
// if there is only one output value.
auto tuple = xla::Tuple(builder, elems);
if (!always_return_tuple && elems.size() == 1) {
xla::GetTupleElement(tuple, 0);
}
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
return Status::OK();
}
} // namespace } // namespace
bool XlaCompiler::Argument::operator==( bool XlaCompiler::Argument::operator==(
@ -83,6 +320,39 @@ bool XlaCompiler::Argument::operator==(
return constant_value.tensor_data() == other.constant_value.tensor_data(); return constant_value.tensor_data() == other.constant_value.tensor_data();
} }
string XlaCompiler::Argument::HumanString() const {
string common;
if (!name.empty()) {
common = absl::StrCat(" name=", name);
}
absl::StrAppend(&common, " type=", DataTypeString(type),
" shape=", shape.DebugString());
switch (kind) {
case kInvalid:
return "invalid";
case kConstant:
return absl::StrCat("kind=constant", common,
" value=", constant_value.DebugString());
case kResource: {
string output = absl::StrCat("kind=resource", common, " resource_kind=",
XlaResource::KindToString(resource_kind),
" initialized=", initialized);
if (tensor_array_size >= 0) {
absl::StrAppend(&output, " tensor_array_size=", tensor_array_size);
}
if (!tensor_array_gradients.empty()) {
absl::StrAppend(&output, " tensor_array_gradients=",
absl::StrJoin(tensor_array_gradients, ","));
}
return output;
}
case kParameter:
return absl::StrCat("kind=parameter", common);
case kToken:
return absl::StrCat("token", common);
}
}
XlaCompiler::XlaCompiler(XlaCompiler::Options options) XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(options), : options_(options),
initialization_status_(Status::OK()), initialization_status_(Status::OK()),
@ -110,8 +380,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
// The default shape representation function is the identity. // The default shape representation function is the identity.
if (!options_.shape_representation_fn) { if (!options_.shape_representation_fn) {
options_.shape_representation_fn = [](const TensorShape& shape, options_.shape_representation_fn =
DataType type) { return shape; }; [](const TensorShape& shape,
DataType dtype) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
} }
} }
@ -171,15 +446,16 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
return graph; return graph;
} }
Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, Status XlaCompiler::CompileFunction(
const NameAttrList& function, const XlaCompiler::CompileOptions& options, const NameAttrList& function,
std::vector<XlaCompiler::Argument> args, absl::Span<const XlaCompiler::Argument> args,
XlaCompiler::CompilationResult* result) { XlaCompiler::CompilationResult* result) {
const string function_id = const string function_id =
Canonicalize(function.name(), AttrSlice(&function.attr())); Canonicalize(function.name(), AttrSlice(&function.attr()));
VLOG(1) << "XlaCompiler::CompileFunction " << function_id; VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
auto it = cache_.find({function_id, args}); const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
auto it = cache_.find({function_id, arg_vector});
if (it != cache_.end()) { if (it != cache_.end()) {
*result = it->second; *result = it->second;
return Status::OK(); return Status::OK();
@ -212,14 +488,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// lowest-numbered core that consumes the argument. We choose the // lowest-numbered core that consumes the argument. We choose the
// lowest-numbered core so the assignment is deterministic. // lowest-numbered core so the assignment is deterministic.
for (Node* n : graph->nodes()) { for (Node* n : graph->nodes()) {
if (absl::string_view(n->type_string()) == "_Arg") { if (absl::string_view(n->type_string()) ==
FunctionLibraryDefinition::kArgOp) {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
} }
} }
// Do _Retval as a second loop, in case the retval's input is an _Arg (which // Do _Retval as a second loop, in case the retval's input is an _Arg (which
// may have gotten a device assignment from the first loop). // may have gotten a device assignment from the first loop).
for (Node* n : graph->nodes()) { for (Node* n : graph->nodes()) {
if (absl::string_view(n->type_string()) == "_Retval") { if (absl::string_view(n->type_string()) ==
FunctionLibraryDefinition::kRetOp) {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
} }
} }
@ -235,7 +513,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
CompileGraph(options, function_id, std::move(graph), args, result)); CompileGraph(options, function_id, std::move(graph), args, result));
VLOG(1) << "===================================================="; VLOG(1) << "====================================================";
cache_[{function_id, args}] = *result; cache_[{function_id, arg_vector}] = *result;
return Status::OK(); return Status::OK();
} }
@ -247,25 +525,24 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case"; LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: { case XlaCompiler::Argument::kParameter: {
TensorShape shape;
if (is_entry_computation) { if (is_entry_computation) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
shape, options_.shape_representation_fn(arg.shape, arg.type)); *xla_shape, options_.shape_representation_fn(arg.shape, arg.type));
} else { } else {
shape = arg.shape; TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(arg.type, arg.shape, xla_shape));
} }
return TensorShapeToXLAShape(arg.type, shape, xla_shape); return Status::OK();
} }
case XlaCompiler::Argument::kResource: { case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized); TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) { switch (arg.resource_kind) {
case XlaResource::kVariable: { case XlaResource::kVariable: {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
TensorShape representation_shape, arg.shape, arg.type));
options_.shape_representation_fn(arg.shape, arg.type));
return TensorShapeToXLAShape(arg.type, representation_shape, return Status::OK();
xla_shape);
} }
case XlaResource::kTensorArray: { case XlaResource::kTensorArray: {
if (arg.tensor_array_size < 0) { if (arg.tensor_array_size < 0) {
@ -314,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
} }
} }
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
int64 step_id) {
// Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
// resource manager takes ownership via Create, and unrefs via Cleanup. We
// explicitly add a reference to ensure the refcount at entry is maintained at
// all exit points; Create and Cleanup are always called in this function.
//
// The Executor requires us to use ScopedStepContainer. We wrap it in a
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
return Status::OK();
}
// Builds the XLA computation.
// `args` is the list of input arguments, `retvals` is the list of retvals
// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
// Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
// `resource_updates.size()` return values from the computation. Each entry in
// `resource_updates` is a (input_index, type) pair, where `input_index` is the
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores,
const std::vector<XlaContext::Retval>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
std::unique_ptr<xla::XlaOp> token_output,
bool return_updated_values_for_all_resources, bool always_return_tuple,
xla::XlaBuilder* builder, xla::XlaComputation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i];
output.type = retvals[i].type;
output.shape = retvals[i].shape;
const XlaExpression& retval = retvals[i].expression;
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
} else if (retval.resource() != nullptr) {
output.is_constant = false;
output.input_index = retval.resource()->arg_num();
} else {
output.is_constant = false;
elems.push_back(retval.handle());
}
}
*num_nonconst_outputs = elems.size();
// Add return values for resources whose values have changed.
std::vector<const XlaResource*> arg_resources;
arg_resources.reserve(resources.size());
for (const auto& resource : resources) {
if (resource->arg_num() >= 0) {
arg_resources.push_back(resource.get());
}
}
std::sort(arg_resources.begin(), arg_resources.end(),
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num() < b->arg_num();
});
// Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names.
xla::OpMetadata retval_metadata;
retval_metadata.set_op_name("XLA_Retvals");
builder->SetOpMetadata(retval_metadata);
for (const XlaResource* resource : arg_resources) {
const XlaCompiler::Argument& arg = args[resource->arg_num()];
const int core = arg_cores[resource->arg_num()];
DCHECK_LT(resource->arg_num(), arg_cores.size());
bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
modified =
modified ||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Since we can't change the sharding metadata of <value> as this point,
// create a tuple/get-tuple-element combination so that sharding
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
elems.push_back(handle);
}
}
// If we have token output, append it as the last one.
if (token_output) {
elems.push_back(*token_output);
}
*num_computation_outputs = elems.size();
// Builds the XLA computation. We *always* form a tuple here to ensure that
// the output value is the last thing added into the XLA computation, even
// if there is only one output value.
auto tuple = xla::Tuple(builder, elems);
if (!always_return_tuple && elems.size() == 1) {
xla::GetTupleElement(tuple, 0);
}
builder->ClearOpMetadata();
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
*computation = computation_status.ConsumeValueOrDie();
return Status::OK();
}
} // namespace
// Builds XLA computations for each of the arguments to the computation. // Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation. // `args` are the arguments to the computation.
Status XlaCompiler::BuildArguments( Status XlaCompiler::BuildArguments(
const Graph& graph, const std::vector<XlaCompiler::Argument>& args, const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions, const std::map<int, int>& arg_cores,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes, std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
bool is_entry_computation) { bool is_entry_computation) {
arg_expressions->resize(args.size()); arg_expressions->resize(args.size());
*arg_cores = std::vector<int>(args.size(), -1);
// Argument numbers of arguments and resources that are to be passed to the // Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters. // XLA computation as runtime parameters.
@ -504,7 +622,7 @@ Status XlaCompiler::BuildArguments(
arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
/*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_size=*/arg.tensor_array_size,
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource); arg_expression = XlaExpression::Resource(resource);
if (arg.initialized) { if (arg.initialized) {
input_mapping->push_back(i); input_mapping->push_back(i);
} }
@ -516,7 +634,7 @@ Status XlaCompiler::BuildArguments(
break; break;
} }
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
arg_expression.set_constant_value(arg.constant_value); arg_expression = XlaExpression::Constant(arg.constant_value);
break; break;
case XlaCompiler::Argument::kInvalid: case XlaCompiler::Argument::kInvalid:
return errors::Internal( return errors::Internal(
@ -541,26 +659,6 @@ Status XlaCompiler::BuildArguments(
*input_shapes = arg_shapes; *input_shapes = arg_shapes;
} }
// Use the _Arg nodes in the graph to resolve core assignments.
for (const Node* n : graph.nodes()) {
if (absl::string_view(n->type_string()) != "_Arg") continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < args.size())
<< "_Arg out of bounds: " << index << " vs " << args.size();
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
const int core = sharding.value().tile_assignment_devices(0);
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
(*arg_cores)[index] = core;
}
}
}
// Attach a common operator name as metadata. This has no semantic effect — it // Attach a common operator name as metadata. This has no semantic effect — it
// merely makes the HLO graph more readable when visualized via TensorBoard, // merely makes the HLO graph more readable when visualized via TensorBoard,
// since TensorBoard forms groups out of operators with similar names. // since TensorBoard forms groups out of operators with similar names.
@ -576,11 +674,10 @@ Status XlaCompiler::BuildArguments(
xla::OpSharding tuple_sharding; xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
for (int64 parameter : *input_mapping) { for (int64 parameter : *input_mapping) {
const int core = (*arg_cores)[parameter]; auto it = arg_cores.find(parameter);
const int root_device = 0; const int core = it == arg_cores.end() ? 0 : it->second;
*tuple_sharding.add_tuple_shardings() = *tuple_sharding.add_tuple_shardings() =
core == -1 ? xla::sharding_builder::AssignDevice(root_device) xla::sharding_builder::AssignDevice(core);
: xla::sharding_builder::AssignDevice(core);
} }
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
tuple_sharding); tuple_sharding);
@ -589,7 +686,8 @@ Status XlaCompiler::BuildArguments(
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
} }
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)]; auto it = arg_cores.find(i);
const int core = it == arg_cores.end() ? -1 : it->second;
xla::XlaScopedShardingAssignment assign_sharding( xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>() builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core)); : xla::sharding_builder::AssignDevice(core));
@ -597,7 +695,8 @@ Status XlaCompiler::BuildArguments(
} }
} else { } else {
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)]; auto it = arg_cores.find(i);
const int core = it == arg_cores.end() ? -1 : it->second;
xla::XlaScopedShardingAssignment assign_sharding( xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? absl::optional<xla::OpSharding>() builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core)); : xla::sharding_builder::AssignDevice(core));
@ -632,14 +731,14 @@ Status XlaCompiler::BuildArguments(
// TODO(b/76097077): propagate device assignments onto arguments and // TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally. // return values of functions, and then reshape unconditionally.
if (is_entry_computation) { if (is_entry_computation) {
arg_expression.set_handle( arg_expression = XlaExpression::XlaOp(
xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type);
} else { } else {
arg_expression.set_handle(arg_handles[i]); arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
} }
break; break;
case XlaCompiler::Argument::kToken: { case XlaCompiler::Argument::kToken: {
arg_expression.set_handle(arg_handles[i]); arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
break; break;
} }
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
@ -653,46 +752,48 @@ Status XlaCompiler::BuildArguments(
} }
Status XlaCompiler::CompileSingleOp( Status XlaCompiler::CompileSingleOp(
const XlaCompiler::CompileOptions& options, string const& name, const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
OpKernelContext* ctx, const std::vector<XlaCompiler::Argument>& args, absl::Span<const XlaCompiler::Argument> args,
CompilationResult* result) { absl::Span<const DataType> result_types, CompilationResult* result) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including // TODO(b/74182462): We implement this by creating a new dummy Graph including
// _Arg nodes, and let CompileGraph walk it. This could be optimized. // _Arg nodes, and let CompileGraph walk it. This could be optimized.
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Status status; Status status;
// First create the actual node we care about computing. // First create the actual node we care about computing.
Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); Node* main_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
// Create dummy _Arg nodes. Link these to `node` and also via a control // Create dummy _Arg nodes. Link these to `node` and also via a control
// dependency edge to the _SOURCE node. // dependency edge to the _SOURCE node.
for (int64 i = 0; i < ctx->num_inputs(); ++i) { for (int64 i = 0; i < args.size(); ++i) {
Node* node; Node* node;
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); string arg_name = absl::StrCat("_arg", i);
Status status = NodeBuilder(name, "_Arg") Status status =
.ControlInput(graph->source_node()) NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.Attr("T", ctx->input_dtype(i)) .ControlInput(graph->source_node())
.Attr("index", i) .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
.Finalize(graph.get(), &node); : args[i].type)
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
graph->AddEdge(node, 0, main_node, i); graph->AddEdge(node, 0, main_node, i);
} }
// Similarly with return values, create dummy _Retval nodes fed by `node`. // Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < ctx->num_outputs(); ++i) { for (int64 i = 0; i < result_types.size(); ++i) {
Node* node; Node* node;
string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(name, "_Retval") Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i) .Input(main_node, i)
.Attr("T", ctx->expected_output_dtype(i)) .Attr("T", result_types[i])
.Attr("index", i) .Attr("index", i)
.Finalize(graph.get(), &node); .Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
} }
FixupSourceAndSinkEdges(graph.get()); FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, name, std::move(graph), args, result); return CompileGraph(options, node_def.name(), std::move(graph), args, result);
} }
namespace { namespace {
@ -747,12 +848,38 @@ Status ValidateGraph(const Graph* graph,
return Status::OK(); return Status::OK();
} }
// Converts the value of any expressions whose values are known at compile-time
// to constants.
Status ResolveConstantExpressionsToConstants(
xla::Client* client, absl::Span<XlaExpression> expressions) {
for (XlaExpression& expression : expressions) {
if (expression.kind() == XlaExpression::Kind::kXlaOp) {
TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
expression.ResolveConstant(client));
if (constant.has_value()) {
expression = XlaExpression::Constant(*constant);
}
}
}
return Status::OK();
}
void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
absl::Span<XlaExpression> expressions) {
for (XlaExpression& expression : expressions) {
if (expression.kind() == XlaExpression::Kind::kConstant) {
expression =
XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
}
}
}
} // namespace } // namespace
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name, string const& name,
std::unique_ptr<Graph> graph, std::unique_ptr<Graph> graph,
const std::vector<XlaCompiler::Argument>& args, absl::Span<const XlaCompiler::Argument> args,
CompilationResult* result) { CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
@ -774,13 +901,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
options_.device_type, name)); options_.device_type, name));
xla::XlaBuilder builder(name); xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext( XlaContext* context =
this, &builder, options_.allow_cpu_custom_calls, new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants, options.is_entry_computation, &options_.shape_representation_fn);
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context); core::ScopedUnref context_unref(context);
std::vector<XlaCompiler::Argument> real_args(args); std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
int token_input_index = -1; int token_input_index = -1;
std::unique_ptr<xla::XlaOp> token_output; std::unique_ptr<xla::XlaOp> token_output;
if (options.add_token_input_output) { if (options.add_token_input_output) {
@ -792,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
real_args.push_back(token_arg); real_args.push_back(token_arg);
} }
std::map<int, int> arg_cores;
std::map<int, int> retval_cores;
TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
ComputeArgAndRetvalCores(*graph));
std::vector<XlaExpression> arg_expressions; std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
TF_RETURN_IF_ERROR(BuildArguments( TF_RETURN_IF_ERROR(BuildArguments(
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
&arg_expressions, &result->input_mapping, &result->xla_input_shapes, &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
options.is_entry_computation)); options.is_entry_computation));
context->set_args(std::move(arg_expressions)); context->set_args(std::move(arg_expressions));
@ -843,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_computation_outputs; int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>(); result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size()); result->outputs.resize(context->retvals().size());
std::vector<XlaExpression> retvals = context->retvals();
if (options.resolve_compile_time_constants) {
TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
client(), absl::Span<XlaExpression>(retvals)));
} else {
ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
}
TF_RETURN_IF_ERROR(BuildComputation( TF_RETURN_IF_ERROR(BuildComputation(
real_args, arg_cores, context->retvals(), context->resources(), real_args, retvals, arg_cores, retval_cores, context->resources(),
std::move(token_output), options.return_updated_values_for_all_resources, std::move(token_output),
options.is_entry_computation ? options_.shape_representation_fn
: ShapeRepresentationFn{},
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(), options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs, &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
&result->resource_updates)); &result->resource_updates));

View File

@ -18,10 +18,13 @@ limitations under the License.
#include <stack> #include <stack>
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
@ -118,7 +121,7 @@ class XlaCompiler {
// The type of the argument. If the argument is a resource, this // The type of the argument. If the argument is a resource, this
// is the type of the variable's value, not DT_RESOURCE. // is the type of the variable's value, not DT_RESOURCE.
DataType type; DataType type = DT_INVALID;
// The shape of the argument. For: // The shape of the argument. For:
// * a parameter: the shape of the parameter. // * a parameter: the shape of the parameter.
@ -155,6 +158,9 @@ class XlaCompiler {
std::set<string> tensor_array_gradients; std::set<string> tensor_array_gradients;
bool operator==(const Argument& other) const; bool operator==(const Argument& other) const;
// Returns a human-readable summary of the argument.
string HumanString() const;
}; };
// Options pertaining to an individual call to CompileGraph() or // Options pertaining to an individual call to CompileGraph() or
@ -259,8 +265,7 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation; std::shared_ptr<xla::XlaComputation> computation;
}; };
typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&, typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
DataType)>
ShapeRepresentationFn; ShapeRepresentationFn;
struct Options { struct Options {
// Name of the compilation device to use. It must be set by the caller. // Name of the compilation device to use. It must be set by the caller.
@ -316,22 +321,23 @@ class XlaCompiler {
Status CompileFunction(const CompileOptions& options, Status CompileFunction(const CompileOptions& options,
const NameAttrList& fn_name_attrs, const NameAttrList& fn_name_attrs,
std::vector<Argument> args, CompilationResult* result); absl::Span<const Argument> args,
CompilationResult* result);
// Compiles a tensorflow::Graph into an xla::XlaComputation. // Compiles a tensorflow::Graph into an xla::XlaComputation.
// Similar to CompileFunction, but takes a Graph as input rather than a // Similar to CompileFunction, but takes a Graph as input rather than a
// function. // function.
Status CompileGraph(const CompileOptions& options, string const& name, Status CompileGraph(const CompileOptions& options, string const& name,
std::unique_ptr<Graph> graph, std::unique_ptr<Graph> graph,
const std::vector<Argument>& args, absl::Span<const Argument> args,
CompilationResult* result); CompilationResult* result);
// Compiles a single Op, given by an OpKernelContext, into an // Compiles a single Op, given by `node_def`, into an
// xla::XlaComputation. Similar to CompileFunction but takes a single Op as // xla::XlaComputation. Similar to CompileFunction but takes a single Op as
// input. // input.
Status CompileSingleOp(const CompileOptions& options, string const& name, Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def,
OpKernelContext* ctx, absl::Span<const Argument> args,
const std::vector<Argument>& args, absl::Span<const DataType> result_types,
CompilationResult* result); CompilationResult* result);
// Returns the shape of the XLA parameter for an argument 'arg'. // Returns the shape of the XLA parameter for an argument 'arg'.
@ -411,7 +417,8 @@ class XlaCompiler {
Status BuildArguments(const Graph& graph, Status BuildArguments(const Graph& graph,
const std::vector<XlaCompiler::Argument>& args, const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::XlaBuilder* builder, bool use_tuple_arg, xla::XlaBuilder* builder,
XlaContext* context, std::vector<int>* arg_cores, XlaContext* context,
const std::map<int, int>& arg_cores,
std::vector<XlaExpression>* arg_expressions, std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping, std::vector<int>* input_mapping,
std::vector<xla::Shape>* input_shapes, std::vector<xla::Shape>* input_shapes,

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
@ -1018,9 +1019,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Compiles the graph. // Compiles the graph.
XlaCompiler::Options options = DefaultOptions(); XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn = [](const TensorShape& shape, options.shape_representation_fn =
DataType type) { [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
return TensorShape({shape.num_elements()}); xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
}; };
XlaCompiler compiler(options); XlaCompiler compiler(options);
@ -1086,9 +1089,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
// Compiles the graph. // Compiles the graph.
XlaCompiler::Options options = DefaultOptions(); XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn = [](const TensorShape& shape, options.shape_representation_fn =
DataType type) { [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
return TensorShape({shape.num_elements()}); xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
}; };
XlaCompiler compiler(options); XlaCompiler compiler(options);

View File

@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
XlaContext::XlaContext( XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder, XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool allow_cpu_custom_calls,
bool is_entry_computation, const std::function<xla::StatusOr<xla::Shape>(
const std::function<xla::StatusOr<TensorShape>(
const TensorShape&, DataType)>* shape_representation_fn) const TensorShape&, DataType)>* shape_representation_fn)
: compiler_(compiler), : compiler_(compiler),
builder_(builder), builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls), allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants),
is_entry_computation_(is_entry_computation),
shape_representation_fn_(shape_representation_fn) {} shape_representation_fn_(shape_representation_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; } string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value void XlaContext::SetRetval(int index, const XlaExpression& expression) {
// with a specific return value of the subgraph. if (retvals_.size() <= index) {
void XlaContext::AddRetval(int retval_index, DataType type, retvals_.resize(index + 1);
const TensorShape& shape, const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
} }
XlaExpression e; retvals_[index] = expression;
e.set_handle(handle);
retvals_[retval_index] = Retval{type, shape, e};
} }
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal) {
VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
Tensor value;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
XlaExpression e;
e.set_constant_value(value);
retvals_[retval_index] = Retval{dtype, value.shape(), e};
return Status::OK();
}
Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
VLOG(1) << "Adding retval index " << retval_index << " with resource "
<< resource->name() << ":" << resource->shape().DebugString()
<< " to XLA computation";
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
XlaExpression e;
e.set_resource(resource);
retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
return Status::OK();
}
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource( Status XlaContext::CreateResource(
XlaResource::Kind kind, int arg_num, string name, DataType type, XlaResource::Kind kind, int arg_num, string name, DataType type,
TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
@ -133,7 +93,7 @@ Status XlaContext::CreateResource(
return Status::OK(); return Status::OK();
} }
xla::StatusOr<TensorShape> XlaContext::RepresentationShape( xla::StatusOr<xla::Shape> XlaContext::RepresentationShape(
const TensorShape& shape, DataType type) const { const TensorShape& shape, DataType type) const {
return (*shape_representation_fn_)(shape, type); return (*shape_representation_fn_)(shape, type);
} }

View File

@ -20,8 +20,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -46,9 +46,8 @@ class XlaContext : public ResourceBase {
// Creates a new XlaContext. See the documentation on the class data fields // Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments. // for descriptions of the arguments.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool allow_cpu_custom_calls,
bool is_entry_computation, const std::function<xla::StatusOr<xla::Shape>(
const std::function<xla::StatusOr<TensorShape>(
const TensorShape&, DataType)>* shape_representation_fn); const TensorShape&, DataType)>* shape_representation_fn);
// Virtual method defined by ResourceBase. // Virtual method defined by ResourceBase.
@ -57,37 +56,19 @@ class XlaContext : public ResourceBase {
XlaCompiler* compiler() const { return compiler_; } XlaCompiler* compiler() const { return compiler_; }
// Returns the XlaBuilder that Ops use for compiling new expressions. // Returns the XlaBuilder that Ops use for compiling new expressions.
xla::XlaBuilder* builder(); xla::XlaBuilder* builder() { return builder_; }
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool resolve_compile_time_constants() const {
return resolve_compile_time_constants_;
}
bool is_entry_computation() const { return is_entry_computation_; }
const std::vector<XlaExpression>& args() const { return args_; } const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args); void set_args(std::vector<XlaExpression> args);
struct Retval { const std::vector<XlaExpression>& retvals() { return retvals_; }
DataType type;
TensorShape shape;
// An XlaExpression representing the Retval's value.
XlaExpression expression;
};
const std::vector<Retval>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value // Sets a return value.
// with a specific return value of the subgraph. // Since we do not always know in advance how many return values there are,
void AddRetval(int retval_index, DataType type, const TensorShape& shape, // grows the return values vector to size index+1 if it is smaller.
const xla::XlaOp& handle); void SetRetval(int index, const XlaExpression& expression);
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal);
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
// Creates a resource with resource `kind` and initial value `handle`. `name` // Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource` // is a descriptive name for use in error messages. See the `XlaResource`
@ -105,8 +86,8 @@ class XlaContext : public ResourceBase {
// Returns the XLA shape to be used to represent a variable of TF `shape` // Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`, or of an argument or return value of a top-level computation. // and `type`, or of an argument or return value of a top-level computation.
xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape, xla::StatusOr<xla::Shape> RepresentationShape(const TensorShape& shape,
DataType type) const; DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the // Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a // XlaContext since it may be used by multiple Ops. There is a
@ -140,31 +121,19 @@ class XlaContext : public ResourceBase {
// Allow ops to emit CustomCall operations for CPU. // Allow ops to emit CustomCall operations for CPU.
const bool allow_cpu_custom_calls_; const bool allow_cpu_custom_calls_;
// If true, constant return values are returned as Tensors instead of
// run-time computation outputs.
const bool resolve_compile_time_constants_;
// Arguments to the Tensorflow graph, indexed by _Arg index. // Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters. // Includes both compile-time constant arguments and runtime parameters.
std::vector<XlaExpression> args_; std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index. // Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<Retval> retvals_; std::vector<XlaExpression> retvals_;
// Holds ownership of resources. The resources are not ordered. // Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_; std::vector<std::unique_ptr<XlaResource>> resources_;
// Is this a top-level computation, or an inner computation (e.g., a while // Describes the on-host shapes of parameters and return values. Also see:
// body)? // XlaDevice::Options::shape_representation_fn.
const bool is_entry_computation_; const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*
// A function that describes how the shapes of
// a) argument and return value, for entry computations
// b) variables, for all computations,
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>*
shape_representation_fn_; shape_representation_fn_;
// Cache of prebuilt computations indexed by their type. // Cache of prebuilt computations indexed by their type.

View File

@ -0,0 +1,145 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
XlaExpression::XlaExpression() = default;
XlaExpression XlaExpression::Invalid() {
XlaExpression e;
e.kind_ = Kind::kInvalid;
return e;
}
XlaExpression XlaExpression::Constant(Tensor value) {
XlaExpression e;
e.kind_ = Kind::kConstant;
e.dtype_ = value.dtype();
e.constant_value_ = value;
return e;
}
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
XlaExpression e;
e.kind_ = Kind::kXlaOp;
e.dtype_ = dtype;
e.handle_ = value;
return e;
}
XlaExpression XlaExpression::Resource(XlaResource* resource) {
XlaExpression e;
e.kind_ = Kind::kResource;
e.dtype_ = DT_RESOURCE;
e.resource_ = resource;
return e;
}
string XlaExpression::HumanString() const {
switch (kind_) {
case Kind::kInvalid:
return "invalid";
case Kind::kConstant:
return "constant";
case Kind::kXlaOp:
return "xla_op";
case Kind::kResource:
return "resource";
}
}
xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
switch (kind_) {
case Kind::kConstant: {
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToBorrowingLiteral(constant_value_, &literal));
return xla::ConstantLiteral(builder, literal);
}
case Kind::kXlaOp:
if (builder != handle_.builder()) {
return errors::InvalidArgument(
"Mismatched builders in XlaExpression::AsXlaOp");
}
return handle_;
default:
return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
HumanString());
}
});
}
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client) const {
switch (kind()) {
case Kind::kConstant:
return {constant_value()};
case Kind::kXlaOp:
break;
case Kind::kResource:
case Kind::kInvalid:
return errors::InvalidArgument(
"ResolveConstant called on XlaExpression: ", HumanString());
}
TF_ASSIGN_OR_RETURN(bool is_constant,
handle().builder()->IsConstant(handle()));
if (!is_constant) return {absl::nullopt};
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildConstantSubGraph(handle()));
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
// The XLA layout is specified minor to major, and TensorFlow uses a major to
// minor order.
std::vector<int64> layout_indices(shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
TF_ASSIGN_OR_RETURN(xla::Literal literal,
client->ComputeConstant(constant_graph, &layout));
Tensor tensor;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
return {tensor};
}
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
switch (kind_) {
case Kind::kConstant:
return constant_value().shape();
case Kind::kXlaOp: {
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
handle().builder()->GetShape(handle()));
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
return shape;
}
case Kind::kResource:
return TensorShape({});
case Kind::kInvalid:
return errors::InvalidArgument(
"GetShape() called on invalid XlaExpression");
}
}
} // namespace tensorflow

View File

@ -0,0 +1,115 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
// compilation.
// An expression is one of:
// * a constant tensor.
// * an xla::XlaOp, representing a symbolic XLA value.
// * a resource, e.g., a variable, represented as an XlaResource pointer.
//
// Constant tensors are mostly an optimization to avoid passing large constants
// to XLA, but are also sometimes used to represent tensors that have no XLA
// representation, for example, DT_STRING tensors. A canonical use case might be
// an error message string.
class XlaExpression {
public:
enum class Kind {
kInvalid,
kConstant,
kXlaOp,
kResource,
};
XlaExpression();
XlaExpression(const XlaExpression&) = default;
XlaExpression& operator=(const XlaExpression&) = default;
// Builds an invalid expression. (Same as the default constructor, but makes
// the intent clearer.)
static XlaExpression Invalid();
// Builds a constant XLA expression.
static XlaExpression Constant(Tensor value);
// Builds a XlaOp expression. Since the mapping from TF data types to XLA
// types is not 1-1, the TF type must also be provided; in general it cannot
// be derived from the XLA type.
static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
// Builds a resource expression.
static XlaExpression Resource(XlaResource* resource);
Kind kind() const { return kind_; }
DataType dtype() const { return dtype_; }
// handle() returns the XlaOp that backs a kXlaOp expression.
const xla::XlaOp& handle() const { return handle_; }
const Tensor& constant_value() const { return constant_value_; }
XlaResource* resource() const { return resource_; }
// Returns a human-readable summary of the expression.
string HumanString() const;
// Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
// an erroneous XlaOp if the expression is not a constant or an expression.
xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
// If a kXlaOp or kConstant expression can be resolved to a compile-time
// constant, returns the value as a host-memory Tensor. Returns an empty
// optional if it cannot be resolved. Returns an error if passed a resource
// expression.
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
xla::Client* client) const;
// Returns the shape of the tensor.
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
// not the shape of the resource's value.
xla::StatusOr<TensorShape> GetShape() const;
private:
Kind kind_ = Kind::kInvalid;
DataType dtype_ = DT_INVALID;
// The XLA handle of the expression's computation, if kind_ == kXlaOp.
xla::XlaOp handle_;
// The value of the constant, if kind_ == kConstant.
Tensor constant_value_;
// The resource, if kind_ == kResource. Not owned.
XlaResource* resource_ = nullptr;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class XlaExpressionTest : public ::testing::Test {
protected:
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
constant_ = test::AsScalar<int32>(42);
op_ = xla::ConstantR0<int32>(builder_.get(), 7);
non_constant_op_ = xla::Parameter(
builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
resource_ = absl::make_unique<XlaResource>(
XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
/*tensor_array_gradients=*/std::set<string>(),
/*tensor_array_multiple_writes_aggregate=*/false);
}
xla::Client* client_;
std::unique_ptr<xla::XlaBuilder> builder_;
Tensor constant_;
xla::XlaOp op_;
xla::XlaOp non_constant_op_;
std::unique_ptr<XlaResource> resource_;
};
TEST_F(XlaExpressionTest, Kind) {
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
EXPECT_TRUE(XlaExpression::Kind::kConstant ==
XlaExpression::Constant(constant_).kind());
EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
XlaExpression::XlaOp(op_, DT_INT32).kind());
EXPECT_TRUE(XlaExpression::Kind::kResource ==
XlaExpression::Resource(resource_.get()).kind());
}
TEST_F(XlaExpressionTest, HumanString) {
EXPECT_EQ("invalid", XlaExpression().HumanString());
EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
}
TEST_F(XlaExpressionTest, AsXlaOp) {
xla::XlaOp op_as_op =
XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
xla::XlaOp const_as_op =
XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
builder_->BuildConstantSubGraph(const_as_op));
TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
client_->ComputeConstant(computation));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
value));
}
TEST_F(XlaExpressionTest, GetShape) {
EXPECT_FALSE(XlaExpression().GetShape().ok());
EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
XlaExpression::Resource(resource_.get()).GetShape());
EXPECT_EQ(TensorShape({}), resource_shape);
TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
XlaExpression::XlaOp(op_, DT_INT32).GetShape());
EXPECT_EQ(TensorShape({}), op_shape);
TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
XlaExpression::Constant(constant_).GetShape());
EXPECT_EQ(TensorShape({}), constant_shape);
}
TEST_F(XlaExpressionTest, ResolveConstant) {
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
EXPECT_FALSE(
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
TF_ASSERT_OK_AND_ASSIGN(
absl::optional<Tensor> op_constant,
XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
ASSERT_TRUE(op_constant.has_value());
test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
.ResolveConstant(client_));
EXPECT_FALSE(op_nonconstant.has_value());
TF_ASSERT_OK_AND_ASSIGN(
absl::optional<Tensor> constant_constant,
XlaExpression::Constant(constant_).ResolveConstant(client_));
ASSERT_TRUE(constant_constant.has_value());
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
}
} // namespace
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression = const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK(expression->handle().valid() || expression->resource() != nullptr); CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
VLOG(1) << "Fetched T" << expression->handle(); << expression->HumanString();
return expression; return expression;
} }
// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. // Assigns an XlaExpression to a tensor on an XLA compilation device.
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { static void AssignExpressionToTensor(Tensor* tensor,
const XlaExpression& value) {
const XlaExpression* expression = const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK(!expression->handle().valid()); CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
return const_cast<XlaExpression*>(expression); << expression->HumanString();
*const_cast<XlaExpression*>(expression) = value;
} }
// Retrieves the XlaOp from an input Tensor to an Op. This computation was const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
// constructed by an Op that executed previously and created the output Tensor return *CastExpressionFromTensor(context_->input(index));
// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
} }
const xla::XlaOp& XlaOpKernelContext::Input(int index) { const XlaExpression& XlaOpKernelContext::InputExpression(
return GetComputationFromTensor(context_->input(index)); absl::string_view name) {
return *CastExpressionFromTensor(GetInputTensorByName(name));
} }
const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { xla::XlaOp XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(GetInputTensorByName(name)); return InputExpression(index).AsXlaOp(builder());
}
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
return InputExpression(name).AsXlaOp(builder());
} }
TensorShape XlaOpKernelContext::InputShape(int index) { TensorShape XlaOpKernelContext::InputShape(int index) {
@ -125,77 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name,
Status XlaOpKernelContext::ConstantInputReshaped( Status XlaOpKernelContext::ConstantInputReshaped(
int index, absl::Span<const int64> new_dims, int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal) { xla::Literal* constant_literal) {
const Tensor& tensor = context_->input(index); XlaExpression e = InputExpression(index);
TensorShape new_shape(new_dims); xla::StatusOr<absl::optional<Tensor>> constant_or_status =
if (tensor.NumElements() != new_shape.num_elements()) { e.ResolveConstant(compiler()->client());
return errors::InvalidArgument( if (!constant_or_status.ok()) {
context_->op_kernel().name(), " input ", index, " has shape ", Status status = constant_or_status.status();
tensor.shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
new_shape.DebugString());
}
const XlaExpression* expression = CastExpressionFromTensor(tensor);
auto copy_tensor_to_literal = [](const Tensor& tensor,
xla::Literal* literal) {
xla::Shape literal_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
*literal = xla::Literal(literal_shape);
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = tensor.TotalBytes();
if (total_bytes > 0) {
void* dst_ptr = literal->untyped_data();
const void* src_ptr = DMAHelper::base(&tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
return Status::OK();
};
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
Tensor temp(tensor.dtype());
if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
// This should never happen. The constant should have a shape compatible
// with the enclosing Tensor.
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
}
return copy_tensor_to_literal(temp, constant_literal);
}
// Make sure we treat zero-element tensors as constant.
if (new_shape.num_elements() == 0) {
Tensor temp(tensor.dtype(), new_shape);
return copy_tensor_to_literal(temp, constant_literal);
}
xla::XlaOp handle = expression->handle();
if (new_shape != tensor.shape()) {
// Reshape the handle to the desired shape.
handle = xla::Reshape(handle, new_shape.dim_sizes());
}
// The XLA layout is specified minor to major, and TensorFlow's minor
// dimension is the last one.
std::vector<int64> layout_indices(new_shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
if (!is_constant.ok()) {
Status status = is_constant.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ", errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
context_->op_kernel().type_string(), context_->op_kernel().type_string(),
" operator as a compile-time constant."); " operator as a compile-time constant.");
return status; return status;
} }
absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
if (!is_constant.ValueOrDie()) { if (!constant.has_value()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Input ", index, " to ", context_->op_kernel().type_string(), "Input ", index, " to ", context_->op_kernel().type_string(),
" operator must be a compile-time constant.\n" " operator must be a compile-time constant.\n"
@ -208,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped(
"stateful operation such as a random number generator."); "stateful operation such as a random number generator.");
} }
// Ask the XLA compiler to evaluate the data handle to a literal. Tensor temp(constant->dtype());
xla::StatusOr<xla::XlaComputation> constant_graph = if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
builder()->BuildConstantSubGraph(handle); return errors::InvalidArgument(
if (!constant_graph.ok()) { context_->op_kernel().name(), " input ", index, " has shape ",
return errors::Internal( constant->shape().DebugString(),
"Error getting a compile-time constant graph for ", " but was asked to be reshaped to incompatible shape ",
context_->op_kernel().name(), " input ", index, TensorShape(new_dims).DebugString());
".\nError: ", constant_graph.status().error_message());
} }
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
" as a compile-time constant.\nError: ",
computed.status().error_message());
}
*constant_literal = std::move(computed).ValueOrDie();
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK(); return Status::OK();
} }
@ -322,6 +259,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
return LiteralToInt64Vector(literal, out); return LiteralToInt64Vector(literal, out);
} }
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
absl::string_view name, std::vector<int64>* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
xla::Literal* out) { xla::Literal* out) {
xla::Literal literal; xla::Literal literal;
@ -372,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
handles->clear(); handles->clear();
shapes->clear(); shapes->clear();
for (const Tensor& input : inputs) { for (const Tensor& input : inputs) {
handles->push_back(GetComputationFromTensor(input)); handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
shapes->push_back(input.shape()); shapes->push_back(input.shape());
} }
return Status::OK(); return Status::OK();
@ -413,9 +359,12 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
XlaContext& xla_context = XlaContext::Get(ctx); XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
TensorShape representation_shape, xla::Shape representation_shape,
xla_context.RepresentationShape(variable->shape(), variable->type())); xla_context.RepresentationShape(variable->shape(), variable->type()));
if (representation_shape == variable->shape()) { xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
*value = variable->value(); *value = variable->value();
} else { } else {
*value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
@ -455,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK(); return Status::OK();
} }
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, void XlaOpKernelContext::SetOutputExpression(int index,
Tensor** output) { const XlaExpression& expression) {
// The step's default allocator is the dummy XlaCompilationAllocator which Status status = [&] {
// simply allocates a metadata buffer to hold the expression to which it // The step's default allocator is the dummy XlaCompilationAllocator which
// corresponds. // simply allocates a metadata buffer to hold the expression to which it
if (expected_output_dtype(index) == DT_VARIANT) { // corresponds.
// tensor_data() is not supported for variant Tensor (i.e., Tensor* output = nullptr;
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the // Provides a special behavior for DT_VARIANT: a variant is treated as
// XlaExpression inside the Tensor's tensor_data() does not work for // DT_UINT8 scalar as the type to allow mapping for variant to more generic
// variant. Instead construct a uint8 tensor and store the expression in its // types.
// value. if (expression.dtype() == DT_VARIANT) {
// TODO(jpienaar): This should be refactored to stop masquerading // tensor_data() is not supported for variant Tensor (i.e.,
// XlaExpressions as Tensors. // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
*output = new Tensor(); // XlaExpression inside the Tensor's tensor_data() does not work for
TensorShape tensor_shape; // variant. Instead construct a uint8 tensor and store the expression in
TF_RETURN_IF_ERROR( // its value.
context_->allocate_temp(DT_UINT8, tensor_shape, *output)); // TODO(jpienaar): This should be refactored to stop masquerading
context_->set_output(index, **output); // XlaExpressions as Tensors.
} else { output = new Tensor();
TensorShape tensor_shape; TensorShape tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); TF_RETURN_IF_ERROR(
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); context_->allocate_temp(DT_UINT8, tensor_shape, output));
context_->set_output(index, *output);
} else {
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
}
AssignExpressionToTensor(output, expression);
return Status::OK();
}();
if (!status.ok()) {
SetStatus(status);
} }
return Status::OK();
} }
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression. SetOutputExpression(
Tensor* output = nullptr; index,
auto shape_or = builder()->GetShape(handle); XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
if (!shape_or.ok()) {
SetStatus(shape_or.status());
return;
}
OP_REQUIRES_OK(context_,
allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
} }
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
const TensorShape& shape = constant.shape(); SetOutputExpression(index, XlaExpression::Constant(constant));
xla::BorrowingLiteral literal;
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
CHECK(handle.valid());
// Make the Tensor that will refer to the expression.
Tensor* output = nullptr;
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_handle(handle);
expression->set_constant_value(constant);
}
void XlaOpKernelContext::SetInvalidOutput(int index) {
Tensor* output = nullptr;
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape({}), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
xla::XlaOp handle;
expression->set_handle(handle);
} }
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
Tensor* output = nullptr; SetOutputExpression(index, XlaExpression::Resource(resource));
// The shape of the output tensor is the shape of the resource itself
// (i.e., a scalar), not the shape of the resource's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
expression->set_resource(resource);
} }
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
@ -570,10 +482,13 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
XlaContext& xla_context = XlaContext::Get(ctx); XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN(TensorShape representation_shape, TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
xla_context.RepresentationShape(shape, type)); xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) { xla::Shape xla_shape;
handle = xla::Reshape(handle, representation_shape.dim_sizes()); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
handle = xla::Reshape(handle,
xla::AsInt64Slice(representation_shape.dimensions()));
} }
return variable->SetValue(handle); return variable->SetValue(handle);
} }

View File

@ -88,9 +88,9 @@ class XlaOpKernelContext {
// Returns input `index` as a XlaOp. Unlike // Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete // OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor. // Tensor.
const xla::XlaOp& Input(int index); xla::XlaOp Input(int index);
// Returns input `name` as a XlaOp. // Returns input `name` as a XlaOp.
const xla::XlaOp& Input(absl::string_view name); xla::XlaOp Input(absl::string_view name);
// Returns true if all inputs are the same shape, otherwise sets the // Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false. // status to a non-OK value and returns false.
@ -111,14 +111,6 @@ class XlaOpKernelContext {
Status ConstantInput(int index, xla::Literal* constant_literal); Status ConstantInput(int index, xla::Literal* constant_literal);
Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
// Converts a constant scalar int32 or int64 tensor into an int64. // Converts a constant scalar int32 or int64 tensor into an int64.
Status ConstantInputAsIntScalar(int index, int64* out); Status ConstantInputAsIntScalar(int index, int64* out);
Status ConstantInputAsIntScalar(absl::string_view name, int64* out); Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
@ -134,6 +126,8 @@ class XlaOpKernelContext {
// Reshapes and converts a constant int32 or int64 tensor into a vector of // Reshapes and converts a constant int32 or int64 tensor into a vector of
// int64s. // int64s.
Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out); Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
Status ConstantInputReshapedToIntVector(absl::string_view name,
std::vector<int64>* out);
// Converts a constant int32 or int64 Tensor into an xla int64 Literal. // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out); Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
@ -148,6 +142,10 @@ class XlaOpKernelContext {
Status ConstantInputList(absl::string_view name, Status ConstantInputList(absl::string_view name,
std::vector<xla::Literal>* literals); std::vector<xla::Literal>* literals);
// Returns an XlaExpression describing the value of 'index'.
const XlaExpression& InputExpression(int index);
const XlaExpression& InputExpression(absl::string_view name);
// Outputs // Outputs
int num_outputs() const { return context_->num_outputs(); } int num_outputs() const { return context_->num_outputs(); }
@ -165,9 +163,8 @@ class XlaOpKernelContext {
// SetConstantOutput where possible. // SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor); void SetConstantOutput(int index, const Tensor& host_tensor);
// Sets output `index` to an invalid value. // Returns an XlaExpression describing the value of 'index'.
// Any subsequent attempt to consume this output will cause an error. void SetOutputExpression(int index, const XlaExpression& expression);
void SetInvalidOutput(int index);
// Status handling. // Status handling.
void SetStatus(const Status& status) { context_->SetStatus(status); } void SetStatus(const Status& status) { context_->SetStatus(status); }
@ -255,10 +252,13 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`. // Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name); const Tensor& GetInputTensorByName(absl::string_view name);
// Wraps OpKernelContext's allocate_output method while providing special // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the // InputShape(index), and stores it in `*constant_literal`. If the input
// type to allow mapping for variant to more generic types. // cannot be evaluated, e.g., because it depends on unbound parameters,
Status allocate_output(int index, const xla::Shape& shape, Tensor** output); // returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
OpKernelContext* const context_; OpKernelContext* const context_;
}; };

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default;
// Lazily register the CPU and GPU JIT devices the first time // Lazily register the CPU and GPU JIT devices the first time
// GetCompilationDevice is called. // GetCompilationDevice is called.
static void* registration_init = [&registry]() { static void* registration_init = [&registry]() {
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
mutex_lock lock(registry.mutex_); mutex_lock lock(registry.mutex_);
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
DeviceRegistration& registration = DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_CPU]; registry.compilation_devices_[DEVICE_CPU];
registration.compilation_device_name = DEVICE_CPU_XLA_JIT; registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
registration.requires_compilation = false; registration.autoclustering_policy =
registration.enable_jit_by_default = false; cpu_global_jit
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
registration.compile_resource_ops = false; registration.compile_resource_ops = false;
} }
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
DeviceRegistration& registration = DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_GPU]; registry.compilation_devices_[DEVICE_GPU];
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.requires_compilation = false; registration.autoclustering_policy =
registration.enable_jit_by_default = true; XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
registration.compile_resource_ops = false; registration.compile_resource_ops = false;
} }
return nullptr; return nullptr;

View File

@ -66,19 +66,26 @@ class XlaOpRegistry {
public: public:
typedef OpKernel* (*Factory)(OpKernelConstruction*); typedef OpKernel* (*Factory)(OpKernelConstruction*);
enum class AutoclusteringPolicy {
// Enable autoclustering if the user requests it, e.g., via
// experimental_jit_scope. Does not autocluster if the JIT is enabled
// globally (e.g., via the OptimizerOptions in the TF session
// configuration.)
kIfExplicitlyRequested,
// Enable autoclustering if explicitly requested, or if the JIT is enabled
// globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
kIfEnabledGlobally,
// Always try to autocluster ops placed on this device.
kAlways,
};
// Describes how to compile operators assigned to a device. // Describes how to compile operators assigned to a device.
struct DeviceRegistration { struct DeviceRegistration {
// The name of the an XLA compilation device to use to compile code. // The name of the an XLA compilation device to use to compile code.
string compilation_device_name; string compilation_device_name;
// Do operators assigned to this device require compilation? // When should we autocluster operators assigned to this device?
bool requires_compilation; AutoclusteringPolicy autoclustering_policy;
// If !requires_compilation, should we try to JIT operators on this device
// when XLA JIT compilation is enabled globally via the SessionOptions?
// (It is still possible to explicitly mark operators to JIT compile, even
// if enable_jit_by_default is false.)
bool enable_jit_by_default;
// Enable compilation of operators that use DT_RESOURCE types? // Enable compilation of operators that use DT_RESOURCE types?
bool compile_resource_ops = false; bool compile_resource_ops = false;

Some files were not shown because too many files have changed in this diff Show More