diff --git a/WORKSPACE b/WORKSPACE index 17961829a60..0c7bc085b51 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +http_archive( + name = "base_images_docker", + sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", + strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", + urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], +) + +http_archive( + name = "bazel_toolchains", + sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", + strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", + ], +) + +http_archive( + name = "io_bazel_rules_docker", + sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", + strip_prefix = "rules_docker-0.5.1", + urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +) + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") + +remote_config_workspace() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. @@ -79,3 +106,4 @@ new_http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) + diff --git a/configure.py b/configure.py index 2eeeceb3399..234561d94a4 100644 --- a/configure.py +++ b/configure.py @@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -1555,6 +1555,9 @@ def main(): check_bazel_version('0.15.0') reset_tf_configure_bazelrc() + # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later + write_to_bazelrc('import %workspace%/tools/bazel.rc') + cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 11b42f349df..859dc3b8d77 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -352,6 +352,7 @@ package_group( "//tensorflow/...", "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", + "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", ], ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 16f633643d4..b8db1b21449 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -95,6 +95,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -199,7 +200,7 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ - ":test_op.so", + ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], kernels = [":test_op_kernel"], @@ -218,6 +219,7 @@ tf_cuda_cc_test( "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/compiler/jit", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", @@ -284,8 +286,8 @@ tf_cc_test( ) tf_custom_op_library( - name = "test_op.so", - srcs = ["test_op.cc"], + name = "test_op1.so", + srcs = ["test_op1.cc"], ) tf_kernel_library( diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 4540dcd6638..f13e8777dff 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2810,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { } return ret; } + +// TF_Server functions ---------------------------------------------- + +#ifndef __ANDROID__ +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // __ANDROID__ + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Start(); +#endif +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Stop(); +#endif +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Join(); +#endif +} + +const char* TF_ServerTarget(TF_Server* server) { +#ifdef __ANDROID__ + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { delete server; } + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index da8ad1cec59..3d56268110e 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1668,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 95652a11378..5ba26d3c585 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -25,6 +25,7 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -179,6 +180,15 @@ struct TF_ApiDefMap { tensorflow::mutex lock; }; +#ifndef __ANDROID__ +struct TF_Server { + TF_Server(std::unique_ptr server); + + const tensorflow::string target; + std::unique_ptr server; +}; +#endif + namespace tensorflow { class TensorCApi { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index b0dc0363fdb..d5934a10395 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - // Load the library. - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - TF_Code code = TF_GetCode(status); - string status_msg(TF_Message(status)); - TF_DeleteStatus(status); - ASSERT_EQ(TF_OK, code) << status_msg; +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op1.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi1", op_list.op(0).name()); + TF_DeleteLibraryHandle(lib); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) { TF_Buffer* op_list_buffer = TF_GetAllOpList(); tensorflow::OpList op_list; @@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) { EXPECT_TRUE(found); TF_DeleteBuffer(op_list_buffer); } - -#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) - { - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); - } -#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) - - TF_DeleteLibraryHandle(lib); } void TestEncodeDecode(int line, const std::vector& data) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3ee31a6a7ac..ba3d8533db7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -69,7 +69,7 @@ tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3554ec0bf32..408277468d7 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -404,8 +404,7 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { "The passed in handle is a nullptr"); return nullptr; } - tensorflow::Device* d = nullptr; - status->status = h->handle->OpDevice(&d); + tensorflow::Device* d = h->handle->op_device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 5006b76f198..52b08245528 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } - tensorflow::Device* device; - status->status = handle->handle->Device(&device); - if (!status->status.ok()) { - return nullptr; - } - #ifdef TENSORFLOW_EAGER_USE_XLA + tensorflow::Device* device = handle->handle->device(); + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. tensorflow::XlaDevice* xla_device = dynamic_cast(device); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 104d52430cf..fa1b22e3af4 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -79,10 +79,6 @@ struct TFE_TensorHandle { tensorflow::Device* op_device) : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} - TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - tensorflow::EagerContext* ctx) - : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; diff --git a/tensorflow/c/test_op1.cc b/tensorflow/c/test_op1.cc new file mode 100644 index 00000000000..b22cc9aef2b --- /dev/null +++ b/tensorflow/c/test_op1.cc @@ -0,0 +1,23 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc"); + +} // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c18b07603ae..83353b79f72 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -170,6 +170,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc( ":array_ops", ":const_op", ":math_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", ], ) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 6c29f09cde7..16151e77737 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -93,7 +93,7 @@ cc_library( ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348c..d548de8c442 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) { return errors::InvalidArgument("Must specify --cpp_class"); } codegen_opts.gen_hlo_profile_printer_data = - xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + xla::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); @@ -132,7 +132,7 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0c41e095c7b..5f25e4626ad 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -21,7 +21,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") @@ -52,6 +51,7 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -65,6 +65,7 @@ cc_library( ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), alwayslink = 1, @@ -190,6 +191,7 @@ cc_library( "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:stack", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", @@ -241,6 +243,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) @@ -253,6 +256,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:core_cpu", @@ -263,6 +267,21 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "xla_compilation_cache_test", + srcs = [ + "xla_compilation_cache_test.cc", + ], + deps = [ + ":xla_compilation_cache", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -500,6 +519,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -524,25 +544,6 @@ cc_library( hdrs = ["union_find.h"], ) -cc_library( - name = "producer_consumer_queue", - hdrs = ["producer_consumer_queue.h"], - deps = ["//tensorflow/core:lib"], -) - -tf_cc_test( - name = "producer_consumer_queue_test", - size = "small", - srcs = ["producer_consumer_queue_test.cc"], - deps = [ - ":producer_consumer_queue", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -606,6 +607,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -648,31 +650,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "xla_launch_util_test", - size = "small", - srcs = ["xla_launch_util_test.cc"], - deps = [ - ":common", - ":xla_compilation_cache", - ":xla_launch_util", - ":xla_tensor", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core/kernels:variable_ops", - ], -) - cc_library( name = "xla_fusion_optimizer", srcs = ["xla_fusion_optimizer.cc"], diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 054f31ba335..93637a69d5d 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) { return errors::Internal("Could not find compilation device ", device_type.type()); } - *result = registration->requires_compilation; + *result = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 617e31488c7..8a73101c184 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output loop_cond = ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), latch.output_true, increment_by); Output next_iteration = @@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue( value, frame_name); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output next_iteration = ops::NextIteration( root.WithOpName(prefix + "/next_iteration"), latch.output_true); CHECK(root.graph() diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index a3b193eea74..5e0c4bf6a0c 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -117,6 +117,25 @@ Status PreprocessForEncapsulation(Graph* g, // Information for XLA computation. struct XlaClusterInfo { + // Add an explicitly-defined default constructor for this class. + // + // The compiler may delete the default constructor here because + // host_compute_core is a const member whose type (std::map) doesn't + // necessarily have a user provided constructor -- while libc++ and + // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at + // least >= 7.3 does not. See also c++11 [class.ctor] p5. + // + // TODO(klimek): In c++17 we'll be able to initialize host_compute_core + // without losing aggregate initialization, which allows us to get rid of + // the constructor definitions again. + XlaClusterInfo() {} + XlaClusterInfo(const string& cluster_name, + const NameAttrList& func_name_attrs, Node* node, + const std::map& host_compute_core) + : cluster_name(cluster_name), + func_name_attrs(func_name_attrs), + node(node), + host_compute_core(host_compute_core) {} // XLA cluster name. It might be different from `func_name`. const string cluster_name; // Name and attributes of XLA computation function. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 70b019d35fc..8b3587c5087 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -394,12 +394,12 @@ Status ConstructHostGraph( for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; FunctionBody* host_fbody = nullptr; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fld->Find(host_func), AttrSlice(), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &host_fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_func), AttrSlice(), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &host_fbody)); std::unique_ptr host_fbody_deleter(host_fbody); // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse @@ -411,52 +411,53 @@ Status ConstructHostGraph( node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); Status s; - ReverseDFS(*host_fbody->graph, /*enter=*/nullptr, - [&](const Node* n) { - if (!s.ok()) { - return; - } + ReverseDFS( + *host_fbody->graph, /*enter=*/nullptr, + [&](const Node* n) { + if (!s.ok()) { + return; + } - Node* copy; - if (node_map.find(n) != node_map.end()) { - // Already copied this node. - copy = node_map.at(n); - } else if (IsKeyPlaceholderNode(*n)) { - // Change a). - copy = key_placeholder; - node_map[n] = copy; - } else { - // Copy the node. - NodeDef copy_def = n->def(); - // Change c). - copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); - if (!s.ok()) { - return; - } - node_map[n] = copy; - } + Node* copy; + if (node_map.find(n) != node_map.end()) { + // Already copied this node. + copy = node_map.at(n); + } else if (IsKeyPlaceholderNode(*n)) { + // Change a). + copy = key_placeholder; + node_map[n] = copy; + } else { + // Copy the node. + NodeDef copy_def = n->def(); + // Change c). + copy_def.clear_device(); + copy = (*host_graph)->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + node_map[n] = copy; + } - // Only handle input edges. Output edges will be added later as - // its output nodes' input edges. - for (auto e : n->in_edges()) { - if (node_map.find(e->src()) == node_map.end()) { - s = errors::Internal("Cannot find node image for ", - e->src()->DebugString()); - return; - } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); - } + // Only handle input edges. Output edges will be added later as + // its output nodes' input edges. + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + (*host_graph) + ->AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); + } - // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); - } - }, - NodeComparatorID()); + // Change b). + if (copy->type_string() == "_XlaRecvAtHost" || + copy->type_string() == "_XlaSendFromHost") { + (*host_graph)->AddControlEdge(copy, sequencer); + } + }, + NodeComparatorID()); if (!s.ok()) { return s; } @@ -838,7 +839,12 @@ Status ExtractOutsideCompilationForFunction( FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( shape_inference_graph); - TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + if (fld->Find(shape_inference_graph)) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + } } } } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index bd8719b7f1a..d984ca15cb7 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" @@ -34,14 +35,30 @@ limitations under the License. namespace tensorflow { namespace { -Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) { - TF_RET_CHECK(n->type_string() == "Const"); + +// StatusOrOptional instances hold +// +// - A non-OK Status to indicate an error that needs to be propagated out of +// this pass (e.g. the Graph is malformed). +// +// - A nullopt to indicate the function that created the instance failed to do +// what it set out to do but this is not actually an error +// (e.g. TryToGetTensorFromConstOp was passed a non-Const node). +// +// - A T to indicate a successful operation. +template +using StatusOrOptional = xla::StatusOr>; + +StatusOrOptional TryToGetTensorFromConstOp(Node* n) { + if (n->type_string() != "Const") { + return {absl::nullopt}; + } + const TensorProto* proto = nullptr; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); Tensor tensor(proto->dtype()); TF_RET_CHECK(tensor.FromProto(*proto)); - *out_tensor = std::move(tensor); - return Status::OK(); + return {tensor}; } struct SliceInputs { @@ -70,7 +87,7 @@ std::vector IntTensorAsVector(const Tensor& t) { // Packages up the inputs to a Slice operation into an instance of // `SliceInputs`. -Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { +StatusOrOptional GetSliceInputs(Node* slice) { const int kSliceInputIndex = 0; const int kSliceBeginIndex = 1; const int kSliceSizeIndex = 2; @@ -81,23 +98,27 @@ Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); const Edge* slice_begin_edge; TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); - slice_inputs->input = + + SliceInputs slice_inputs; + slice_inputs.input = Output(slice_input_edge->src(), slice_input_edge->src_output()); - slice_inputs->begin = + slice_inputs.begin = Output(slice_begin_edge->src(), slice_begin_edge->src_output()); - slice_inputs->size = + slice_inputs.size = Output(slice_size_edge->src(), slice_size_edge->src_output()); - Tensor tf_slice_size; - TF_RETURN_IF_ERROR( - GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size)); - - if (tf_slice_size.dims() != 1) { - return errors::Internal("Expected vector for the slice size input."); + TF_ASSIGN_OR_RETURN(absl::optional tf_slice_size, + TryToGetTensorFromConstOp(slice_inputs.size.node())); + if (!tf_slice_size.has_value()) { + return {absl::nullopt}; } - slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size); - return Status::OK(); + if (tf_slice_size->dims() != 1) { + return {absl::nullopt}; + } + + slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size); + return {slice_inputs}; } // Casts `x` to a DT_INT64 if it isn't one already. @@ -263,36 +284,43 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Returns true if `n` is a slice we can rewrite to have a static shape -// (i.e. have the output shape only depend on the "size" input). Fills in -// `slice_inputs` in the process. -bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { +// If `n` is a slice we can rewrite to have a static shape (i.e. have the output +// shape only depend on the "size" input) then returns the a SliceInputs +// representing the inputs to `n`. Otherwise returns nullopt. +StatusOrOptional IsRewritableSlice(Node* n) { if (n->type_string() != "Slice") { - return false; + return {absl::nullopt}; } if (!GetXlaClusterForNode(*n).has_value()) { // There is no need to change slice ops outside XLA clusters. - return false; + return {absl::nullopt}; } - if (!GetSliceInputs(n, slice_inputs).ok()) { - // Could not parse slice inputs. E.g. the sizes input was not a constant. - return false; + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + if (!slice_inputs.has_value()) { + return {absl::nullopt}; } // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - return absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); + bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); + if (!slice_is_ok) { + return {absl::nullopt}; + } + + return slice_inputs; } Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector> slices_to_rewrite; for (Node* n : g->nodes()) { - SliceInputs slice_inputs; - if (IsRewritableSlice(n, &slice_inputs)) { - slices_to_rewrite.push_back({n, std::move(slice_inputs)}); + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + IsRewritableSlice(n)); + if (slice_inputs.has_value()) { + slices_to_rewrite.push_back({n, std::move(*slice_inputs)}); } } diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 107d521077c..f79bdc1e2e8 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -44,11 +44,8 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); -// TODO(b/111210515): IncreaseDynamismForAutoJitPass creates slices with index -// type DT_INT64 which do not have a kernel on GPU. -// -// REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, -// IncreaseDynamismForAutoJitPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + IncreaseDynamismForAutoJitPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 6bcae1dcc3d..055de7afcc5 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -39,12 +39,22 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + namespace tensorflow { namespace { -Status PlatformInfoFromContext(OpKernelConstruction* ctx, - XlaPlatformInfo* result) { +XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; @@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx, } if (!device_allocator) { - TF_ASSIGN_OR_RETURN(se::Platform* const platform, - se::MultiPlatformManager::PlatformWithId(platform_id)); + xla::StatusOr maybe_platform = + se::MultiPlatformManager::PlatformWithId(platform_id); + OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); + xla_allocator = absl::make_unique( - platform, ctx->device()->GetAllocator({})); + maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } - *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); - - return Status::OK(); + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, : OpKernel(ctx), constants_(constants), resources_(resources), - function_(function) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(function), + platform_info_(PlatformInfoFromContext(ctx)) {} static Status BuildCompilationCache(OpKernelContext* ctx, const XlaPlatformInfo& platform_info, @@ -277,8 +286,10 @@ static Status CompileToLocalExecutable( // rather than a one-element tuple. compile_options.always_return_tuple = false; - return cache->Compile(options, function, constant_args, *variables, ctx, - compile_options, + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_args, *variables, ctx, &args)); + return cache->Compile(options, function, args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, kernel, executable); @@ -333,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - // Helper static functions to construct parameters for // XlaLocalLaunchBase constructor from OpKernelConstruction. std::vector ConstantsVector(OpKernelConstruction* ctx) { @@ -381,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) { return *func; } -#undef OP_REQUIRES_OK_RETURN +bool MustCompileAttr(OpKernelConstruction* ctx) { + bool must_compile; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr("must_compile", &must_compile)); + return must_compile; +} } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -396,10 +400,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx), constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), - function_(FunctionAttr(ctx)) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); -} + function_(FunctionAttr(ctx)), + platform_info_(PlatformInfoFromContext(ctx)), + must_compile_(MustCompileAttr(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() @@ -409,13 +412,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) { + bool cannot_compile_cluster; + { + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster = cannot_compile_cluster_; + } + + if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + cannot_compile_cluster) { executable = nullptr; } else { - OP_REQUIRES_OK(ctx, CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/!must_compile_, &client, - &variables, &kernel, &executable)); + Status status = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, + /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + OP_REQUIRES_OK(ctx, status); + } + + if (status.code() == error::UNIMPLEMENTED) { + LOG(WARNING) << "Compilation failed:" << status.ToString() + << ". Falling back to TF function call."; + executable = nullptr; + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster_ = true; + } } AllocatorAttributes host_alloc_attrs; @@ -452,9 +472,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(1, compilation_successful); } -XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index ac90837e0d9..7b4d4b5b473 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#include + #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" @@ -33,6 +35,7 @@ namespace tensorflow { class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; explicit XlaPlatformInfo(const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, @@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel { protected: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; - XlaPlatformInfo platform_info_; + const NameAttrList function_; + const XlaPlatformInfo platform_info_; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel { private: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; + const NameAttrList function_; XlaPlatformInfo platform_info_; - bool must_compile_; + const bool must_compile_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + + mutex cannot_compile_cluster_mu_; }; class XlaRunOp : public OpKernel { @@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaPlatformInfo platform_info_; + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 49ff9a3ddd1..5fa6c85f06f 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -22,7 +22,7 @@ cc_library( hdrs = ["mark_for_compilation_pass_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -34,7 +34,7 @@ cc_library( hdrs = ["xla_device_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -46,7 +46,7 @@ cc_library( hdrs = ["build_xla_ops_pass_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -58,7 +58,7 @@ cc_library( hdrs = ["xla_ops_common_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc index 73f4dc73ed8..961c17c17ea 100644 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc @@ -16,7 +16,7 @@ limitations under the License. #include // NOLINT #include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -34,7 +34,7 @@ void AllocateAndParseFlags() { Flag("tf_xla_enable_lazy_compilation", &flags->tf_xla_enable_lazy_compilation, ""), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } } // namespace diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 7277a1d1f8a..bad306e0b0a 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -19,7 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -64,7 +65,18 @@ static void AllocateFlags() { Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " "global_jit_level is ON*.")}); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Parsed MarkForCompilationPassFlags:"; + VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit; + VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size; + VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size; + VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug; + VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel; + VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + } } // Append to *append_to flag definitions associated with the XLA bridge's diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 2affda6ab4e..79b47357a17 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -33,7 +33,7 @@ void AppendMarkForCompilationPassFlags( // The values of flags associated with the XLA bridge's // mark_for_compilation_pass module. -typedef struct { +struct MarkForCompilationPassFlags { int32 tf_xla_auto_jit; // Control compilation of operators into XLA // computations on CPU and GPU devices. 0 = use // ConfigProto setting; -1 = off; 1 = on for things @@ -55,7 +55,7 @@ typedef struct { // is set to ON* and overrides its behavior. If // true, enable fusion of element-wise operations // only using XLA. -} MarkForCompilationPassFlags; +}; // Return a pointer to the MarkForCompilationPassFlags struct; // repeated calls return the same pointer. diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc index 1bb2fce2dba..76b80d3034c 100644 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -41,7 +41,7 @@ static void AllocateFlags() { "Switch a device into 'on-demand' mode, where instead of " "autoclustering ops are compiled one by one just-in-time."), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } // Return a pointer to the XlaDeviceFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc index ae17fdffb9b..1443d48a734 100644 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" - +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -35,7 +35,13 @@ void AllocateAndParseFlags() { Flag("tf_xla_always_defer_compilation", &flags->tf_xla_always_defer_compilation, ""), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Parsed XlaOpsCommonFlags:"; + VLOG(1) << " tf_xla_always_defer_compilation = " + << flags->tf_xla_always_defer_compilation; + } } const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11975a6bb07..70033cae0af 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -61,8 +61,23 @@ struct OperationFilter { // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid // auto-clustering stateful RNG ops. bool allow_stateful_rng_ops; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger; + + // Whether ops with dummy implementations are allowed. We avoid + // auto-clustering these ops so that the user is not surprised when XLA is + // implicitly enabled. If the user explicitly specifies to use XLA, it is fine + // to resort to a dummy implementation. Currently Assert and CheckNumerics ops + // have dummy XLA implementations. + bool allow_dummy_ops; }; +bool IsDummyImplOp(absl::string_view op_name) { + return op_name == "Assert" || op_name == "CheckNumerics"; +} + bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || @@ -225,6 +240,12 @@ bool IsCompilableCall(const NodeDef& call_def, IsStatefulRandomOp(node->type_string())) { return false; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + return false; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, lib_runtime)) { @@ -452,7 +473,14 @@ Status FindCompilationCandidates( OperationFilter op_filter; op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = registration->requires_compilation; + op_filter.allow_stateful_rng_ops = + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_control_trigger = + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_dummy_ops = (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, @@ -467,6 +495,15 @@ Status FindCompilationCandidates( VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; continue; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op"; + continue; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": dummy op (" + << node->type_string() << ")"; + continue; + } if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { @@ -597,11 +634,14 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - // We can always *compile* resource operations and stateful RNGs, even if we - // are sometimes unable to auto-cluster them. + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. OperationFilter op_filter; op_filter.allow_resource_ops = true; op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_dummy_ops = true; + return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } @@ -613,10 +653,8 @@ Status MarkForCompilationPass::Run( GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); - bool cpu_global_jit = flags->tf_xla_cpu_global_jit; bool fusion_only = flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; @@ -635,9 +673,6 @@ Status MarkForCompilationPass::Run( return false; } - // If this device requires a JIT, we must say yes. - if (registration->requires_compilation) return true; - // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); @@ -674,18 +709,21 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level. - // Ignore enable_jit_by_default if global jit compilation for CPU - // is explicitly requested via tf_xla_cpu_global_jit flag - bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + // Otherwise use the value of global_jit_level and the device's + // autoclustering policy. bool should_compile = - (ignore_registration || registration->enable_jit_by_default) && - global_jit_level != OptimizerOptions::OFF; + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); if (!should_compile) { if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { - VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + VLOG(2) + << "Rejecting " << node->name() + << ": autoclustering for device only when requested explicitly."; } } return should_compile; @@ -1073,12 +1111,10 @@ Status MarkForCompilationPass::RunImpl( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if the operator is placed on a device that requires - // compilation, or if it contains at least one op that is marked for + // Also, always compile if it contains at least one op that is marked for // compilation that is not an Identity op. if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || - registration->requires_compilation) { + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { string& name = cluster_names[cluster]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index ead1cf4fd5f..24d78c07726 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -817,14 +817,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unordered_map clusters = GetClusters(*graph); - ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; - - // ctrl_trigger_a has inputs with mismatching deadness so it won't be - // clustered. ctrl_trigger_b is okay to cluster. - std::unordered_map expected_clusters( - {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); - EXPECT_EQ(clusters, expected_clusters); + // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so + // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't + // cluster it because of b/118970344. + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, RandomShape) { @@ -923,9 +919,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); - EXPECT_NE(clusters["test/shape_rng"], ""); - EXPECT_NE(clusters["test/reshape"], ""); - EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); + EXPECT_EQ(clusters["test/shape_rng"], ""); + EXPECT_EQ(clusters["test/reshape"], ""); } TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { @@ -1088,7 +1083,7 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { EXPECT_NE(clusters["test/c"], ""); } -TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { +TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { Scope root = Scope::NewRootScope().ExitOnError(); Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); @@ -1104,5 +1099,53 @@ TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/b"], ""); } + +TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/check"], ""); + EXPECT_NE(clusters["test/greaterequal"], ""); + EXPECT_NE(clusters["test/assert"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterDummyOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/assert"], ""); + EXPECT_EQ(clusters["test/check"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 5b961032233..36b345ecbff 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -133,6 +133,10 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveEdge(out_edge_to_clone); } + if (n->out_edges().empty()) { + graph->RemoveNode(n); + } + return Status::OK(); } @@ -191,6 +195,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { } } + // Recompute post order since PartiallyDeclusterNode may have deleted nodes. + post_order.clear(); + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); nodes_to_partially_decluster.clear(); TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); @@ -210,7 +218,8 @@ bool IsIntraClusterEdge(const Edge& edge) { bool IsMustCompileDevice(const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return registration->requires_compilation; + return registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; } return false; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 74d5ef57184..1fc5da5071f 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -437,5 +437,32 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); } +TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { + const char* const kClusteredProducer0Name = "ClusteredProducer0"; + const char* const kClusteredProducer1Name = "ClusteredProducer1"; + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName(kClusteredProducer0Name)); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName(kClusteredProducer1Name)); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h deleted file mode 100644 index 7c8c04152d2..00000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ -#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ - -#include -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// A thread-safe, first-in-first-out queue. -template -class ProducerConsumerQueue { - public: - ProducerConsumerQueue() - : capacity_(std::numeric_limits::max()) {} - ~ProducerConsumerQueue() = default; - - // Wait until the queue is non-full, then append a copy of v. - void Put(const T &v); - - // Wait until the queue is non-empty, then remove and return the head value. - T Get(); - - // If the queue is non-empty, remove the head value, placing it in *pv, and - // return true; otherwise return false. - bool TryGet(T *pv); - - // Set the capacity of the queue; the queue is full whenever count() >= - // capacity(). The initial value is the maximum size_t. Requires size > 0. - void set_capacity(std::size_t size); - - // Return the capacity of the queue. - std::size_t capacity() const; - - // Return the number of elements in the queue. - std::size_t count() const; - - // Implementation details follow. Clients should ignore. - private: - mutable tensorflow::mutex mu_; // protects all fields below - tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); - tensorflow::condition_variable non_full_ GUARDED_BY(mu_); - std::size_t capacity_ GUARDED_BY(mu_); - std::deque queue_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); -}; - -// ------------------------------------------------------ -// Implementation details follow. Clients should ignore. - -// Wait until the queue is non-full, then append a copy of v. -template -void ProducerConsumerQueue::Put(const T &v) { - mutex_lock lock(mu_); - while (queue_.size() >= capacity_) { - non_full_.wait(lock); - } - queue_.push_back(v); - non_empty_.notify_one(); -} - -// Wait until the queue is non-empty, then remove and return the head value. -template -T ProducerConsumerQueue::Get() { - mutex_lock lock(mu_); - while (queue_.empty()) { - non_empty_.wait(lock); - } - non_full_.notify_one(); - T result_value = queue_.front(); - queue_.pop_front(); - return result_value; -} - -// If the queue is non-empty, remove the head value, placing it in *pv, and -// return true; otherwise return false. -template -bool ProducerConsumerQueue::TryGet(T *pv) { - mutex_lock lock(mu_); - bool got_element = !queue_.empty(); - if (got_element) { - non_full_.notify_one(); - *pv = queue_.front(); - queue_.pop_front(); - } - return got_element; -} - -// Set the capacity of the queue; the queue is full whenever count() >= -// capacity(). The initial value is the maximum size_t. Requires size > 0. -template -void ProducerConsumerQueue::set_capacity(std::size_t size) { - mutex_lock lock(mu_); - CHECK_NE(size, 0); - capacity_ = size; - non_full_.notify_all(); -} - -// Return the capacity of the queue. -template -std::size_t ProducerConsumerQueue::capacity() const { - mutex_lock lock(mu_); - std::size_t max_elements = capacity_; - return max_elements; -} - -// Return the number of elements in the queue. -template -std::size_t ProducerConsumerQueue::count() const { - mutex_lock lock(mu_); - std::size_t num_elements = queue_.size(); - return num_elements; -} -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc deleted file mode 100644 index f61260c6e52..00000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/producer_consumer_queue.h" - -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -typedef ProducerConsumerQueue IntQueue; - -// Insert integers between low inclusive and high exclusive into q. -void PushRange(IntQueue *q, int low, int high) { - while (low != high) { - q->Put(low); - VLOG(2) << "Pushing " << low; - ++low; - } -} - -// Push the numbers between 0 and 999 inclusive from several threads in the -// pool. -void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { - VLOG(1) << "Adding 20-36"; - pool->Schedule([queue] { PushRange(queue, 20, 36); }); - VLOG(1) << "Adding 7-20"; - pool->Schedule([queue] { PushRange(queue, 7, 20); }); - VLOG(1) << "Adding 36-501"; - pool->Schedule([queue] { PushRange(queue, 36, 501); }); - VLOG(1) << "Adding 501-1000"; - pool->Schedule([queue] { PushRange(queue, 501, 1000); }); - VLOG(1) << "Adding 0-5"; - pool->Schedule([queue] { PushRange(queue, 0, 5); }); - VLOG(1) << "Adding 5-7"; - pool->Schedule([queue] { PushRange(queue, 5, 7); }); -} - -// Pop elements from queue using Get(). Make sure that exactly elements -// were present and their values are all integers between 0 and high-1 -// inclusive. -void GetRange(IntQueue *queue, int high) { - VLOG(1) << "Testing Wait"; - std::vector results; - for (int i = 0; i != high; ++i) { - int r = queue->Get(); - VLOG(2) << "Waited and got " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK(results[i] == i); - } -} - -// Pop elements from queue using TryGet(). Make sure that exactly -// elements were present and their values are all integers between 0 and high-1 -// inclusive. -void TryGetRange(IntQueue *queue, int high) { - std::vector results; - // Give up if we don't get all the elements back from the queue - // in 10 seconds. - int timeout = 10; - int r; - for (int i = 0; i != high; ++i) { - while (!queue->TryGet(&r)) { - if (!timeout--) { - LOG(FATAL) << "Can't find all elements in the queue"; - } - VLOG(1) << "Sleeping for a second..."; - sleep(1); - } - VLOG(2) << "Popped " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - CHECK(!queue->TryGet(&r)); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK_EQ(i, results[i]); - } -} - -const int kNumThreads = 15; - -TEST(ProducerConsumerQueue, GetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - GetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, TryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - TryGetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, ParallelGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { GetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -TEST(ProducerConsumerQueue, ParallelTryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 4a5ea9e0a5f..3df5479a55e 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -65,14 +66,14 @@ string XlaCompilationCache::DebugString() { // Compute a string signature which encodes the shapes of the // arguments in the supplied list. -string XlaCompilationCache::SignatureDebugString(const Signature& sig) { - string result = sig.name; - for (const auto& a : sig.arg_types) { +string XlaCompilationCache::Signature::HumanString() const { + string result = name; + for (const auto& a : arg_types) { absl::StrAppend(&result, ",", DataTypeString(a.first), a.second.DebugString()); } - for (const auto& v : sig.arg_values) { + for (const auto& v : arg_values) { absl::StrAppend(&result, "; ", v.DebugString()); } return result; @@ -84,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { - if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { + if (arg_values[i].dtype() != other.arg_values[i].dtype() || + arg_values[i].shape() != other.arg_values[i].shape() || + arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { return false; } } @@ -108,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( return h; } -Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - Signature* signature) { - signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.reserve(constant_args.size()); - - signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); - - for (int i = 0; i < ctx->num_inputs(); ++i) { - if (constant_args.count(i) > 0) { - // Use the values of compile time constants in the signature. - signature->arg_values.push_back(constant_args.at(i)); - } else if (variable_args.count(i) > 0) { - const OptionalTensor& variable = variable_args.at(i); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); - } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); - } - } else { - signature->arg_types.emplace_back(ctx->input_dtype(i), - ctx->input(i).shape()); +xla::StatusOr +XlaCompilationCache::BuildSignature( + const NameAttrList& function, + absl::Span args) { + Signature signature; + signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); + for (const XlaCompiler::Argument& arg : args) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + signature.arg_values.push_back(arg.constant_value); + break; + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kResource: + signature.arg_types.emplace_back(arg.type, arg.shape); + break; + default: + return errors::InvalidArgument( + "Unhandled argument kind in XlaCompilationCache: ", + arg.HumanString()); } } - return Status::OK(); + return std::move(signature); } -namespace { - -// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. -Status BuildArguments(const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); - - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; - if (constant_args.count(input_num) > 0) { - // Handles compile-time constants. - const Tensor& input = constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_args.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { - // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - const OptionalTensor& variable = variable_args.at(input_num); - arg.name = variable.name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable.present) { - const Tensor& value = variable.value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; - } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do - // permit uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); - } - } - } - - return Status::OK(); -} - -} // namespace - Status XlaCompilationCache::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, @@ -227,25 +164,38 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { - // Set the compile threshold to 1 to implement CompileMode::kStrict. - int64 compile_threshold = - compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; - return CompileImpl(options, function, constant_args, variable_args, ctx, - compile_options, /*compile_single_op=*/false, + absl::optional compile_threshold; + if (compile_mode == CompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } + auto compile_fn = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + return compiler->CompileFunction(compile_options, function, args, result); + }; + return CompileImpl(options, function, args, compile_fn, /*compile_threshold=*/compile_threshold, out_compilation_result, out_executable); } +static bool IsMegamorphic(int64 compile_count, int64 execution_count) { + const int64 kCompileThreshold = 10; + const int64 kMinExecutionsPerCompile = 50; + + // This heuristic is trying to capture the following property: have we sunk a + // certain minimum amount of compile time into the cluster that didn't quite + // "pay off"? + return compile_count > kCompileThreshold && + execution_count < kMinExecutionsPerCompile * compile_count; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { @@ -253,54 +203,41 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl(options, name, constant_args, variable_args, ctx, - compile_options, - /*compile_single_op=*/true, /*compile_threshold=*/1, + auto compile_op = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + std::vector result_dtypes(ctx->num_outputs()); + for (int i = 0; i < result_dtypes.size(); ++i) { + result_dtypes[i] = ctx->expected_output_dtype(i); + } + return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), + args, result_dtypes, result); + }; + return CompileImpl(options, name, args, compile_op, + /*compile_threshold=*/absl::nullopt, out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, - int64 compile_threshold, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { - VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << constant_args.size() - << " num_variable_args=" << variable_args.size(); - for (int i = 0; i < ctx->num_inputs(); i++) { - TensorShape shape = ctx->input(i).shape(); - VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i)) - << " present=" << ctx->has_input(i) - << " shape=" << shape.DebugString(); - } - for (auto& iterator : variable_args) { - const OptionalTensor& variable = iterator.second; - VLOG(2) << "variable present=" << variable.present - << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString() - << " TF arg= " << iterator.first; - } - VLOG(2) << "num_outputs = " << ctx->num_outputs(); - for (int i = 0; i < ctx->num_outputs(); i++) { - VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0; i < args.size(); i++) { + VLOG(2) << i << ": " << args[i].HumanString(); } } - TF_RET_CHECK(constant_args.size() + variable_args.size() <= - ctx->num_inputs()); + TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); + VLOG(2) << "Signature: " << signature.HumanString(); - Signature signature; - TF_RETURN_IF_ERROR( - BuildSignature(function, constant_args, variable_args, ctx, &signature)); - - VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not // protect the contents of the cache entry. Entry* entry; @@ -319,25 +256,67 @@ Status XlaCompilationCache::CompileImpl( // (since they get the benefit of XLA right away without waiting for warmup) // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at // most one cluster-compilation's worth of compile time). - bool is_first_execution = [&] { + bool is_first_execution; + + // We avoid compiling clusters that have "gone megamorphic" i.e. have an + // excessive amount of shape dynamism. + bool is_megamorphic; + + { mutex_lock lock(cluster_compile_stats_mu_); auto it = cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) .first; - return it->second.execution_count++ == 0; - }(); + is_first_execution = it->second.execution_count++ == 0; + + // The is_megamorphic bit is "sticky". We assume clusters that have been + // observed to be megamorphic once stay megamorphic forever. + it->second.is_megamorphic |= + IsMegamorphic(/*compile_count=*/it->second.compile_count, + /*execution_count=*/it->second.execution_count); + is_megamorphic = it->second.is_megamorphic; + } // Acquire the cache entry lock and compile, if necessary. // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); int64 current_request_count = ++entry->request_count; + VLOG(2) << "Compilation cache entry hit: " << entry->compiled + << " signature: " << signature.HumanString() << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold.value_or(0); if (!entry->compiled) { - VLOG(2) << "Compilation cache miss for signature: " - << SignatureDebugString(signature) << " with request count " - << current_request_count << " and compile threshold " - << compile_threshold; - if (!is_first_execution && current_request_count < compile_threshold) { + const bool should_compile = [&] { + if (!compile_threshold.has_value()) { + // Lazy compilation is disabled. + return true; + } + + if (is_megamorphic) { + VLOG(3) << "Not compiling cluster " << function.name() + << " because it is megamorphic."; + return false; + } + + if (is_first_execution) { + return true; + } + + bool reached_compile_threshold = + current_request_count >= *compile_threshold; + if (!reached_compile_threshold) { + VLOG(3) + << "Not compiling cluster " << function.name() + << " because it has not reached compile threshold; threshold is " + << *compile_threshold << " execution count " + << current_request_count << "."; + } + return reached_compile_threshold; + }(); + + if (!should_compile) { + VLOG(2) << "Not compiling for signature: " << signature.HumanString(); *out_compilation_result = nullptr; *out_executable = nullptr; return Status::OK(); @@ -347,21 +326,12 @@ Status XlaCompilationCache::CompileImpl( const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) - std::vector args; - TF_RETURN_IF_ERROR( - BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - if (compile_single_op) { - entry->compilation_status = - compiler.CompileSingleOp(compile_options, signature.name, ctx, args, - &entry->compilation_result); - } else { - entry->compilation_status = compiler.CompileFunction( - compile_options, function, args, &entry->compilation_result); - } + entry->compilation_status = + compile_fn(&compiler, &entry->compilation_result); TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); entry->compilation_status = diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index b43e5d40e64..846d0c963db 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -17,9 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/graph.pb.h" @@ -30,13 +33,6 @@ limitations under the License. namespace tensorflow { -// Struct that represents a possibly-absent Tensor. -struct OptionalTensor { - string name; // A descriptive name - bool present = false; // Is the tensor present? - Tensor value; // If present, what is the Tensor's value? -}; - // The XlaCompilationCache class caches the results of the XlaCompiler class, // which converts a Tensorflow graph into a compiled XLA compilation. // @@ -58,11 +54,7 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `constant_args` is a map of tensorflow argument number to its constant - // value. - // `variable_args` is a snapshot of the current values of the - // resource variable arguments to `function`; uninitialized variables are - // represented by an absent OptionalTensor. + // `args` is a description of the arguments to the computation. // // `compile_mode` controls the behavior of the compilation cache on a cache // miss. If `compile_mode` is `kLazy` then, based on some profitability @@ -78,9 +70,7 @@ class XlaCompilationCache : public ResourceBase { // outputs. Status Compile(const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, @@ -90,8 +80,7 @@ class XlaCompilationCache : public ResourceBase { // XlaCompiler::CompileFunction. Status CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -101,26 +90,6 @@ class XlaCompilationCache : public ResourceBase { string DebugString() override; - private: - // Common implementation of Compile and CompileSingleOp. - Status CompileImpl( - const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op, int64 compile_threshold, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // Takes `result` which has been compiled from a Tensorflow subgraph to a - // XLA computation already, and generates an XLA LocalExecutable `executable`. - Status BuildExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable); - - xla::LocalClient* const client_; - const DeviceType device_type_; - // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { @@ -137,14 +106,35 @@ class XlaCompilationCache : public ResourceBase { struct Hash { uint64 operator()(const Signature& signature) const; }; + + // Returns a human-readable description of the signature. + string HumanString() const; }; - static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, Signature* signature); + static xla::StatusOr BuildSignature( + const NameAttrList& function, + absl::Span args); + + private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); + + // Takes `result` which has been compiled from a Tensorflow subgraph to a + // XLA computation already, and generates an XLA LocalExecutable `executable`. + Status BuildExecutable(const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + std::unique_ptr* executable); + + xla::LocalClient* const client_; + const DeviceType device_type_; // The value associated with a cache entry. struct Entry { @@ -180,7 +170,13 @@ class XlaCompilationCache : public ResourceBase { // Cumulative time spent compiling the cluster. int64 cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; }; + mutex cluster_compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc new file mode 100644 index 00000000000..018c7c219f4 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(XlaCompilationCacheTest, SignatureEquality) { + NameAttrList fn; + fn.set_name("afunction"); + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({4, 0}); + args[0].constant_value = Tensor(DT_INT32, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].type = DT_FLOAT; + args[0].constant_value = Tensor(DT_FLOAT, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].shape = TensorShape({0, 4}); + args[0].constant_value = Tensor(DT_FLOAT, {0, 4}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3, + XlaCompilationCache::BuildSignature(fn, args)); + + std::vector signatures = {s1, s2, s3}; + for (int i = 0; i < signatures.size(); ++i) { + for (int j = 0; j < signatures.size(); ++j) { + EXPECT_EQ(i == j, signatures[i] == signatures[j]) + << signatures[i].HumanString() << " " << signatures[j].HumanString(); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 31cb32e3059..1fe612d43d1 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -187,8 +187,13 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); - return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - compile_options, result, executable); + + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arguments, variable_args, ctx, &args)); + + return cache->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index cbfeb388050..116e0756036 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = !compile_on_demand; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + compile_on_demand + ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested + : XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2289abd2df3..5c1b55cb57f 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -446,7 +446,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - kernel_factory::OpKernelRegistrar::Factory factory = + OpKernel* (*factory)(OpKernelConstruction*) = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 8881b697bc8..49f53b477ef 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -112,6 +112,12 @@ class XlaDevice : public LocalDevice { // compute, host-to-device, and device-to-host communication. bool use_multiple_streams = false; + // A function that describes how the on-host shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared + // shapes for computations. Must be non-null. XlaCompiler::ShapeRepresentationFn shape_representation_fn; // If padded_shape_fn is empty, a default implementation that returns diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index eb3cf27624b..6e6532731e6 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -70,9 +70,12 @@ XlaDeviceContext::XlaDeviceContext( CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = - [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { return shape; }; + shape_representation_fn_ = [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -99,7 +102,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, CHECK(xla_tensor); Status status = [&]() -> Status { - TF_ASSIGN_OR_RETURN(TensorShape shape, + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn_(device_tensor->shape(), device_tensor->dtype())); @@ -111,9 +114,15 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal())); + // The cpu_tensor and literal that we created here hold the data of host + // tensor in descending layout. The layout could be different from layout in + // device_tensor (but the logical shape has to be the same). The + // transfer_manager is responsible to do corresponding transposing when + // transferring the data to device. xla::BorrowingLiteral literal( static_cast(DMAHelper::base(cpu_tensor)), - xla_tensor->shaped_buffer().on_host_shape()); + xla::ShapeUtil::MakeShape(shape.element_type(), + xla::AsInt64Slice(shape.dimensions()))); VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); @@ -183,8 +192,15 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); + // Transfer manager requires the shape of the shaped buffer to be the same as + // literal shape except for the layout. Set the literal to use xla_tensor's + // shape as it is derived from the cpu_tensor's shape using + // shape_representation_fn_. xla::MutableBorrowingLiteral literal; - TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal)); + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout( + xla_tensor->shaped_buffer().on_host_shape()), + cpu_tensor, &literal)); TensorReference ref(*device_tensor); transfer_manager_->TransferLiteralFromDevice( diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 241ea8f60df..adf0f994b84 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/stack.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .TypeConstraint("T") \ .HostMemory("input"), \ - RetvalOp); + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("StackV2") \ + .Device(DEVICE) \ + .HostMemory("max_size") \ + .HostMemory("handle"), \ + StackOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("T", TYPES), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("elem_type", TYPES), \ + StackPopOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); -// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // and write the tensors they access in order to concatenate them into a batch. // We would need either to call out to an XLA computation to perform the diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 8f28b38b5e1..44197016958 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, std::vector* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); @@ -53,24 +53,25 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, return Status::OK(); } - XlaDevice::Options options; - options.platform = platform.ValueOrDie(); - options.device_name_prefix = name_prefix; - options.device_name = DEVICE_XLA_GPU; - options.device_ordinal = 0; - options.compilation_device_name = DEVICE_GPU_XLA_JIT; - options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + XlaDevice::Options options; + options.platform = platform.ValueOrDie(); + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_GPU; + options.device_ordinal = i; + options.compilation_device_name = DEVICE_GPU_XLA_JIT; + options.use_multiple_streams = true; + auto device = absl::make_unique(session_options, options); - // TODO(b/78468222): Uncomment after fixing this bug - // status = device->UseGpuDeviceInfo(); - // if (!status.ok()) { - // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, - // " device"); - // return status; - // } + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + " device number ", i); + return status; + } - devices->push_back(device.release()); + devices->push_back(device.release()); + } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index dc37362fd86..e828bae865d 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices( XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6e51bfca4a1..3b0bda4caa1 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -191,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { return Status::OK(); } -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -ScopedShapedBuffer ExtractSubShapedBuffer( - ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator) { - const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_host_shape(), index); - const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_device_shape(), index); - - ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, - shaped_buffer->platform(), - shaped_buffer->device_ordinal()); - - auto& shape_tree = shaped_buffer->buffers(); - auto& sub_shape_tree = sub_shaped_buffer.buffers(); - sub_shape_tree.CopySubtreeFrom(shape_tree, - /*source_base_index=*/{index}, - /*target_base_index=*/{}); - shape_tree.ForEachMutableElement( - [index](const xla::ShapeIndex& shape_index, - tensorflow::se::DeviceMemoryBase* data) { - // shape_index is empty for the root node. Ignore that. - if (!shape_index.empty() && shape_index[0] == index) { - *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); - } - }); - return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); -} -} // namespace internal -using internal::ExtractSubShapedBuffer; - XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) @@ -391,8 +357,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); if (use_multiple_streams_) { xla_tensor->ResetDefinitionEvent(definition_event, stream); } @@ -445,7 +410,6 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - se::DeviceMemoryBase buffer = output.buffer({output_num}); if (variable_infos[i].var()->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); @@ -455,18 +419,20 @@ Status XlaComputationLaunchContext::PopulateOutputs( Tensor output_tensor; TF_RETURN_IF_ERROR( ctx->allocate_temp(write.type, write.shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); - if (use_multiple_streams_) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); + if (write.shape.num_elements() > 0) { + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); + if (use_multiple_streams_) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); + } } *variable_infos[i].var()->tensor() = output_tensor; } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable_infos[i].var()->tensor() = output_tensor; } ++output_num; @@ -474,4 +440,60 @@ Status XlaComputationLaunchContext::PopulateOutputs( return Status::OK(); } +Status XlaComputationLaunchContext::BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args) { + args->resize(ctx->num_inputs()); + + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { + XlaCompiler::Argument& arg = (*args)[input_num]; + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); + arg.constant_value = input; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); + } else { + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 81e205d13f7..437db019a0e 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -35,6 +35,13 @@ limitations under the License. namespace tensorflow { class XlaAllocator; +// Struct that represents a possibly-absent Tensor. +struct OptionalTensor { + string name; // A descriptive name + bool present = false; // Is the tensor present? + Tensor value; // If present, what is the Tensor's value? +}; + // Takes a snapshot of the values of resource variable arguments, whose indices // are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is @@ -139,6 +146,13 @@ class XlaComputationLaunchContext { bool allocate_xla_tensors, bool use_multiple_streams); + // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch + // op. + static Status BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args); + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. // @@ -223,17 +237,6 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; -// Exposed in this header file for microbenchmarking purposes, but this is an -// internal implementation detail. -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -xla::ScopedShapedBuffer ExtractSubShapedBuffer( - xla::ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator); -} // namespace internal - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc deleted file mode 100644 index a45932403ec..00000000000 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Contains microbenchmarks for performance critical functions in -// xla_launch_util.cc. - -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs -// (cardinality of each non-leaf node's children). -void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { - tensorflow::testing::StopTiming(); - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); - for (int i = 0; i < depth; ++i) { - std::vector shapes(fan_out, shape); - shape = xla::ShapeUtil::MakeTupleShape(shapes); - } - xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, - /*device_ordinal=*/0); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - // Extract a buffer from approximately the middle of the first level of the - // tree. - (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) - .release(); - } -} - -BENCHMARK(BM_ExtractSubBuffer) - ->ArgPair(1, 4) - ->ArgPair(1, 8) - ->ArgPair(1, 32) - ->ArgPair(1, 64) - ->ArgPair(1, 128) - ->ArgPair(1, 256) - ->ArgPair(1, 512) - ->ArgPair(2, 4) - ->ArgPair(2, 8) - ->ArgPair(2, 32) - ->ArgPair(2, 64) - ->ArgPair(2, 128); - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - tensorflow::testing::RunBenchmarks(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 6f8b198262d..d1f7f754c83 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -43,11 +43,10 @@ namespace tensorflow { } } -Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, +Status XlaTensor::AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal) { - xla::Shape on_host_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); xla::Shape on_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6d7a6fd66c8..77e80aa2527 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -50,7 +50,7 @@ class XlaTensor { // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // is replaced and the managed memory deallocated. - Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal); // Some Tensors can have complex on-device shapes, including tuple shapes. To diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6945de1eda1..6b8e6bba1e1 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -470,12 +470,12 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", + "//tensorflow/python/ops/signal", ], ) @@ -837,8 +837,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - # Stack ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 4e6dd6abfc9..332381c59ee 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + import numpy as np from tensorflow.compiler.tests import xla_test @@ -967,7 +969,7 @@ class BinaryOpsTest(xla_test.XLATestCase): self._testBinary( array_ops.expand_dims, np.array([42], dtype=dtype), - np.int32(0), + np.array([0], dtype=np.int64), expected=np.array([[42]], dtype=dtype)) self._testBinary( array_ops.expand_dims, @@ -994,15 +996,21 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[[1, 2], [3, 4]]], dtype=dtype), np.int32(3), expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.array([2], dtype=np.int64), + expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) def testPad(self): - for dtype in self.numeric_types: + for dtype, pad_type in itertools.product( + self.numeric_types, [np.int32, np.int64]): self._testBinary( array_ops.pad, np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[1, 2], [2, 1]], dtype=np.int32), + [[1, 2], [2, 1]], dtype=pad_type), expected=np.array( [[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], @@ -1016,7 +1024,7 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[0, 3], [2, 1]], dtype=np.int32), + [[0, 3], [2, 1]], dtype=pad_type), expected=np.array( [[7, 7, 1, 2, 3, 7], [7, 7, 4, 5, 6, 7], diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b3e13fbaa6b..e92afd5d6fe 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -24,10 +24,10 @@ import numpy as np import scipy.signal as sps from tensorflow.compiler.tests import xla_test -from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import signal from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 561715ee1c3..6f51ae33a1b 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -593,6 +593,67 @@ class LazyCompilationTest(test.TestCase): self.assertFalse( InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + def testIsMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Make the cluster go megamorphic by running it with lots of shape + # signatures where the cluster is executed with each signature only a few + # times. Then check that we don't compile the cluster ever again. + + for shape in range(10, 50): + for _ in range(0, 49): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 50): + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + + def testIsNotMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Run the cluster with lots of shape signatures, but in a way that it + # isn't megamorphic (i.e. each shape signature sees a lot of executions). + # Then check that the cluster has not been marked as megamorphic. + + for shape in range(10, 50): + for _ in range(0, 1000): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 10): + sess.run(y, feed_dict={x: [0.] * 60}) + + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index cfccf5f3d2a..a6b58020126 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2466,20 +2466,21 @@ TEST_F(OpTest, Pack) { }); } -// TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); - // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. - // DataType tpaddings = Choose({DT_INT32, DT_INT64}); - DataType tpaddings = DT_INT32; + DataType tpaddings = Choose({DT_INT32, DT_INT64}); std::vector paddings_vec; - std::uniform_int_distribution distribution(0, 7); for (int i = 0; i < t_dims.size(); ++i) { - paddings_vec.push_back(distribution(generator())); - paddings_vec.push_back(distribution(generator())); + std::uniform_int_distribution pad_distribution(0, t_dims[i]); + int pad_size = pad_distribution(generator()); + std::uniform_int_distribution lower_distribution(0, pad_size); + int low_pad_size = lower_distribution(generator()); + paddings_vec.push_back(low_pad_size); + paddings_vec.push_back(pad_size - low_pad_size); + t_dims[i] -= pad_size; } Tensor paddings; CHECK( diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/compiler/tests/resampler_ops_test.py index d05554fdb68..f87ac3360c9 100644 --- a/tensorflow/compiler/tests/resampler_ops_test.py +++ b/tensorflow/compiler/tests/resampler_ops_test.py @@ -37,7 +37,7 @@ class ResamplerOpsTest(xla_test.XLATestCase): out = sess.run(resampled, {input_image: image_np, warp: warp_np}) self.assertAllCloseAccordingToType( - expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2) + expected, out, rtol=5e-3, half_rtol=1e-2, bfloat16_rtol=3e-2) def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, expected_grad_data, expected_grad_warp): diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dd2c252d383..77cdeac8168 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" + def testWriteEmptyShape(self): + # Verifies that we can pass an uninitialized variable with an empty shape, + # assign it a value, and successfully return it. + for dtype in self.numeric_types: + with self.test_session() as sess, self.test_scope(): + zeros = np.zeros([3, 0], dtype=dtype) + v = resource_variable_ops.ResourceVariable(zeros) + p = array_ops.placeholder(dtype) + x = v.assign(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(zeros, sess.run(y, {p: zeros})) + def testOneWriteOneOutput(self): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5fc9a352ff9..e0171415492 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -166,6 +166,7 @@ cc_library( "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", + "xla_expression.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", @@ -180,6 +181,7 @@ cc_library( "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", + "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", @@ -194,6 +196,7 @@ cc_library( ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -217,6 +220,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -362,8 +366,12 @@ tf_cc_test( tf_cc_test( name = "xla_compiler_test", - srcs = ["xla_compiler_test.cc"], + srcs = [ + "xla_compiler_test.cc", + "xla_expression_test.cc", + ], deps = [ + ":common", ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", @@ -386,6 +394,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -435,7 +444,7 @@ cc_library( "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc index a6c908ba011..2eb1f8cd849 100644 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ b/tensorflow/compiler/tf2xla/dump_graph_flags.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -41,7 +41,7 @@ static void AllocateFlags() { "Path prefix to which graphs dumped during debugging should be " "written."), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } // Append to *append_to flag definitions associated with the XLA bridge's diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f818d80022d..9ef9f49f422 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -242,23 +242,20 @@ Status FunctionalizeControlFlowPass::Run( continue; } const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - bool modified; - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + bool modified; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 706ed4f5bbf..efb75749722 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -51,12 +52,11 @@ namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { - auto builder = ctx->builder(); auto client = ctx->compiler()->client(); - std::vector compile_time_constant_flags(expressions.size()); + std::vector arg_must_be_compile_time_constant(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); @@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); - if (arg.type == DT_RESOURCE) { - return errors::InvalidArgument( - "Resource as function argument is not yet implemented."); - } else if (expressions[i]->has_constant_value()) { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); - } else if (compile_time_constant_flags[i]) { - arg.kind = XlaCompiler::Argument::kConstant; - TF_RET_CHECK(expressions[i]->resource() == nullptr) - << "Input with resource is not yet implemented."; - TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( - expressions[i]->handle())); - TF_ASSIGN_OR_RETURN(auto literal, - client->ComputeConstant(constant_graph)); - TF_RETURN_IF_ERROR( - LiteralToHostTensor(literal, arg.type, &arg.constant_value)); - } else { - arg.kind = XlaCompiler::Argument::kParameter; + switch (expressions[i]->kind()) { + case XlaExpression::Kind::kConstant: + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + break; + case XlaExpression::Kind::kXlaOp: + if (arg_must_be_compile_time_constant[i]) { + TF_ASSIGN_OR_RETURN(absl::optional value, + expressions[i]->ResolveConstant(client)); + if (!value.has_value()) { + return errors::InvalidArgument( + "Argument to function must be a compile-time constant, but " + "unable to resolve argument value to a constant."); + } + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *value; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + break; + case XlaExpression::Kind::kResource: + return errors::Unimplemented( + "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument("Invalid function argument"); } } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 276d744c096..2db2514397d 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel { } const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; - if (arg.resource() != nullptr) { - ctx->SetResourceOutput(0, arg.resource()); - } else if (arg.has_constant_value()) { - ctx->SetConstantOutput(0, arg.constant_value()); - } else { - ctx->SetOutput(0, arg.handle()); - } + OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, + errors::InvalidArgument("Invalid/missing argument expression")); + ctx->SetOutputExpression(0, arg); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 9fa57b76f8e..c022284fec6 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), errors::InvalidArgument("In[", i, "] must be a vector.", in_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); + std::vector vec; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec)); - BCast::Vec vec; - for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(literal.Get({i})); - } - shapes.push_back(vec); + shapes.push_back(BCast::Vec(vec.begin(), vec.end())); } BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e28755dd73b..cd7c7f4a82d 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -45,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_tensor_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); - // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = literal.Get({}); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_tensor_shape.DebugString())); + int64 concat_dim; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); std::vector values; std::vector shapes; @@ -63,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel { const TensorShape& input_shape = shapes[0]; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(ctx, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -75,14 +72,11 @@ class ConcatBaseOp : public XlaOpKernel { // elements. std::vector input_data; int output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; - const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( - ctx, - in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), + ctx, in_shape.dims() == input_dims, errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", input_shape.DebugString(), " vs. shape[", i, @@ -131,11 +125,10 @@ class ConcatOffsetOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_shape = ctx->InputShape(0); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_shape.DebugString())); for (int i = 1; i < ctx->num_inputs(); ++i) { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), errors::InvalidArgument("input ", i, @@ -162,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel { // [0, 5, 0, 0] const int32 N = ctx->num_inputs() - 1; const TensorShape inp0_shape = ctx->InputShape(1); - xla::Literal inp0_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); - const int64 dims = inp0_shape.num_elements(); + std::vector inp0_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims)); + const int64 inp0_rank = inp0_shape.num_elements(); - xla::Literal concat_dim_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = concat_dim_literal.Get({}); + int64 cdim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); - VLOG(1) << "ConcatOffset " << cdim << "," << dims; - int32 axis = cdim < 0 ? cdim + dims : cdim; - OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), + VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; + int32 axis = cdim < 0 ? cdim + inp0_rank : cdim; + OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), errors::InvalidArgument("Concat dim is out of range: ", axis, - " vs. ", dims)); + " vs. ", inp0_rank)); int32 offset = 0; for (int i = 0; i < N; ++i) { const TensorShape inp_shape = ctx->InputShape(1 + i); - OP_REQUIRES(ctx, dims == inp_shape.num_elements(), - errors::InvalidArgument("input ", i, " should contain ", dims, - " elements, but got ", + OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), + errors::InvalidArgument("input ", i, " should contain ", + inp0_rank, " elements, but got ", inp_shape.num_elements())); - xla::Literal inp_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); + std::vector inp_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims)); - Tensor out_constant(DT_INT32, TensorShape({dims})); + Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); auto out_vec = out_constant.vec(); - for (int64 j = 0; j < dims; ++j) { + for (int64 j = 0; j < inp0_rank; ++j) { if (j == axis) { out_vec(j) = offset; - offset += inp_literal.Get({j}); + offset += inp_dims[j]; } else { - const int32 inp0_element = inp0_literal.Get({j}); - const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES(ctx, (inp0_element == inp_element), + const int32 inp0_element = inp0_dims[j]; + const int32 inp_element = inp_dims[j]; + OP_REQUIRES(ctx, inp0_element == inp_element, errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", inp0_element, " vs. ", inp_element)); diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 2628ef8e245..dff8af80022 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); - if (proto_.dtype() == DT_STRING) { - LOG(WARNING) << "Not computing Const of type DT_STRING"; - ctx->SetInvalidOutput(0); - return; - } xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e9bdb15aa0c..35e0625dbb0 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -33,39 +34,20 @@ class FillOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // The output of this Op is a tensor of shape 'dims_shape' with each // element set to the scalar 'dims_literal'. - const TensorShape dims_shape = ctx->InputShape(0); - const TensorShape value_shape = ctx->InputShape(1); + const TensorShape dims_shape = ctx->InputShape("dims"); + const TensorShape value_shape = ctx->InputShape("value"); OP_REQUIRES( - ctx, IsLegacyVector(dims_shape), + ctx, TensorShapeUtils::IsVector(dims_shape), errors::InvalidArgument("dims must be a vector of int32, got shape ", dims_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(value_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape), errors::InvalidArgument("value must be a scalar, got shape ", value_shape.DebugString())); - // Evaluate the 'dims' constant input, reshaping to a vector if it - // was a 'legacy' vector (secretly a scalar). - xla::Literal dims_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped( - 0, {dims_shape.num_elements()}, &dims_literal)); - // Convert the dims literal into a vector that we can pass to - // XlaBuilder. - std::vector broadcast; - broadcast.reserve(dims_literal.shape().dimensions(0)); - for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(dims_literal.Get({i})); - } - // Look up the value input, reshaping to a scalar if it was a - // 'legacy' scalar (secretly a vector). - xla::XlaOp data = ctx->Input(1); - if (value_shape.dims() > 0) { - CHECK_EQ(value_shape.dims(), 1); - data = xla::Reshape(data, {}); - } - // Emit the actual computation, which broadcasts the scalar to the - // desired shape. - auto result = xla::Broadcast(data, broadcast); + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims)); + auto result = xla::Broadcast(ctx->Input("value"), dims); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index d069373086a..e310db2162d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -48,9 +48,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // We require that the dimension argument is a constant, since it lets us // dispatch to a specialized custom-call function without any run-time // overhead, when compiling ahead-of-time. - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = literal.Get({}); + int64 dim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 4833a9662dd..f6b8534f4d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel { for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); - TF_ASSIGN_OR_RETURN(int64 lhs_padding, - pad_literal.GetIntegralAsS64({dimno, 0})); - TF_ASSIGN_OR_RETURN(int64 rhs_padding, - pad_literal.GetIntegralAsS64({dimno, 1})); + int64 lhs_padding = pad_literal.Get({dimno, 0}); + int64 rhs_padding = pad_literal.Get({dimno, 1}); int64 dim_size = original_shape.dimensions(dimno); // Padding amounts on each side must be no more than the size of the @@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); @@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel { TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::XlaBuilder* b = ctx->builder(); - auto in0 = ctx->Input(0); + auto in0 = ctx->Input("input"); xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 3f5445b4821..36ea70ac392 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel { explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); const int dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); - if (fixed_dims == 0) { + xla::XlaOp input = ctx->Input("input"); + if (dims == 0) { // Tensor is rank 0. Return it unchanged. - ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(0, input); return; } - // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::PaddingConfig config; - for (int i = 0; i < fixed_dims; ++i) { + for (int i = 0; i < dims; ++i) { auto* dim = config.add_dimensions(); - int before = pad_literal.Get({i, 0}); - int after = pad_literal.Get({i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument( "Paddings must be non-negative: ", before, " ", after)); @@ -73,12 +70,13 @@ class PadOp : public XlaOpKernel { // PadV2 added a "constant_values" input that indicates the pad value. xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), - errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")), + errors::InvalidArgument("constant_values must be a scalar.")); + ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(input, zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 47a4eac2066..fa1b6b91710 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); const TensorShape sizes_shape = ctx->InputShape(1); // Preliminary validation of sizes. - OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e172c649325..6970dd0a006 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -46,61 +47,8 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::XlaOp input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - DataType input_type = ctx->input_type(0); - XlaContext& tc = XlaContext::Get(ctx); - - if (input_type == DT_RESOURCE) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - ctx->SetStatus(tc.AddResourceRetval(index_, resource)); - return; - } - - auto is_constant = ctx->builder()->IsConstant(input); - if (!is_constant.ok()) { - ctx->SetStatus(is_constant.status()); - return; - } - - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); - } else { - TensorShape shape = ctx->InputShape(0); - ctx->SetStatus(is_constant.status()); - TensorShape representation_shape; - if (tc.is_entry_computation()) { - xla::StatusOr shape_or_status = - tc.RepresentationShape(shape, ctx->input_type(0)); - if (!shape_or_status.ok()) { - ctx->SetStatus(shape_or_status.status()); - return; - } else { - representation_shape = shape_or_status.ValueOrDie(); - } - } else { - representation_shape = shape; - } - - xla::XlaOp output = input; - if (tc.is_entry_computation()) { - output = xla::Reshape(input, representation_shape.dim_sizes()); - } else { - // The core from which a return value is returned depends on the - // device assignment of the input to the retval. Since we can't change - // the device assignment of "input" at this point, we must always - // introduce an operator here, even if the shape does not change. - // TODO(b/76097077): propagate device assignments onto arguments and - // return values of functions, and then reshape unconditionally. - output = - xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); - } - tc.AddRetval(index_, dtype_, shape, output); - } + XlaContext& xla_context = XlaContext::Get(ctx); + xla_context.SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 56b80cb4a29..2ceadaf79c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel { } // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); - std::vector revdims(x_shape.dims()); - std::copy(lax.data().begin(), lax.data().end(), - revdims.begin()); - std::vector dimensions; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax)); + std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { - if (revdims[d]) { + if (lax.Get({d})) { dimensions.push_back(d); } } diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 379f4aeb0fc..60b011ba6d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -30,31 +30,6 @@ limitations under the License. namespace tensorflow { namespace { -template -Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = literal.Get({}); - return Status::OK(); -} - -Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - switch (literal.shape().element_type()) { - case xla::S32: - *value = literal.Get({}); - break; - case xla::S64: - *value = literal.Get({}); - break; - default: - return errors::InvalidArgument("Invalid argument type for argument", - index); - } - return Status::OK(); -} - // The type-specific part of the implementation of Range. template xla::StatusOr CreateRangeTensor( @@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel { const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape limit_in_shape = ctx->InputShape(1); const TensorShape delta_in_shape = ctx->InputShape(2); - OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape), errors::InvalidArgument("limit must be a scalar, not shape ", limit_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape), errors::InvalidArgument("delta must be a scalar, not shape ", delta_in_shape.DebugString())); xla::Literal start, limit, delta; @@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel { explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape start_in_shape = ctx->InputShape(0); - const TensorShape stop_in_shape = ctx->InputShape(1); - const TensorShape num_in_shape = ctx->InputShape(2); + const TensorShape start_in_shape = ctx->InputShape("start"); + const TensorShape stop_in_shape = ctx->InputShape("stop"); + const TensorShape num_in_shape = ctx->InputShape("num"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); @@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel { DataType type = ctx->input_type(0); int64 num; - OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num)); OP_REQUIRES(ctx, num > 0, errors::InvalidArgument("Requires num > 0: ", num)); Tensor out_constant(type, TensorShape({num})); + xla::Literal start_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal)); + xla::Literal stop_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal)); + switch (type) { case DT_FLOAT: { - float start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + float start = start_literal.GetFirstElement(); + float stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; @@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel { break; } case DT_DOUBLE: { - double start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + double start = start_literal.GetFirstElement(); + double stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 37b026aeb05..12830816ec1 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { @@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel { explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dim_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape dim_shape = ctx->InputShape("dim"); - // TODO(phawkins): the standard implementation of ExpandDimsOp seems to - // accept legacy scalars, even when they should be forbidden by the graphdef - // version. - OP_REQUIRES(ctx, dim_shape.num_elements() == 1, + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); + OP_REQUIRES(ctx, dims.size() == 1, errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); - - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - - int dim = literal.data()[0]; + int dim = dims[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -148,7 +144,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 34980ead818..88da64e5a21 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" @@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel { OP_REQUIRES( ctx, - IsLegacyVector(begin_tensor_shape) && - IsLegacyVector(size_tensor_shape) && + TensorShapeUtils::IsVector(begin_tensor_shape) && + TensorShapeUtils::IsVector(size_tensor_shape) && begin_tensor_shape.num_elements() == input_shape.dims() && size_tensor_shape.num_elements() == input_shape.dims(), errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 230a343f796..7a0e240400b 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); - const TensorShape index_shape = ctx->InputShape(0); + const TensorShape split_dim_shape = ctx->InputShape("split_dim"); const TensorShape input_shape = ctx->InputShape(1); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(split_dim_shape), + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_shape.dims())); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig)); - int32 split_dim_orig; - if (index_shape.dims() == 0) { - split_dim_orig = literal_index.Get({}); - } else { - OP_REQUIRES( - ctx, index_shape.dims() == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - OP_REQUIRES( - ctx, index_shape.dim_size(0) == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - split_dim_orig = literal_index.Get({0}); - } int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() : split_dim_orig; OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), @@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel { // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; - std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); OP_REQUIRES(ctx, split_size_shape.dims() == 1 && @@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel { split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); // Get the dimension of this split. - xla::Literal split_size_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + std::vector split_sizes; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes)); for (int i = 0; i < num_split; ++i) { - int slice_size; - slice_size = split_size_literal.Get({i}); + int64 slice_size = split_sizes[i]; if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, @@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel { i)); neg_one_dim = i; } else { - split_sizes_vec[i] = slice_size; total_split_size += slice_size; } } @@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel { total_split_size)); if (neg_one_dim >= 0) { - split_sizes_vec[neg_one_dim] = + split_sizes[neg_one_dim] = input_shape.dim_size(split_dim) - total_split_size; } @@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel { std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); - int slice_size = split_sizes_vec[i]; + int slice_size = split_sizes[i]; output_shape.set_dim(split_dim, slice_size); // Slice out the ith split from the split dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d79cdad9fa2..7b96b43ad83 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp); +REGISTER_XLA_OP( + Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(), + StackOp); class StackPushOp : public XlaOpKernel { public: @@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); }; -REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); +REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp); class StackPopOp : public XlaOpKernel { public: @@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); }; -REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); +REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp); class StackCloseOp : public XlaOpKernel { public: @@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); }; -REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); +REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 7b2cd5a5b08..e1c764f3d5c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" @@ -44,7 +45,7 @@ class TileOp : public XlaOpKernel { const TensorShape multiples_shape = ctx->InputShape("multiples"); OP_REQUIRES( - ctx, IsLegacyVector(multiples_shape), + ctx, TensorShapeUtils::IsVector(multiples_shape), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples_shape.DebugString())); OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 48a211942d7..c9b324a243e 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel { : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape perm_tensor_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("x"); + const TensorShape perm_tensor_shape = ctx->InputShape("perm"); // Preliminary validation of sizes. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), @@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel { ". But input(1) is a vector of size ", perm_tensor_shape.num_elements())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); - - std::vector perm(dims); - std::copy(literal.data().begin(), literal.data().end(), - perm.begin()); + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm)); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { - const int32 d = perm[i]; + const int64 d = perm[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); @@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel { xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - transposed = ctx->Input(0); + transposed = ctx->Input("x"); } else { - transposed = xla::Transpose(ctx->Input(0), transposed_order); + transposed = xla::Transpose(ctx->Input("x"), transposed_order); } // Conjugate the transposed result if this is ConjugateTransposeOp. diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0bdfc057261..a0ea6422d73 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); XLAJIT_MAKE_UNARY(Neg, -x); -// Implements Banker's rounding: numbers that are equidistant between two -// integers are rounded towards even. -xla::XlaOp RoundToEven(xla::XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); - - auto round_val = xla::Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); -} - -XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 20103ec3ae0..67d08290033 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal)); + return literal.Clone(); +} + Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { xla::Shape xla_shape; diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 1db7470ee2a..a153dddee61 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,11 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); + +// Returns a Literal with the contents of 'host_tensor', backed by its own +// storage (i.e., not reusing 'host_tensor's buffers.) +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor); + // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // owned by 'host_tensor', but is mutable via the xla::Literal methods. Status HostTensorToMutableBorrowingLiteral( diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 8b559c87506..c9f486edc8d 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ "//learning/deepmind/public/wavenet/python:__subpackages__", + "//learning/deepmind/research/alphastar:__subpackages__", "//learning/tfx:__subpackages__", "//tensorflow:internal", ], diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index cb7843850c3..ddb284966ee 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto( "XLACompilationDevice::MakeTensorFromProto should not be called"); } -XlaExpression::XlaExpression() = default; - -void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index a6e78825334..de6a3356e05 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" @@ -38,8 +35,8 @@ class XlaCompilationAllocator; // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that -// backs each Tensor with metadata indicating the computation the -// Tensor represents. +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. // // We deliberately don't register a device factory because we *never* // want placement to put Ops on a compilation device. The device is created @@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// A XlaExpression wraps an XLA computation. Each Tensor on an -// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the XlaOp. Each -// expression is either a constant, or a function of previously-compiled -// expressions. -class XlaExpression { - public: - XlaExpression(); - - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::XlaOp& h); - const xla::XlaOp& handle() const { return handle_; } - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - void set_resource(XlaResource* resource) { resource_ = resource; } - XlaResource* resource() const { return resource_; } - - private: - // The XLA handle of the expression's computation. - xla::XlaOp handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_ = false; - Tensor constant_value_; - - XlaResource* resource_ = nullptr; // Not owned. -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e177a5f07f5..a08d030ce71 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,10 +36,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -48,7 +51,7 @@ namespace { // Checks that arguments `args` match types `types`. Status CheckSignature(const DataTypeVector& types, - const std::vector& args) { + absl::Span args) { if (args.size() != types.size()) { return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); @@ -63,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } +// Uses the _Arg and _Retval nodes in the graph to determine a core assignment +// for each argument and return value. +xla::StatusOr, std::map>> +ComputeArgAndRetvalCores(const Graph& graph) { + auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + std::map arg_cores; + std::map retval_cores; + for (const Node* n : graph.nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Arg index"; + arg_cores[index] = core; + } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Retval index"; + TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); + retval_cores[index] = core; + } + } + return std::make_pair(std::move(arg_cores), std::move(retval_cores)); +} + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status status; + auto step_container = absl::make_unique( + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return Status::OK(); +} + +// Builds the XLA computation. +// - `args` is the list of input arguments +// - `retvals` is the list of retvals produced by _Retval operators, in index +// order. +// - `args_core` and `retval_cores` are mapping from arg/return indices to core +// assignments. +// - If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// - Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// - `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a ResourceUpdate, whose `index` is the index of a +// resource variable argument to the computation to be updated, and `type` is +// the type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& retvals, + const std::map& arg_cores, const std::map& retval_cores, + const std::vector>& resources, + std::unique_ptr token_output, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, + std::vector* resource_updates) { + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); + + // Builds a no-op XLA computation. We need to set the sharding of outputs, but + // cannot change the sharding of the existing output op. To do this, we build + // a new identity op to which shardings can be applied. + auto identity_op = [builder](xla::XlaOp op) { + return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); + }; + + std::vector elems; + elems.reserve(retvals.size()); + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + const XlaExpression& retval = retvals[i]; + output.type = retval.dtype(); + switch (retval.kind()) { + case XlaExpression::Kind::kConstant: + output.is_constant = true; + output.constant_value = retval.constant_value(); + output.shape = output.constant_value.shape(); + break; + + case XlaExpression::Kind::kXlaOp: { + output.is_constant = false; + TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); + xla::XlaOp value = retval.handle(); + auto it = retval_cores.find(i); + xla::XlaScopedShardingAssignment assign_sharding( + builder, it == retval_cores.end() + ? absl::optional() + : xla::sharding_builder::AssignDevice(it->second)); + if (shape_representation_fn) { + // If there is a shape representation function, reshape the output + // tensor to the shape given by the representation shape function. + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( + output.shape, output.type)); + value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + } else if (it != retval_cores.end()) { + // Apply the sharding to the output, if there is a core assignment. + value = identity_op(value); + } + elems.push_back(value); + break; + } + + case XlaExpression::Kind::kResource: + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); + output.shape = retval.resource()->shape(); + break; + + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument( + "Invalid expression returned by computation. " + "This probably means a return value was not set."); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + DCHECK_LT(resource->arg_num(), args.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + auto it = arg_cores.find(resource->arg_num()); + const int core = it == arg_cores.end() ? -1 : it->second; + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::XlaScopedShardingAssignment assign_sharding( + builder, core == -1 ? absl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Ensures the correct sharding is applied to the output. + handle = identity_op(handle); + + elems.push_back(handle); + } + } + + // If we have token output, append it as the last one. + if (token_output) { + elems.push_back(*token_output); + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + } // namespace bool XlaCompiler::Argument::operator==( @@ -83,6 +320,39 @@ bool XlaCompiler::Argument::operator==( return constant_value.tensor_data() == other.constant_value.tensor_data(); } +string XlaCompiler::Argument::HumanString() const { + string common; + if (!name.empty()) { + common = absl::StrCat(" name=", name); + } + absl::StrAppend(&common, " type=", DataTypeString(type), + " shape=", shape.DebugString()); + switch (kind) { + case kInvalid: + return "invalid"; + case kConstant: + return absl::StrCat("kind=constant", common, + " value=", constant_value.DebugString()); + case kResource: { + string output = absl::StrCat("kind=resource", common, " resource_kind=", + XlaResource::KindToString(resource_kind), + " initialized=", initialized); + if (tensor_array_size >= 0) { + absl::StrAppend(&output, " tensor_array_size=", tensor_array_size); + } + if (!tensor_array_gradients.empty()) { + absl::StrAppend(&output, " tensor_array_gradients=", + absl::StrJoin(tensor_array_gradients, ",")); + } + return output; + } + case kParameter: + return absl::StrCat("kind=parameter", common); + case kToken: + return absl::StrCat("token", common); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -110,8 +380,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = [](const TensorShape& shape, - DataType type) { return shape; }; + options_.shape_representation_fn = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -171,15 +446,16 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - std::vector args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction( + const XlaCompiler::CompileOptions& options, const NameAttrList& function, + absl::Span args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; - auto it = cache_.find({function_id, args}); + const std::vector arg_vector(args.begin(), args.end()); + auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; return Status::OK(); @@ -212,14 +488,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kArgOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kRetOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -235,7 +513,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; - cache_[{function_id, args}] = *result; + cache_[{function_id, arg_vector}] = *result; return Status::OK(); } @@ -247,25 +525,24 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape; if (is_entry_computation) { TF_ASSIGN_OR_RETURN( - shape, options_.shape_representation_fn(arg.shape, arg.type)); + *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); } else { - shape = arg.shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); } - return TensorShapeToXLAShape(arg.type, shape, xla_shape); + return Status::OK(); } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, - options_.shape_representation_fn(arg.shape, arg.type)); - return TensorShapeToXLAShape(arg.type, representation_shape, - xla_shape); + TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( + arg.shape, arg.type)); + + return Status::OK(); } case XlaResource::kTensorArray: { if (arg.tensor_array_size < 0) { @@ -314,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } -namespace { - -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the - // resource manager takes ownership via Create, and unrefs via Cleanup. We - // explicitly add a reference to ensure the refcount at entry is maintained at - // all exit points; Create and Cleanup are always called in this function. - // - // The Executor requires us to use ScopedStepContainer. We wrap it in a - // unique_ptr so we can capture the cleanup status in the end. - xla_context->Ref(); - Status status; - auto step_container = absl::make_unique( - step_id, [&status, device](const string& name) { - status = device->resource_manager()->Cleanup(name); - }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); - - GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); - TF_RETURN_IF_ERROR(graph_compiler.Compile()); - // Explicitly clean up the step container, to capture the cleanup status. - step_container.reset(); - return Status::OK(); -} - -// Builds the XLA computation. -// `args` is the list of input arguments, `retvals` is the list of retvals -// produced by _Retval operators, in index order. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - std::unique_ptr token_output, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else if (retval.resource() != nullptr) { - output.is_constant = false; - output.input_index = retval.resource()->arg_num(); - } else { - output.is_constant = false; - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - // Attach a common operator name as metadata. This has no semantic effect — it - // merely makes the HLO graph more readable when visualized via TensorBoard, - // since TensorBoard forms groups out of operators with similar names. - xla::OpMetadata retval_metadata; - retval_metadata.set_op_name("XLA_Retvals"); - builder->SetOpMetadata(retval_metadata); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = - modified || - !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); - } - } - - // If we have token output, append it as the last one. - if (token_output) { - elems.push_back(*token_output); - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. We *always* form a tuple here to ensure that - // the output value is the last thing added into the XLA computation, even - // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); - if (!always_return_tuple && elems.size() == 1) { - xla::GetTupleElement(tuple, 0); - } - builder->ClearOpMetadata(); - - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - std::vector* arg_cores, std::vector* arg_expressions, + const std::map& arg_cores, + std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); - *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -504,7 +622,7 @@ Status XlaCompiler::BuildArguments( arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); - arg_expression.set_resource(resource); + arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } @@ -516,7 +634,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kConstant: - arg_expression.set_constant_value(arg.constant_value); + arg_expression = XlaExpression::Constant(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -541,26 +659,6 @@ Status XlaCompiler::BuildArguments( *input_shapes = arg_shapes; } - // Use the _Arg nodes in the graph to resolve core assignments. - for (const Node* n : graph.nodes()) { - if (absl::string_view(n->type_string()) != "_Arg") continue; - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - TF_RET_CHECK(index >= 0 && index < args.size()) - << "_Arg out of bounds: " << index << " vs " << args.size(); - TF_ASSIGN_OR_RETURN( - auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == - xla::OpSharding::Type::OpSharding_Type_MAXIMAL); - const int core = sharding.value().tile_assignment_devices(0); - if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { - (*arg_cores)[index] = core; - } - } - } - // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -576,11 +674,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); for (int64 parameter : *input_mapping) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; + auto it = arg_cores.find(parameter); + const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); + xla::sharding_builder::AssignDevice(core); } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); @@ -589,7 +686,8 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -597,7 +695,8 @@ Status XlaCompiler::BuildArguments( } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -632,14 +731,14 @@ Status XlaCompiler::BuildArguments( // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. if (is_entry_computation) { - arg_expression.set_handle( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression = XlaExpression::XlaOp( + xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); } else { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; case XlaCompiler::Argument::kToken: { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; } case XlaCompiler::Argument::kConstant: @@ -653,46 +752,48 @@ Status XlaCompiler::BuildArguments( } Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, string const& name, - OpKernelContext* ctx, const std::vector& args, - CompilationResult* result) { + const XlaCompiler::CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result) { // TODO(b/74182462): We implement this by creating a new dummy Graph including // _Arg nodes, and let CompileGraph walk it. This could be optimized. std::unique_ptr graph(new Graph(OpRegistry::Global())); Status status; // First create the actual node we care about computing. - Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + Node* main_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); // Create dummy _Arg nodes. Link these to `node` and also via a control // dependency edge to the _SOURCE node. - for (int64 i = 0; i < ctx->num_inputs(); ++i) { + for (int64 i = 0; i < args.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); - Status status = NodeBuilder(name, "_Arg") - .ControlInput(graph->source_node()) - .Attr("T", ctx->input_dtype(i)) - .Attr("index", i) - .Finalize(graph.get(), &node); + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); graph->AddEdge(node, 0, main_node, i); } // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < ctx->num_outputs(); ++i) { + for (int64 i = 0; i < result_types.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); - Status status = NodeBuilder(name, "_Retval") + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) .Input(main_node, i) - .Attr("T", ctx->expected_output_dtype(i)) + .Attr("T", result_types[i]) .Attr("index", i) .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, name, std::move(graph), args, result); + return CompileGraph(options, node_def.name(), std::move(graph), args, result); } namespace { @@ -747,12 +848,38 @@ Status ValidateGraph(const Graph* graph, return Status::OK(); } +// Converts the value of any expressions whose values are known at compile-time +// to constants. +Status ResolveConstantExpressionsToConstants( + xla::Client* client, absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kXlaOp) { + TF_ASSIGN_OR_RETURN(absl::optional constant, + expression.ResolveConstant(client)); + if (constant.has_value()) { + expression = XlaExpression::Constant(*constant); + } + } + } + return Status::OK(); +} + +void ConvertConstantsToExpressions(xla::XlaBuilder* builder, + absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kConstant) { + expression = + XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); + } + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; @@ -774,13 +901,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); - std::vector real_args(args); + std::vector real_args(args.begin(), args.end()); int token_input_index = -1; std::unique_ptr token_output; if (options.add_token_input_output) { @@ -792,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, real_args.push_back(token_arg); } + std::map arg_cores; + std::map retval_cores; + TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), + ComputeArgAndRetvalCores(*graph)); + std::vector arg_expressions; - std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -843,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); + std::vector retvals = context->retvals(); + if (options.resolve_compile_time_constants) { + TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals))); + } else { + ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + } TF_RETURN_IF_ERROR(BuildComputation( - real_args, arg_cores, context->retvals(), context->resources(), - std::move(token_output), options.return_updated_values_for_all_resources, + real_args, retvals, arg_cores, retval_cores, context->resources(), + std::move(token_output), + options.is_entry_computation ? options_.shape_representation_fn + : ShapeRepresentationFn{}, + options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, &result->resource_updates)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2cc603a5801..63426124686 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" @@ -118,7 +121,7 @@ class XlaCompiler { // The type of the argument. If the argument is a resource, this // is the type of the variable's value, not DT_RESOURCE. - DataType type; + DataType type = DT_INVALID; // The shape of the argument. For: // * a parameter: the shape of the parameter. @@ -155,6 +158,9 @@ class XlaCompiler { std::set tensor_array_gradients; bool operator==(const Argument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -259,8 +265,7 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function(const TensorShape&, - DataType)> + typedef std::function(const TensorShape&, DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. @@ -316,22 +321,23 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - std::vector args, CompilationResult* result); + absl::Span args, + CompilationResult* result); // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result); - // Compiles a single Op, given by an OpKernelContext, into an + // Compiles a single Op, given by `node_def`, into an // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. - Status CompileSingleOp(const CompileOptions& options, string const& name, - OpKernelContext* ctx, - const std::vector& args, + Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result); // Returns the shape of the XLA parameter for an argument 'arg'. @@ -411,7 +417,8 @@ class XlaCompiler { Status BuildArguments(const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, std::vector* arg_cores, + XlaContext* context, + const std::map& arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4ef154f856b..aaee208f634 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -1018,9 +1019,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); @@ -1086,9 +1089,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 20e1ee2ddb3..43095fbb473 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } -// This is called by the Retval Op to associate a computed value -// with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { - VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); +void XlaContext::SetRetval(int index, const XlaExpression& expression) { + if (retvals_.size() <= index) { + retvals_.resize(index + 1); } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; + retvals_[index] = expression; } -Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal) { - VLOG(1) << "Adding retval index " << retval_index - << " with non-data-dependent tensor to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; - return Status::OK(); -} - -Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { - VLOG(1) << "Adding retval index " << retval_index << " with resource " - << resource->name() << ":" << resource->shape().DebugString() - << " to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - XlaExpression e; - e.set_resource(resource); - retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; - return Status::OK(); -} - -xla::XlaBuilder* XlaContext::builder() { return builder_; } - Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, @@ -133,7 +93,7 @@ Status XlaContext::CreateResource( return Status::OK(); } -xla::StatusOr XlaContext::RepresentationShape( +xla::StatusOr XlaContext::RepresentationShape( const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e9..dbfd344c9ba 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -46,9 +46,8 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. @@ -57,37 +56,19 @@ class XlaContext : public ResourceBase { XlaCompiler* compiler() const { return compiler_; } // Returns the XlaBuilder that Ops use for compiling new expressions. - xla::XlaBuilder* builder(); + xla::XlaBuilder* builder() { return builder_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector& retvals() { return retvals_; } + const std::vector& retvals() { return retvals_; } - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); - - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal); - - // As for Retval, but for return values that are resource handles. - Status AddResourceRetval(int retval_index, XlaResource* resource); + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -105,8 +86,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -140,31 +121,19 @@ class XlaContext : public ResourceBase { // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; - // If true, constant return values are returned as Tensors instead of - // run-time computation outputs. - const bool resolve_compile_time_constants_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - - // A function that describes how the shapes of - // a) argument and return value, for entry computations - // b) variables, for all computations, - // should be represented in XLA. Parameters/return values will be shaped - // according to this function, and reshaped back to/from their declared shapes - // for computations. Must be non-null. - const std::function(const TensorShape&, DataType)>* + // Describes the on-host shapes of parameters and return values. Also see: + // XlaDevice::Options::shape_representation_fn. + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc new file mode 100644 index 00000000000..ca0309166b7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_expression.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() = default; + +XlaExpression XlaExpression::Invalid() { + XlaExpression e; + e.kind_ = Kind::kInvalid; + return e; +} + +XlaExpression XlaExpression::Constant(Tensor value) { + XlaExpression e; + e.kind_ = Kind::kConstant; + e.dtype_ = value.dtype(); + e.constant_value_ = value; + return e; +} + +XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { + XlaExpression e; + e.kind_ = Kind::kXlaOp; + e.dtype_ = dtype; + e.handle_ = value; + return e; +} + +XlaExpression XlaExpression::Resource(XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + return e; +} + +string XlaExpression::HumanString() const { + switch (kind_) { + case Kind::kInvalid: + return "invalid"; + case Kind::kConstant: + return "constant"; + case Kind::kXlaOp: + return "xla_op"; + case Kind::kResource: + return "resource"; + } +} + +xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + switch (kind_) { + case Kind::kConstant: { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToBorrowingLiteral(constant_value_, &literal)); + return xla::ConstantLiteral(builder, literal); + } + case Kind::kXlaOp: + if (builder != handle_.builder()) { + return errors::InvalidArgument( + "Mismatched builders in XlaExpression::AsXlaOp"); + } + return handle_; + default: + return errors::InvalidArgument("AsXlaOp called on XlaExpression: ", + HumanString()); + } + }); +} + +xla::StatusOr> XlaExpression::ResolveConstant( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: + return {constant_value()}; + case Kind::kXlaOp: + break; + case Kind::kResource: + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveConstant called on XlaExpression: ", HumanString()); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + handle().builder()->IsConstant(handle())); + if (!is_constant) return {absl::nullopt}; + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildConstantSubGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor)); + return {tensor}; +} + +xla::StatusOr XlaExpression::GetShape() const { + switch (kind_) { + case Kind::kConstant: + return constant_value().shape(); + case Kind::kXlaOp: { + TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, + handle().builder()->GetShape(handle())); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + return shape; + } + case Kind::kResource: + return TensorShape({}); + case Kind::kInvalid: + return errors::InvalidArgument( + "GetShape() called on invalid XlaExpression"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 00000000000..bed6761d362 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + const Tensor& constant_value() const { return constant_value_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kConstant expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + xla::StatusOr> ResolveConstant( + xla::Client* client) const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + xla::StatusOr GetShape() const; + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp. + xla::XlaOp handle_; + + // The value of the constant, if kind_ == kConstant. + Tensor constant_value_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc new file mode 100644 index 00000000000..84202c93139 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class XlaExpressionTest : public ::testing::Test { + protected: + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + builder_ = absl::make_unique("acomputation"); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); + non_constant_op_ = xla::Parameter( + builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); + resource_ = absl::make_unique( + XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), + DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), + /*tensor_array_multiple_writes_aggregate=*/false); + } + + xla::Client* client_; + std::unique_ptr builder_; + Tensor constant_; + xla::XlaOp op_; + xla::XlaOp non_constant_op_; + std::unique_ptr resource_; +}; + +TEST_F(XlaExpressionTest, Kind) { + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind()); + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind()); + EXPECT_TRUE(XlaExpression::Kind::kConstant == + XlaExpression::Constant(constant_).kind()); + EXPECT_TRUE(XlaExpression::Kind::kXlaOp == + XlaExpression::XlaOp(op_, DT_INT32).kind()); + EXPECT_TRUE(XlaExpression::Kind::kResource == + XlaExpression::Resource(resource_.get()).kind()); +} + +TEST_F(XlaExpressionTest, HumanString) { + EXPECT_EQ("invalid", XlaExpression().HumanString()); + EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString()); + EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString()); + EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString()); + EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString()); +} + +TEST_F(XlaExpressionTest, AsXlaOp) { + xla::XlaOp op_as_op = + XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get()); + EXPECT_TRUE(op_.IsIdenticalTo(op_as_op)); + + xla::XlaOp const_as_op = + XlaExpression::Constant(constant_).AsXlaOp(builder_.get()); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, + builder_->BuildConstantSubGraph(const_as_op)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, + client_->ComputeConstant(computation)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), + value)); +} + +TEST_F(XlaExpressionTest, GetShape) { + EXPECT_FALSE(XlaExpression().GetShape().ok()); + EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok()); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape, + XlaExpression::Resource(resource_.get()).GetShape()); + EXPECT_EQ(TensorShape({}), resource_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape, + XlaExpression::XlaOp(op_, DT_INT32).GetShape()); + EXPECT_EQ(TensorShape({}), op_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape, + XlaExpression::Constant(constant_).GetShape()); + EXPECT_EQ(TensorShape({}), constant_shape); +} + +TEST_F(XlaExpressionTest, ResolveConstant) { + EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); + EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); + EXPECT_FALSE( + XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional op_constant, + XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); + ASSERT_TRUE(op_constant.has_value()); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + + TF_ASSERT_OK_AND_ASSIGN(absl::optional op_nonconstant, + XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) + .ResolveConstant(client_)); + EXPECT_FALSE(op_nonconstant.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional constant_constant, + XlaExpression::Constant(constant_).ResolveConstant(client_)); + ASSERT_TRUE(constant_constant.has_value()); + test::ExpectTensorEqual(constant_, *constant_constant); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dd3498ef7aa..8dd8def0549 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().valid() || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle(); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); return expression; } -// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. -static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { +// Assigns an XlaExpression to a tensor on an XLA compilation device. +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK(!expression->handle().valid()); - return const_cast(expression); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; } -// Retrieves the XlaOp from an input Tensor to an Op. This computation was -// constructed by an Op that executed previously and created the output Tensor -// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. -static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); +const XlaExpression& XlaOpKernelContext::InputExpression(int index) { + return *CastExpressionFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(int index) { - return GetComputationFromTensor(context_->input(index)); +const XlaExpression& XlaOpKernelContext::InputExpression( + absl::string_view name) { + return *CastExpressionFromTensor(GetInputTensorByName(name)); } -const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { - return GetComputationFromTensor(GetInputTensorByName(name)); +xla::XlaOp XlaOpKernelContext::Input(int index) { + return InputExpression(index).AsXlaOp(builder()); +} + +xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { + return InputExpression(name).AsXlaOp(builder()); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -125,77 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name, Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { - const Tensor& tensor = context_->input(index); - TensorShape new_shape(new_dims); - if (tensor.NumElements() != new_shape.num_elements()) { - return errors::InvalidArgument( - context_->op_kernel().name(), " input ", index, " has shape ", - tensor.shape().DebugString(), - " but was asked to be reshaped to incompatible shape ", - new_shape.DebugString()); - } - const XlaExpression* expression = CastExpressionFromTensor(tensor); - - auto copy_tensor_to_literal = [](const Tensor& tensor, - xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); - }; - - // If the tensor has a known constant value, there is no need to invoke XLA. - if (expression->has_constant_value()) { - Tensor temp(tensor.dtype()); - if (!temp.CopyFrom(expression->constant_value(), new_shape)) { - // This should never happen. The constant should have a shape compatible - // with the enclosing Tensor. - return errors::Internal("Incompatible shapes in ConstantInputReshaped."); - } - - return copy_tensor_to_literal(temp, constant_literal); - } - - // Make sure we treat zero-element tensors as constant. - if (new_shape.num_elements() == 0) { - Tensor temp(tensor.dtype(), new_shape); - - return copy_tensor_to_literal(temp, constant_literal); - } - - xla::XlaOp handle = expression->handle(); - if (new_shape != tensor.shape()) { - // Reshape the handle to the desired shape. - handle = xla::Reshape(handle, new_shape.dim_sizes()); - } - - // The XLA layout is specified minor to major, and TensorFlow's minor - // dimension is the last one. - std::vector layout_indices(new_shape.dims()); - std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - - xla::StatusOr is_constant = builder()->IsConstant(handle); - if (!is_constant.ok()) { - Status status = is_constant.status(); + XlaExpression e = InputExpression(index); + xla::StatusOr> constant_or_status = + e.ResolveConstant(compiler()->client()); + if (!constant_or_status.ok()) { + Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); return status; } - - if (!is_constant.ValueOrDie()) { + absl::optional constant = constant_or_status.ValueOrDie(); + if (!constant.has_value()) { return errors::InvalidArgument( "Input ", index, " to ", context_->op_kernel().type_string(), " operator must be a compile-time constant.\n" @@ -208,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped( "stateful operation such as a random number generator."); } - // Ask the XLA compiler to evaluate the data handle to a literal. - xla::StatusOr constant_graph = - builder()->BuildConstantSubGraph(handle); - if (!constant_graph.ok()) { - return errors::Internal( - "Error getting a compile-time constant graph for ", - context_->op_kernel().name(), " input ", index, - ".\nError: ", constant_graph.status().error_message()); + Tensor temp(constant->dtype()); + if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + constant->shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + TensorShape(new_dims).DebugString()); } - xla::StatusOr computed = compiler()->client()->ComputeConstant( - constant_graph.ValueOrDie(), &layout); - if (!computed.ok()) { - return errors::Internal("Error evaluating ", context_->op_kernel().name(), - " input ", index, - " as a compile-time constant.\nError: ", - computed.status().error_message()); - } - *constant_literal = std::move(computed).ValueOrDie(); + TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); return Status::OK(); } @@ -322,6 +259,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + absl::string_view name, std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -372,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(GetComputationFromTensor(input)); + handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -413,9 +359,12 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, XlaContext& xla_context = XlaContext::Get(ctx); TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, + xla::Shape representation_shape, xla_context.RepresentationShape(variable->shape(), variable->type())); - if (representation_shape == variable->shape()) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); + if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { *value = variable->value(); } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); @@ -455,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, - Tensor** output) { - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - if (expected_output_dtype(index) == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in its - // value. - // TODO(jpienaar): This should be refactored to stop masquerading - // XlaExpressions as Tensors. - *output = new Tensor(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, *output)); - context_->set_output(index, **output); - } else { - TensorShape tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); - TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); +void XlaOpKernelContext::SetOutputExpression(int index, + const XlaExpression& expression) { + Status status = [&] { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + Tensor* output = nullptr; + // Provides a special behavior for DT_VARIANT: a variant is treated as + // DT_UINT8 scalar as the type to allow mapping for variant to more generic + // types. + if (expression.dtype() == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in + // its value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, output)); + context_->set_output(index, *output); + } else { + TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); + TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); + } + AssignExpressionToTensor(output, expression); + return Status::OK(); + }(); + if (!status.ok()) { + SetStatus(status); } - return Status::OK(); } void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { - // Makes the host Tensor that will refer to the expression. - Tensor* output = nullptr; - auto shape_or = builder()->GetShape(handle); - if (!shape_or.ok()) { - SetStatus(shape_or.status()); - return; - } - - OP_REQUIRES_OK(context_, - allocate_output(index, shape_or.ValueOrDie(), &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); + SetOutputExpression( + index, + XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); } void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { - const TensorShape& shape = constant.shape(); - - xla::BorrowingLiteral literal; - OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - - xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); - CHECK(handle.valid()); - - // Make the Tensor that will refer to the expression. - Tensor* output = nullptr; - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); - expression->set_constant_value(constant); -} - -void XlaOpKernelContext::SetInvalidOutput(int index) { - Tensor* output = nullptr; - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape({}), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::XlaOp handle; - expression->set_handle(handle); + SetOutputExpression(index, XlaExpression::Constant(constant)); } void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { - Tensor* output = nullptr; - // The shape of the output tensor is the shape of the resource itself - // (i.e., a scalar), not the shape of the resource's value. - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape(), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_resource(resource); + SetOutputExpression(index, XlaExpression::Resource(resource)); } Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { @@ -570,10 +482,13 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, xla_context.RepresentationShape(shape, type)); - if (shape != representation_shape) { - handle = xla::Reshape(handle, representation_shape.dim_sizes()); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { + handle = xla::Reshape(handle, + xla::AsInt64Slice(representation_shape.dimensions())); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index aa00a454968..c06efa2c474 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -88,9 +88,9 @@ class XlaOpKernelContext { // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::XlaOp& Input(int index); + xla::XlaOp Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(absl::string_view name); + xla::XlaOp Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -111,14 +111,6 @@ class XlaOpKernelContext { Status ConstantInput(int index, xla::Literal* constant_literal); Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); - // Evaluates input `index`, reshapes it to `new_shape` if new_shape != - // InputShape(index), and stores it in `*constant_literal`. If the input - // cannot be evaluated, e.g., because it depends on unbound parameters, - // returns a non-Ok status. If InputShape(index).num_elements() != - // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, absl::Span new_dims, - xla::Literal* constant_literal); - // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); Status ConstantInputAsIntScalar(absl::string_view name, int64* out); @@ -134,6 +126,8 @@ class XlaOpKernelContext { // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. Status ConstantInputReshapedToIntVector(int index, std::vector* out); + Status ConstantInputReshapedToIntVector(absl::string_view name, + std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); @@ -148,6 +142,10 @@ class XlaOpKernelContext { Status ConstantInputList(absl::string_view name, std::vector* literals); + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + // Outputs int num_outputs() const { return context_->num_outputs(); } @@ -165,9 +163,8 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output `index` to an invalid value. - // Any subsequent attempt to consume this output will cause an error. - void SetInvalidOutput(int index); + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } @@ -255,10 +252,13 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); - // Wraps OpKernelContext's allocate_output method while providing special - // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the - // type to allow mapping for variant to more generic types. - Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + Status ConstantInputReshaped(int index, absl::Span new_dims, + xla::Literal* constant_literal); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9f00de708cc..dcd0e9c5c1f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + cpu_global_jit + ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = true; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; registration.compile_resource_ops = false; } return nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 45a40c0acc0..0bdd4a10854 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -66,19 +66,26 @@ class XlaOpRegistry { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. string compilation_device_name; - // Do operators assigned to this device require compilation? - bool requires_compilation; - - // If !requires_compilation, should we try to JIT operators on this device - // when XLA JIT compilation is enabled globally via the SessionOptions? - // (It is still possible to explicitly mark operators to JIT compile, even - // if enable_jit_by_default is false.) - bool enable_jit_by_default; + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? bool compile_resource_ops = false; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 63b09c8f02a..a322eb9015e 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,6 +26,19 @@ limitations under the License. namespace tensorflow { +/*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) { + switch (kind) { + case XlaResource::kInvalid: + return "invalid"; + case XlaResource::kVariable: + return "variable"; + case XlaResource::kStack: + return "stack"; + case XlaResource::kTensorArray: + return "tensorarray"; + } +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index aa9ce1b171f..857b9a928bb 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -35,6 +36,7 @@ class XlaResource { kTensorArray, kStack, }; + static absl::string_view KindToString(Kind kind); XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d6b60c5f991..91096cf1d04 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -68,7 +68,7 @@ cc_library( visibility = [":friends"], deps = [ ":xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", ], ) @@ -735,6 +735,70 @@ tf_cc_test( ], ) +cc_library( + name = "parse_flags_from_env", + srcs = ["parse_flags_from_env.cc"], + hdrs = ["parse_flags_from_env.h"], + deps = + [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "parse_flags_from_env_test", + srcs = ["parse_flags_from_env_test.cc"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "debug_options_flags", + srcs = [ + "debug_options_flags.cc", + "debug_options_parsers.h", + ], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "debug_options_parsers_test", + size = "small", + srcs = [ + "debug_options_parsers.h", + "debug_options_parsers_test.cc", + ], + deps = + [ + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 0cbe68d7efd..42da0ebf499 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -68,6 +68,7 @@ cc_library( deps = [ ":global_data", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", @@ -76,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -236,13 +236,13 @@ tf_cc_test( deps = [ ":xla_builder", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index f5f8d5c6b1f..eef2844e0df 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -210,11 +210,10 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const XlaComputation& computation, absl::Span arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteGraphRequest request; +StatusOr Client::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions* execution_options) { + CompileRequest request; *request.mutable_computation() = computation.proto(); if (execution_options == nullptr) { @@ -222,6 +221,34 @@ StatusOr> Client::Execute( } else { *request.mutable_execution_options() = *execution_options; } + if (request.execution_options().device_handles_size() > 1) { + return InvalidArgument( + "Compiling with multiple device handles is not supported. Use " + "'Execute' instead."); + } + + // The argument shapes affect how the computation is compiled. + for (const auto& arg_shape : argument_shapes) { + *request.add_input_shape_with_layout() = arg_shape; + } + + CompileResponse response; + VLOG(1) << "making compile request: " << request.ShortDebugString(); + Status s = stub_->Compile(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + TF_RET_CHECK(response.has_handle()); + return response.handle(); +} + +StatusOr> Client::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { + ExecuteRequest request; + *request.mutable_handle() = handle; for (GlobalData* argument : arguments) { CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); @@ -229,7 +256,7 @@ StatusOr> Client::Execute( ExecuteResponse response; VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->ExecuteGraph(&request, &response); + Status s = stub_->Execute(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -238,15 +265,62 @@ StatusOr> Client::Execute( if (execution_profile != nullptr) { *execution_profile = response.profile(); + } + + return absl::make_unique(stub_, response.output()); +} + +StatusOr> Client::Execute( + const XlaComputation& computation, absl::Span arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + if (execution_options != nullptr && + execution_options->device_handles_size() > 1) { + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); + } + } + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); + } + + // The argument shapes affect how the computation is compiled. + std::vector arg_shapes(arguments.size()); + for (int i = 0; i < arguments.size(); i++) { + TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); + } + + TF_ASSIGN_OR_RETURN(auto handle, + Compile(computation, arg_shapes, execution_options)); + + TF_ASSIGN_OR_RETURN(auto result, + Execute(handle, arguments, execution_profile)); + + if (execution_profile != nullptr) { if (VLOG_IS_ON(1)) { TF_ASSIGN_OR_RETURN( auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); + ExecutionStatsAsString(computation, *execution_profile)); VLOG(1) << execution_stats; } } - return absl::make_unique(stub_, response.output()); + return std::move(result); } StatusOr>> Client::ExecuteParallel( @@ -274,10 +348,11 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { + for (size_t i = 0; i < response.responses_size(); ++i) { outputs.push_back( absl::make_unique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { + if (i < computations.size() && + computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } } @@ -390,8 +465,7 @@ StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); + GetComputationStats(computation, GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 6f4d33c469f..d0ac4703c63 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -40,6 +40,31 @@ class Client { explicit Client(ServiceInterface* stub); virtual ~Client(); + // Compile the computation with the given argument shapes and returns the + // handle to the compiled executable. The compiled executable is cached on the + // service, and the returned handle can be used for exection without + // re-compile. + // * The shape and layout of the arguments being executed with will affect how + // the computation is compiled. If argument_shapes is empty, the parameters' + // shape and layout will be used in the compilation. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_options.device_handles should be empty. If you need + // non-empty device handles, call 'Execute' instead. + StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions* execution_options = nullptr); + + // Executes the compiled executable for the given handle with the given + // arguments and returns the global data that was produced from the execution. + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global // data that was produced from the execution. // * If execution_options is not nullptr, these options are passed to the diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d3d7edb42a3..08a887a6e46 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -265,6 +265,22 @@ XlaOp Digamma(XlaOp input) { return result; } +// Implements Banker's rounding: numbers that are equidistant between two +// integers are rounded towards even. +XlaOp RoundToEven(XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); + + auto round_val = xla::Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + round_val + one, round_val); +} + // Trigonometric functions. // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index a6cafd42077..3f06d04b9ae 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -51,6 +51,10 @@ XlaOp Lgamma(XlaOp input); // Computes an approximation of the digamma function. XlaOp Digamma(XlaOp input); +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + // Trigonometric functions // Computes the arc cosine of 'x'. diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 14c259a7fa2..ae2ea225d1a 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -136,5 +136,17 @@ XLA_TEST_F(MathTest, Digamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, RoundToEven) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-1.4, -1.5, -2.5, -0.5, 0, 0.5, 1.5, 2.5, 3.5, 4.5}); + RoundToEven(x); + + std::vector expected = {-1.0, -2.0, -2.0, -0.0, 0, + 0.0, 2.0, 2.0, 4.0, 4.0}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index feb2f8ec9da..e49451ca970 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -60,8 +60,8 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. // - // The given ExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. Status ValidateExecutionOptions( const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -69,8 +69,8 @@ class LocalExecutable { // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. // - // The given ServiceExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -114,8 +114,8 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr> Compile( const XlaComputation& computation, const absl::Span argument_layouts, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index f9c23b44810..0a587725d20 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2305,6 +2305,19 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, }); } +XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension)); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, + {operand}); + }); +} + StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(first_error_); @@ -3158,4 +3171,8 @@ XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { return builder->Iota(shape, iota_dimension); } +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) { + return operand.builder()->GetDimensionSize(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 908a616b4ea..68314a026ea 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -933,6 +933,8 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); + XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1355,6 +1357,8 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); }; // RAII-style object: sets the current sharding assignment in builder on @@ -2129,6 +2133,10 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& grad_output, float epsilon, int64 feature_index); +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + // Implementation details below this point. template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index dfe5fd5eb23..8aa85c3cd63 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -43,7 +43,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -54,7 +54,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -349,6 +349,15 @@ TEST_F(XlaBuilderTest, CollectivePermute) { EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); } +TEST_F(XlaBuilderTest, GetDimensionSize) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + GetDimensionSize(x, 1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc similarity index 96% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc rename to tensorflow/compiler/xla/debug_options_flags.cc index 3ed3afcfced..033887d7c11 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -13,17 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace legacy_flags { - namespace { DebugOptions* flag_values; @@ -101,8 +99,8 @@ void AllocateFlags() { [](string comma_separated_values) { auto* extra_options_map = flag_values->mutable_xla_backend_extra_options(); - impl::parse_xla_backend_extra_options(extra_options_map, - comma_separated_values); + parse_xla_backend_extra_options(extra_options_map, + comma_separated_values); return true; }; @@ -111,8 +109,8 @@ void AllocateFlags() { [](string reduce_precision_option_value) { HloReducePrecisionOptions* option_proto = flag_values->add_hlo_reduce_precision_options(); - return impl::parse_xla_reduce_precision_option( - option_proto, reduce_precision_option_value); + return parse_xla_reduce_precision_option(option_proto, + reduce_precision_option_value); }; flag_objects = new std::vector({ @@ -353,5 +351,4 @@ xla::DebugOptions GetDebugOptionsFromFlags() { return *flag_values; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h similarity index 81% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.h rename to tensorflow/compiler/xla/debug_options_flags.h index b53157f59c6..60e59abc2a2 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ #include @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Appends flag definitions for debug options to flag_list. void AppendDebugOptionsFlags(std::vector* flag_list); @@ -32,7 +31,6 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // first. xla::DebugOptions GetDebugOptionsFromFlags(); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/debug_options_parsers.h similarity index 94% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h rename to tensorflow/compiler/xla/debug_options_parsers.h index ee7eb019c07..80aadfd5ece 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ #include #include "absl/strings/numbers.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla.pb.h" namespace xla { -namespace legacy_flags { -namespace impl { template void parse_xla_backend_extra_options(T* extra_options_map, @@ -140,8 +138,6 @@ inline bool parse_xla_reduce_precision_option( return true; } -} // namespace impl -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/debug_options_parsers_test.cc similarity index 88% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc rename to tensorflow/compiler/xla/debug_options_parsers_test.cc index 6f197aec53c..8003c3496d5 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/debug_options_parsers_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" #include #include @@ -23,13 +23,12 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace xla { -namespace legacy_flags { // Test that the xla_backend_extra_options flag is parsed correctly. TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { std::unordered_map test_map; string test_string = "aa=bb,cc,dd=,ee=ff=gg"; - impl::parse_xla_backend_extra_options(&test_map, test_string); + parse_xla_backend_extra_options(&test_map, test_string); EXPECT_EQ(test_map.size(), 4); EXPECT_EQ(test_map.at("aa"), "bb"); EXPECT_EQ(test_map.at("cc"), ""); @@ -41,7 +40,7 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -56,7 +55,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot;"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -71,7 +70,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -84,7 +83,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -96,7 +95,6 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { diff --git a/tensorflow/compiler/xla/execution_options_util.cc b/tensorflow/compiler/xla/execution_options_util.cc index e83ff7cddd6..cf569863bbe 100644 --- a/tensorflow/compiler/xla/execution_options_util.cc +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" namespace xla { ExecutionOptions CreateDefaultExecutionOptions() { ExecutionOptions execution_options; - *(execution_options.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); + *(execution_options.mutable_debug_options()) = GetDebugOptionsFromFlags(); return execution_options; } diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index 5376a04669d..ded1e582b24 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -58,7 +58,7 @@ sess = tf.Session(config=config) > compiled for the CPU. JIT compilation for CPU operations must be done via > the manual method documented below. -#### Manual +#### Manual with experimental_jit_scope() JIT compilation can also be turned on manually for one or more operators. This is done by tagging the operators to compile with the attribute @@ -79,6 +79,16 @@ The `_XlaCompile` attribute is currently supported on a best-effort basis. If an operator cannot be compiled, TensorFlow will silently fall back to the normal implementation. +#### Manual with xla.compile() + +Unlike experimental_jit_scope() which silently falls back to normal Tensorflow +on uncompilable operator, xla.compile() returns an explicit error. This is +useful if you want more predictable behaviors from XLA compilation. + +Please see +[xla.compile() tutorial Colab](https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb) +for how to use it. + ### Placing operators on XLA devices Another way to run computations via XLA is to place an operator on a specific diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index a3cdfe19b2e..73a9db75f6b 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1339,6 +1339,22 @@ the semantics for `tf.gather_nd`. index `X` in the gather indices array picks an entire row and the result is the concatenation of all these rows. +## GetDimensionSize + +See also +[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the size of the given dimension of the operand. The operand must be +array shaped. + + `GetDimensionSize(operand, dimension)` + +| Arguments | Type | Semantics | +| ----------- | ------- | --------------------------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | +: : : dimension : + ## GetTupleElement See also diff --git a/tensorflow/compiler/xla/g3doc/overview.md b/tensorflow/compiler/xla/g3doc/overview.md index 6a172c3ae15..d3428b72761 100644 --- a/tensorflow/compiler/xla/g3doc/overview.md +++ b/tensorflow/compiler/xla/g3doc/overview.md @@ -4,11 +4,8 @@ -> Note: XLA is experimental and considered alpha. Most use cases will not -> see improvements in performance (speed or decreased memory usage). We have -> released XLA early so the Open Source Community can contribute to its -> development, as well as create a path for integration with hardware -> accelerators. +> Note: XLA is still under development. Some use cases will not +> see improvements in speed or decreased memory usage. XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes TensorFlow computations. The results are improvements in diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb new file mode 100644 index 00000000000..a83e3f78598 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -0,0 +1,412 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "xD_ydfejEV7H" + }, + "outputs": [], + "source": [ + "#@title MIT License\n", + "#\n", + "# Copyright (c) 2017 François Chollet\n", + "#\n", + "# Permission is hereby granted, free of charge, to any person obtaining a\n", + "# copy of this software and associated documentation files (the \"Software\"),\n", + "# to deal in the Software without restriction, including without limitation\n", + "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", + "# and/or sell copies of the Software, and to permit persons to whom the\n", + "# Software is furnished to do so, subject to the following conditions:\n", + "#\n", + "# The above copyright notice and this permission notice shall be included in\n", + "# all copies or substantial portions of the Software.\n", + "#\n", + "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", + "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", + "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", + "# DEALINGS IN THE SOFTWARE." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1oSi4lHFt3z" + }, + "source": [ + "# Welcome to `xla.compile()` tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/jit#turning_on_jit_compilation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "v9YbsuLZaBXy" + }, + "source": [ + "xla.compile() is a new experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/).\n", + "\n", + "Please run all code blocks in order." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "45kUPj5ZFrRa" + }, + "outputs": [], + "source": [ + "import tensorflow as tf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9NMQFjroSMns" + }, + "source": [ + "Imports XLA library, which includes xla.compile() experimental API." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-Uggy03rSGJm" + }, + "outputs": [], + "source": [ + "from tensorflow.contrib.compiler import xla" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GZVNiRmTDV-5" + }, + "source": [ + "Define some necessary constants and prepare MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "f37TSEGvGX4_" + }, + "outputs": [], + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "TiVXchblG5hK" + }, + "outputs": [], + "source": [ + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n", + "\n", + "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n", + "images, labels = iterator.get_next()\n", + "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", + "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x_ZehpZP-SfS" + }, + "source": [ + "## Defines build_mnist_model function to construct model\n", + "\n", + "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", + "\n", + "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ZbhJl_WvGa3g" + }, + "outputs": [], + "source": [ + "def build_mnist_model(x, y_):\n", + " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", + "\n", + " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n", + " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", + "\n", + " return y, train_step" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7Jh3lyQHDfM9" + }, + "source": [ + "## Uses xla.compile with build_mnist_model function to enable XLA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EtDwez_1gjzv" + }, + "source": [ + "Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kYpCXCdRHNuN" + }, + "outputs": [], + "source": [ + "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4giQh62IrZGF" + }, + "source": [ + "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", + "\n", + "xla.compile does not return any\n", + "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n", + "\n", + "In pseudo-code, xla.compile's implementation looks as follows:\n", + "\n", + "---\n", + "```\n", + "# Ask Tensorflow to execute code in XLA-friendly manner\n", + "\n", + "y, train_step = build_mnist_model(images, labels)\n", + "with tf.control_dependencies([train_step]):\n", + " y = tf.identity(y)\n", + "\n", + "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n", + "```\n", + "---\n", + "\n", + "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "TPGas4jjFLZl" + }, + "source": [ + "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EZD1m_n1DxAF" + }, + "source": [ + "## Trains and tests the model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "qe28bAHNHUG2" + }, + "outputs": [], + "source": [ + "# Creates session and initialize all variables.\n", + "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", + "sess = tf.Session()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qgsKmz3n2UiW" + }, + "source": [ + "Following code block trains model.\n", + "\n", + "Note that evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_GxF6jTRHVuA" + }, + "outputs": [], + "source": [ + "# Feeds training dataset\n", + "sess.run(iterator.make_initializer(train_ds))\n", + "\n", + "# Runs TRAIN_STEPS steps\n", + "for i in range(TRAIN_STEPS):\n", + " sess.run(y)\n", + "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "dHlQlRSRHXD1" + }, + "outputs": [], + "source": [ + "# Tests trained model\n", + "\n", + "# Feeds testing dataset\n", + "sess.run(iterator.make_initializer(test_ds))\n", + "\n", + "# Calculates accuracy\n", + "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ynJQIuzjHYOb" + }, + "outputs": [], + "source": [ + "# Cleans up session\n", + "sess.close()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "xla.compile() Tutorial", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 93522d2ca87..fa94d0afb4c 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -24,8 +24,7 @@ limitations under the License. namespace xla { namespace { -void SetMinorToMajorLayout(Shape* shape, - std::initializer_list dimensions) { +void SetMinorToMajorLayout(Shape* shape, std::vector dimensions) { shape->mutable_layout()->clear_minor_to_major(); for (auto dimension : dimensions) { shape->mutable_layout()->add_minor_to_major(dimension); @@ -122,7 +121,7 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { std::vector linear_indexes = {0, 1439999999, 1145567336, 43883404, 617295214, 1117613654}; - std::vector> minor_to_major_orders; + std::vector> minor_to_major_orders; minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0}); minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6}); minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3}); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD deleted file mode 100644 index 3e79129aafd..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ /dev/null @@ -1,82 +0,0 @@ -# Legacy command-line flags for the XLA libraries. - -# Please do not add more flags to this package. - -# The XLA libraries were written in an environment that allowed command-line -# flags to be scattered freely throughout the libraries. This model, while -# initially convenient, leads to a proliferation in unused command-line flags -# in tests and binaries, and serious problems in servers, where one might wish -# parameters to be different in independent RPC calls to the same routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "parse_flags_from_env", - srcs = ["parse_flags_from_env.cc"], - hdrs = ["parse_flags_from_env.h"], - deps = - [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "parse_flags_from_env_test", - srcs = ["parse_flags_from_env_test.cc"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "debug_options_flags", - srcs = [ - "debug_options_flags.cc", - "debug_options_parsers.h", - ], - hdrs = ["debug_options_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "debug_options_parsers_test", - size = "small", - srcs = [ - "debug_options_parsers.h", - "debug_options_parsers_test.cc", - ], - deps = - [ - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 80dfdb83c35..cb00a0ab16d 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1434,10 +1434,14 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case U8: return EqualElementsInternal(other, &multi_index); + case S16: + return EqualElementsInternal(other, &multi_index); case S32: return EqualElementsInternal(other, &multi_index); case S64: return EqualElementsInternal(other, &multi_index); + case U16: + return EqualElementsInternal(other, &multi_index); case U32: return EqualElementsInternal(other, &multi_index); case U64: @@ -1506,6 +1510,11 @@ bool LiteralBase::IsAll(int8 value) const { return AllElementsEqualValue(piece.data(), value); } return false; + case U16: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; case U32: if (value >= 0) { return AllElementsEqualValue(piece.data(), value); @@ -1518,6 +1527,8 @@ bool LiteralBase::IsAll(int8 value) const { return false; case S8: return AllElementsEqualValue(piece.data(), value); + case S16: + return AllElementsEqualValue(piece.data(), value); case S32: return AllElementsEqualValue(piece.data(), value); case S64: @@ -1739,12 +1750,16 @@ bool LiteralBase::IsZero(absl::Span indices) const { switch (shape().element_type()) { case U8: return Get(indices) == 0; + case U16: + return Get(indices) == 0; case U32: return Get(indices) == 0; case U64: return Get(indices) == 0; case S8: return Get(indices) == 0; + case S16: + return Get(indices) == 0; case S32: return Get(indices) == 0; case S64: @@ -1802,6 +1817,20 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case U16: + *proto->mutable_u16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_u16s()); + } + break; + case S16: + *proto->mutable_s16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_s16s()); + } + break; case F16: *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); @@ -1916,6 +1945,22 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case S16: { + const string& s(proto.s16s()); + TF_RET_CHECK(data().size() * sizeof(int16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case U16: { + const string& s(proto.u16s()); + TF_RET_CHECK(data().size() * sizeof(uint16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; case F16: { const string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 9d34d9d5041..b044f0ad73f 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -141,8 +141,10 @@ int64 RecursiveElementCount(const Shape& shape) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else { + } else if (ShapeUtil::IsArray(shape)) { return ShapeUtil::ElementsIn(shape); + } else { + return 0; } } diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 3511760ac1c..8cec37897a9 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -1394,6 +1394,28 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { EXPECT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, CopyFromProto_u16) { + uint16 u1(0xabcd); + uint16 u2(0x1234); + + const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12, + 0x34, 0x12, 0xcd, 0xab}; + LiteralProto p; + p.mutable_shape()->set_element_type(U16); + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(4); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + p.clear_u16s(); + p.set_u16s(uint16_vals, 8); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); + ASSERT_EQ(4, r.size()); + EXPECT_EQ(u1, r[0]); + EXPECT_EQ(u2, r[1]); + EXPECT_EQ(u2, r[2]); + EXPECT_EQ(u1, r[3]); +} + TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc similarity index 98% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc rename to tensorflow/compiler/xla/parse_flags_from_env.cc index 2a4e49b05aa..40481331b69 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried static const char kWS[] = " \t\r\n"; // whitespace @@ -202,5 +201,4 @@ void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { *pargv = &env_argv->argv; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h similarity index 90% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h rename to tensorflow/compiler/xla/parse_flags_from_env.h index b54482ad2ba..fe86ee687f8 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ +#define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ // This module exports ParseFlagsFromEnv(), which allows other modules to parse // flags from the environtment variable TF_XLA_FLAGS, or (if the first @@ -50,7 +50,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet // unrecognized flags passed in from the environment, and return its @@ -60,7 +59,6 @@ bool ParseFlagsFromEnv(const std::vector& flag_list); // Used only for testing. Not to be used by clients. void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ +#endif // TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc similarity index 96% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc rename to tensorflow/compiler/xla/parse_flags_from_env_test.cc index 138c0c852e2..edd6538402d 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include #include @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Test that XLA flags can be set from the environment. // Failure messages are accompanied by the text in msg[]. @@ -159,12 +158,11 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) { } } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { // Save name of binary so that it may invoke itself. - xla::legacy_flags::binary_name = argv[0]; + xla::binary_name = argv[0]; bool recursing = false; xla::int32 int_flag = 1; const std::vector flag_list = { @@ -173,7 +171,7 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); + bool parse_ok = xla::ParseFlagsFromEnv(flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 21685c4a5b9..63ac1c66492 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") py_library( name = "xla_client", @@ -66,6 +67,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -81,6 +83,7 @@ tf_py_wrap_cc( srcs = ["xla.i"], swig_includes = [ "local_computation_builder.i", + "//tensorflow/python:platform/base.i", ], deps = [ ":local_computation_builder", @@ -89,5 +92,7 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ], + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service:gpu_plugin", + ]), ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b1fae826ab1..4d2a37cfac3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -56,6 +57,12 @@ tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; +string* GetPlatformNameString() { + static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = + new string("Host"); + return platform_name_string; +} + Status InitializeReplicaCount(int replica_count) { if (replica_count < 1) { return InvalidArgument("Replica count must be >= 1; got %d.", @@ -72,17 +79,33 @@ Status InitializeReplicaCount(int replica_count) { return Status::OK(); } +Status InitializePlatformName(const string& platform_name) { + string* g_platform_name = GetPlatformNameString(); + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the platform name to %s, but a local XLA service was " + "previously created with a platform name of %s.", + platform_name, *g_platform_name); + } + TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); + *g_platform_name = platform_name; + return Status::OK(); +} + int GetReplicaCount() { tensorflow::mutex_lock lock(g_local_client_mutex); return g_replica_count; } LocalClient* GetOrCreateLocalClient() { + string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { return g_local_client; } LocalClientOptions options; + options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); CHECK(g_local_client != nullptr); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 82f84ddb35b..9e617c48bdc 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -39,6 +39,12 @@ namespace swig { // returned. Status InitializeReplicaCount(int replica_count); +// Initializes the platform name that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializePlatformName(const string& platform_name); + // Returns the replica count that is currently set, regardless of whether the // local XLA service has been instantiated yet or not. int GetReplicaCount(); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index c13d00d2530..feabfdb889c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -977,6 +977,7 @@ tensorflow::ImportNumpy(); %unignore xla; %unignore xla::swig; %unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 07e0e093255..92b0685dbba 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1371,6 +1371,18 @@ def initialize_replica_count(replica_count): c_api.InitializeReplicaCount(replica_count) +def initialize_platform_name(platform_name): + """Initializes the desired platform name to use on XLA service init. + + Args: + platform_name: string name of platform. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + c_api.InitializePlatformName(platform_name) + + def get_replica_count(): """Returns the current replica count used for the XLA service. diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 4e1435fa30a..d8123a6de28 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,11 +47,18 @@ namespace xla { }); } -::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, + const CompileRequest* arg, + CompileResponse* result) { return DelegateRPC( - [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); + [this, arg, result]() { return service_->Compile(arg, result); }); +} + +::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/, + const ExecuteRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Execute(arg, result); }); } ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index ca1b09b6480..3e586b288a5 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, + const CompileRequest* arg, + CompileResponse* result) override; + + ::grpc::Status Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 7b8ab158e13..66abf66cfd6 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Compile(const CompileRequest* request, + CompileResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteGraph(context, *request, response); + return grpc_stub_->Compile(context, *request, response); + }); +} + +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Execute(context, *request, response); }); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 8dfcb761387..f02b401399f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status Compile(const CompileRequest* request, + CompileResponse* response) override; + + Status Execute(const ExecuteRequest* request, + ExecuteResponse* response) override; Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 551ae895e05..e4f332cda22 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -128,11 +128,14 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. + // Compiles the provided computation into executable. Returns the handle of + // the executable. + rpc Compile(CompileRequest) returns (CompileResponse) {} + + // Invokes the provided executable with the provided global data passed as + // immutable arguments. The request contains the handle to the executable. // Returns global data output and execution timing. - rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { - } + rpc Execute(ExecuteRequest) returns (ExecuteResponse) {} // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index cd8c20d43ea..19b5c1ca25d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -124,7 +123,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -158,12 +156,12 @@ tf_cc_test( ":bfloat16_propagation", ":bfloat16_support", ":hlo", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -281,7 +279,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -323,7 +321,6 @@ cc_library( ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", - ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal", @@ -365,7 +362,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,6 +398,7 @@ cc_library( srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", @@ -420,7 +417,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -466,7 +462,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -519,7 +514,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -568,7 +562,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -591,7 +584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -603,11 +595,11 @@ cc_library( hdrs = ["platform_util.h"], deps = [ ":compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -647,6 +639,7 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", + ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -662,6 +655,7 @@ cc_library( ":source_map_util", ":stream_pool", ":transfer_manager", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -673,7 +667,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", @@ -730,12 +723,12 @@ cc_library( ":computation_layout", ":platform_util", ":service", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -811,6 +804,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", ], @@ -833,6 +827,7 @@ cc_library( ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", @@ -840,7 +835,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -1086,7 +1080,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1103,6 +1096,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_reachability", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1168,7 +1162,6 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1388,6 +1381,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", @@ -1428,7 +1422,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1504,7 +1497,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1556,7 +1548,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1593,7 +1584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1643,7 +1633,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1695,6 +1685,19 @@ cc_library( ], ) +tf_cc_test( + name = "while_loop_analysis_test", + srcs = ["while_loop_analysis_test.cc"], + deps = [ + ":while_loop_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1703,9 +1706,9 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", + ":hlo_query", ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1717,10 +1720,12 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":hlo", + ":hlo_dce", ":hlo_matchers", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -1751,7 +1756,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1779,7 +1784,7 @@ tf_cc_test( ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1824,7 +1829,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1858,7 +1862,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -2264,7 +2268,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2327,13 +2330,26 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], ) +cc_library( + name = "compilation_cache", + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], + deps = [ + ":executable", + ":hlo_module_config", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "layout_assignment", srcs = [ @@ -2403,14 +2419,13 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2528,7 +2543,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2595,7 +2609,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2657,7 +2670,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2698,7 +2711,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2737,7 +2749,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2809,10 +2820,9 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -3000,7 +3010,6 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -3279,6 +3288,8 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", ], ) @@ -3305,6 +3316,7 @@ cc_library( ":hlo", ":hlo_pass", ":tuple_util", + ":while_loop_analysis", ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -3324,7 +3336,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3354,7 +3366,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3415,7 +3427,7 @@ tf_cc_test( ":indexed_array_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", ], @@ -3512,7 +3524,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 85fc42f7475..89e62bd2f0d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -107,6 +108,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override; + Status HandleAnd(HloInstruction* logical_and) override; + Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcastConvert(HloInstruction* bitcast) override; @@ -141,6 +144,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) override; + Status HandleNegate(HloInstruction* negate) override; + + Status HandleNot(HloInstruction* logical_not) override; + + Status HandleOr(HloInstruction* logical_or) override; + Status HandlePad(HloInstruction* pad) override; Status HandlePower(HloInstruction* power) override; @@ -306,9 +315,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); - // Tries to simplify a slice(pad(...)) where the result of the slice is a - // scalar. - StatusOr TrySimplifySliceOfPad(HloInstruction* slice); + // Tries to simplify a slice where the result of the slice is a scalar. + StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + + // Tries to convert slice(reshape(X)) into reshape(slice(X)) + StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. @@ -423,6 +434,43 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); + // Simplify logical and + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A && True => A + VLOG(10) << "trying transform [A && True => A]: " + << logical_and->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + // True && A => A + VLOG(10) << "trying transform [True && A => A]: " + << logical_and->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // A && False => False + VLOG(10) << "trying transform [A && False => False]: " + << logical_and->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // False && A => False + VLOG(10) << "trying transform [False && A => False]: " + << logical_and->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. HloInstruction* op; @@ -1229,6 +1277,64 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { + // negate(negate(x)) => x + HloInstruction* x; + if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && + ReplaceInstructionIfSameShape(negate, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { + // not(not(x)) => x + HloInstruction* x; + if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && + ReplaceInstructionIfSameShape(logical_not, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); + + // Simplify logical or + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A || True => True + VLOG(10) << "trying transform [A || True => True]: " + << logical_or->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + // True || A => True + VLOG(10) << "trying transform [True || A => True]: " + << logical_or->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // A || False => A + VLOG(10) << "trying transform [A || False => A]: " + << logical_or->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // False || A => A + VLOG(10) << "trying transform [False || A => A]: " + << logical_or->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); @@ -1826,60 +1932,160 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifySliceOfPad( +StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { // Only try to do this for effective scalars. We could do the same for slicing // out larger pieces of padding (replacing with a broadcast of the padding // value), but this is probably not worth it. - if (!ShapeUtil::IsEffectiveScalar(slice->shape()) || - slice->operand(0)->opcode() != HloOpcode::kPad) { + if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { return false; } - VLOG(10) << "Trying to simplify scalar slice of pad"; - // Check there's no internal padding. Again, we could handle that too, since - // everything is statically known, but it's not worth it. - auto pad = Cast(slice->mutable_operand(0)); - auto padding_config = pad->padding_config(); - int64 rank = padding_config.dimensions_size(); - if (HasInteriorPadding(padding_config)) { - VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; - return false; + if (slice->operand(0)->opcode() == HloOpcode::kPad) { + VLOG(10) << "Trying to simplify scalar slice of pad"; + // Check there's no internal padding. Again, we could handle that too, since + // everything is statically known, but it's not worth it. + auto pad = Cast(slice->mutable_operand(0)); + auto padding_config = pad->padding_config(); + int64 rank = padding_config.dimensions_size(); + if (HasInteriorPadding(padding_config)) { + VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; + return false; + } + + // Check whether the scalar we're slicing out falls into the padding. + bool in_padding = [&]() { + for (int64 i = 0; i < rank; ++i) { + int64 start = slice->slice_starts(i); + int64 low = padding_config.dimensions(i).edge_padding_low(); + int64 data = pad->operand(0)->shape().dimensions(i); + if (start >= low && start < low + data) { + return false; + } + } + return true; + }(); + + if (in_padding) { + VLOG(10) << "Folding scalar slice of pad into padding value"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), + pad->mutable_padding_value()))); + return true; + } else { + // We already know the output of the slice is scalar. If the padded + // value is scalar, and it's not in the padding, then it's exactly the + // output value. + bool replaced = + ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + if (replaced) { + VLOG(10) << "Folding scalar slice of pad into padded value"; + } else { + VLOG(10) << "Not folding scalar slice of pad into padded value as they " + "have different shapes."; + } + return replaced; + } } - // Check whether the scalar we're slicing out falls into the padding. - bool in_padding = [&]() { - for (int64 i = 0; i < rank; ++i) { - int64 start = slice->slice_starts(i); - int64 low = padding_config.dimensions(i).edge_padding_low(); - int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { + VLOG(10) << "Trying to simplify scalar slice of concat"; + // Only do this for R1, there's no chance of this being useful otherwise. + if (ShapeUtil::Rank(slice->shape()) != 1) { + VLOG(10) << "Not folding, slice is not rank 1"; + return false; + } + HloConcatenateInstruction* concat = + Cast(slice->mutable_operand(0)); + int64 operand_start = 0; + int64 operand_num = 0; + // Weird loop structure to avoid annoying off-by-one errors. + while (true) { + TF_RET_CHECK(operand_num < concat->operand_count()); + const HloInstruction* operand = concat->operand(operand_num); + int64 next_operand_start = operand_start + operand->shape().dimensions(0); + if (next_operand_start > slice->slice_starts(0)) { + break; + } + operand_start = next_operand_start; + operand_num++; + } + + bool replaced = ReplaceInstructionIfSameShape( + slice, concat->mutable_operand(operand_num)); + if (replaced) { + VLOG(10) << "Folding scalar slice of concat into concat operand"; + } else { + VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), concat->mutable_operand(operand_num), + {slice->slice_starts(0) - operand_start}, + {slice->slice_starts(0) - operand_start + 1}, + slice->slice_strides()))); + } + return true; + } + + return false; +} + +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + +StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( + HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!IsUnstridedSlice(slice)) { + return false; + } + HloInstruction* reshape = slice->mutable_operand(0); + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + HloInstruction* new_slice_operand = reshape->mutable_operand(0); + int64 slice_rank = ShapeUtil::Rank(slice->shape()); + std::vector sliced_dims; + for (int64 i = 0; i < slice_rank; ++i) { + if (slice->slice_starts(i) != 0 || + slice->slice_limits(i) != reshape->shape().dimensions(i)) { + sliced_dims.push_back(i); + } + } + + if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && + slice->slice_starts(0) == 0) { + const Shape& new_slice_shape = new_slice_operand->shape(); + const int64 rank = ShapeUtil::Rank(new_slice_shape); + std::vector new_slice_starts(rank, 0); + std::vector new_slice_stides(rank, 1); + std::vector new_slice_limits(new_slice_shape.dimensions().begin(), + new_slice_shape.dimensions().end()); + int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); + for (int64 i = rank - 1; i >= 0; --i) { + if (slice_elements >= new_slice_limits[i]) { + if (slice_elements % new_slice_limits[i] != 0) { + return false; + } + slice_elements /= new_slice_limits[i]; + } else { + new_slice_limits[i] = slice_elements; + slice_elements = 1; } } - return true; - }(); - - if (in_padding) { - VLOG(10) << "Folding scalar slice of pad into padding value"; + HloInstruction* new_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(new_slice_shape.element_type(), + new_slice_limits), + new_slice_operand, new_slice_starts, new_slice_limits, + new_slice_stides)); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice, HloInstruction::CreateReshape(slice->shape(), - pad->mutable_padding_value()))); + slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); return true; - } else { - // We already know the output of the slice is scalar. If the padded - // value is scalar, and it's not in the padding, then it's exactly the - // output value. - bool replaced = - ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); - if (replaced) { - VLOG(10) << "Folding scalar slice of pad into padded value"; - } else { - VLOG(10) << "Not folding scalar slice of pad into padded value as they " - "have different shapes."; - } - return replaced; } + return false; } Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { @@ -1888,12 +2094,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } - auto is_unstrided_slice = [](const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); - }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && - is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { HloInstruction* operand_slice = slice->mutable_operand(0); std::vector new_slice_starts = slice->slice_starts(); std::vector new_slice_limits = slice->slice_limits(); @@ -1907,11 +2109,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_slice_starts, new_slice_limits, slice->slice_strides())); } - TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice)); + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); if (replaced) { return Status::OK(); } + TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + if (replaced) { + return Status::OK(); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 7b3e957fbcf..e4c4da1b0e7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -54,10 +53,11 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -67,18 +67,19 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { + auto m = CreateNewVerifiedModule(); Shape r0s32 = ShapeUtil::MakeShape(S32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -88,12 +89,12 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } @@ -166,6 +167,7 @@ TEST_F(AlgebraicSimplifierTest, SelectIdentical) { // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( @@ -180,7 +182,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); HloInstruction* param = builder.AddInstruction( @@ -193,17 +195,18 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reduce(param, zero)); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -213,18 +216,19 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Constant())); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -239,17 +243,18 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -261,17 +266,18 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloComputation* add_computation = nullptr; @@ -284,7 +290,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); HloInstruction* param0 = builder.AddInstruction( @@ -297,17 +303,18 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateBroadcast(r2f32, zero, {}))}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -319,64 +326,68 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14, 3.14, 4}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); } // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -386,18 +397,19 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A - Const is canonicalized to A + (-Const). TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -407,18 +419,19 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); } // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -432,14 +445,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -447,6 +460,7 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { // Test that A/(B/C) is simplified to (A*C)/B. TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -460,14 +474,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -475,6 +489,7 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -492,7 +507,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -500,7 +515,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -509,6 +524,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { // Test that A/exp(B) is simplified to A*exp(-B). TEST_F(AlgebraicSimplifierTest, DivOfExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -520,14 +536,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -535,6 +551,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { // Test that A/pow(B,C) is simplified to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfPower) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -548,14 +565,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -564,6 +581,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -577,14 +595,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -592,6 +610,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { // A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -602,11 +621,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Constant())); @@ -614,6 +633,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -627,10 +647,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Power(base, op::Multiply(exp1, exp2))); } @@ -638,6 +658,7 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + auto m = CreateNewVerifiedModule(); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -651,14 +672,15 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -668,18 +690,19 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A/1 is simplified to A for an array. TEST_F(AlgebraicSimplifierTest, DivOneArray) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -689,18 +712,19 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that complex(real(c), imag(c)) is simplified to c. TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); HloComputation::Builder builder(TestName()); @@ -713,18 +737,19 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -737,18 +762,19 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that imag(complex(r,i)) is simplified to i. TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -761,18 +787,19 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -788,18 +815,19 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } // Test that exp(A)/exp(B) is simplified to exp(A-B) TEST_F(AlgebraicSimplifierTest, ExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -813,14 +841,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -828,6 +856,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { // Test that exp(A)*exp(B) is simplified to exp(A+B) TEST_F(AlgebraicSimplifierTest, ExpMul) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -841,14 +870,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -856,6 +885,7 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { // Test that pow(exp(A), B) is simplified to exp(A*B) TEST_F(AlgebraicSimplifierTest, PowExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -867,14 +897,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -882,6 +912,7 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { // Test that ln(pow(A, B)) is simplified to ln(A)*B TEST_F(AlgebraicSimplifierTest, LnPow) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -893,14 +924,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -908,6 +939,7 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -917,19 +949,20 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B TEST_F(AlgebraicSimplifierTest, LnExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -945,14 +978,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -960,6 +993,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { // Test that pow(A, 0) where A is a scalar is simplified to the scalar // constant 1. TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -969,13 +1003,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -984,6 +1018,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). TEST_F(AlgebraicSimplifierTest, Pow0Vector) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -993,13 +1028,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -1012,6 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { // Test that pow(A, 1) is simplified to A. TEST_F(AlgebraicSimplifierTest, Pow1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1021,19 +1057,20 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. TEST_F(AlgebraicSimplifierTest, Pow2) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1043,19 +1080,20 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } // Test that pow(A, -1) is simplified to 1/A. TEST_F(AlgebraicSimplifierTest, PowNegative1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1065,13 +1103,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -1081,6 +1119,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { } TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); @@ -1113,17 +1152,18 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1148,24 +1188,25 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1182,17 +1223,18 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); - module().AddEntryComputation(builder.Build()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + m->AddEntryComputation(builder.Build()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Pad(param, op::Constant())); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -1206,39 +1248,41 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - module().AddEntryComputation(std::move(computation)); + m->AddEntryComputation(std::move(computation)); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), op); + EXPECT_THAT(m->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } // Test that copies are removed. TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1246,18 +1290,19 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1268,24 +1313,25 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param)); AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. EXPECT_THAT(computation->root_instruction(), op::Copy(param)); AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1293,19 +1339,20 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } // Test that empty operands of concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1322,7 +1369,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -1330,7 +1377,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -1338,6 +1385,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { // Test that reduce of concat is simplified. TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r3f32 = ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); @@ -1363,7 +1411,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); @@ -1373,11 +1421,11 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { builder.AddInstruction(HloInstruction::CreateReduce( reduce_shape, Concatenate, zero, {1, 2}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -1387,6 +1435,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1401,20 +1450,21 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that concat with a scalar broadcast becomes a pad. TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); @@ -1427,17 +1477,18 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } // Test that a simplification which changes layouts is not performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1445,7 +1496,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1455,7 +1506,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -1464,6 +1515,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { // Test that a simplification which preserves layouts is performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1471,7 +1523,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1481,7 +1533,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1490,6 +1542,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { // Test that a reshape which could be replaced with a bitcast is not if // add_bitcasts is false. TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1502,13 +1555,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1516,6 +1569,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { // Test transforming reshapes and transposes of rng. TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); @@ -1532,11 +1586,11 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { ShapeUtil::MakeShape(F32, {4}), transpose)) ->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. @@ -1547,6 +1601,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1578,7 +1633,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, @@ -1586,7 +1641,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - simplifier.Run(&module()).ValueOrDie(); + simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1597,6 +1652,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1613,13 +1669,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1637,11 +1694,12 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1655,19 +1713,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1681,19 +1740,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1706,19 +1766,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1733,18 +1794,19 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1757,13 +1819,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), @@ -1772,6 +1834,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {5}), "param0")); @@ -1780,20 +1843,21 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } // Test merging broadcast and reshape. TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); @@ -1802,19 +1866,20 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1823,20 +1888,21 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1845,14 +1911,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1860,6 +1926,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1868,14 +1935,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = @@ -1885,6 +1952,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1893,33 +1961,34 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction(HloInstruction::CreateIota( ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_TRUE( @@ -1927,18 +1996,19 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { } TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); auto result_shape = iota->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Iota()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -1948,37 +2018,39 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_EQ(Cast(computation->root_instruction()) @@ -1987,19 +2059,20 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); const int64 iota_dim = @@ -2009,19 +2082,20 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } @@ -2043,14 +2117,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2076,7 +2150,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2095,7 +2169,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -2110,14 +2184,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2133,14 +2207,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2162,14 +2236,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, /*start_indices=*/{2, 3}, /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); @@ -2178,6 +2252,57 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + const int64 dim2 = 13; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape, + /*start_indices=*/{0, 0, 0}, + /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {3600, 512}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {960, 512}), original_reshape, + /*start_indices=*/{0, 0}, + /*limit_indices=*/{960, 512}, /*strides=*/{1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); @@ -2185,11 +2310,11 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } @@ -2207,15 +2332,191 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { builder.AddInstruction(HloInstruction::CreateSort( ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, keys, {values0, values1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values0, values1)); } +// Test that A && True is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that True && A is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A && False is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that False && A is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that A || True is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that True || A is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that A || False is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that False || A is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Used for TEST_Ps that test merging (or not) of a kPad instruction into a // convolution's Window. struct ConvPaddingTestcase { @@ -2337,15 +2638,15 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(), lhs_pad, filter, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2455,15 +2756,15 @@ TEST_P(ConvFilterPaddingTest, DoIt) { input, rhs_pad, /*feature_group_count=*/1, window, dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2604,7 +2905,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2724,7 +3025,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -2734,10 +3035,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2763,7 +3064,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -2772,7 +3073,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2782,7 +3083,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2864,7 +3165,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2954,12 +3255,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2970,6 +3271,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { // Dots add computations to the parent module. Test that, when the HloModule's // computations are updated, then iterator invalidation doesn't occur // when running on subsequent computations. + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); HloComputation::Builder builder(TestName() + ".Dot"); HloInstruction* x = @@ -2991,15 +3293,16 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - module().AddEmbeddedComputation(std::move(dot_computation)); - module().AddEntryComputation(call_builder.Build()); + m->AddEmbeddedComputation(std::move(dot_computation)); + m->AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -3008,11 +3311,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -3021,6 +3324,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { // of its input equals the size of its output. In this case, the dynamic slice // is equal to its input. TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3032,10 +3336,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -3043,6 +3347,7 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3065,16 +3370,17 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( @@ -3085,12 +3391,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { builder.AddInstruction( HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_THAT(root->dimensions(), ElementsAre(2)); @@ -3098,6 +3404,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); @@ -3111,12 +3418,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); @@ -3124,6 +3431,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* iota = @@ -3131,12 +3439,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3144,6 +3452,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); HloInstruction* iota = @@ -3152,12 +3461,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3174,9 +3483,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3196,9 +3504,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3218,9 +3525,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3238,9 +3544,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3249,6 +3554,92 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { EXPECT_THAT(root, op::Parameter()); } +TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[2:3]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(1)); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_EQ(root->slice_starts(0), 1); + EXPECT_EQ(root->slice_limits(0), 2); +} + +TEST_F(AlgebraicSimplifierTest, NegateNegate) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + neg.0 = f32[2] negate(param.0) + ROOT neg.1 = f32[2] negate(neg.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + +TEST_F(AlgebraicSimplifierTest, NotNot) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = pred[2] parameter(0) + not.0 = pred[2] not(param.0) + ROOT not.1 = pred[2] not(not.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -3278,6 +3669,7 @@ class PadReduceWindowEffectiveBroadcastTest PadReduceWindowEffectiveBroadcastCase> {}; TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + auto m = CreateNewVerifiedModule(); const auto& param = GetParam(); // a and b are parallel bounds we can either turn into a B F S0 S1 or @@ -3326,7 +3718,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Window window = window_util::MakeWindow( @@ -3340,10 +3732,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3392,6 +3784,7 @@ class DotStrengthReductionTest public ::testing::WithParamInterface< ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + auto module = CreateNewVerifiedModule(); int m, k, n; bool transpose_lhs, transpose_rhs; PrimitiveType element_type; @@ -3421,10 +3814,10 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -3452,7 +3845,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // Test that we transform @@ -3460,6 +3853,7 @@ class DotOfConcatSimplificationTest // to // add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3498,10 +3892,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3519,6 +3913,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { // to // add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3562,10 +3957,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3590,6 +3985,7 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { // Test that DynamicUpdateSlice update param with any dimension equal to zero // gets removed. TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); HloInstruction* const operand = builder.AddInstruction( @@ -3602,11 +3998,11 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -3625,7 +4021,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -3634,6 +4030,7 @@ class DotOfGatherSimplificationTest // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3680,10 +4077,10 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3704,6 +4101,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3750,10 +4148,10 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a64..52ec1a794c5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { @@ -38,11 +37,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -61,11 +61,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -84,11 +85,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -107,11 +109,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -130,11 +133,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -153,11 +157,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f7ac8f54829..08cf8026177 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloVerifiedTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -59,14 +59,14 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { param0, param1, param2, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -101,14 +101,14 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { param1, param2, param3, param4, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie()); - for (auto* instruction : module().entry_computation()->instructions()) { + for (auto* instruction : m->entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 5f93740887a..4ce351acc2c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { +class BFloat16ConversionFoldingTest : public HloTestBase { protected: BFloat16ConversionFoldingTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -103,10 +103,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -138,10 +138,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -173,10 +173,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -203,10 +203,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { HloInstruction* convert1 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -216,7 +216,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("add"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -252,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cb075a5e38a..9f97d18c565 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloVerifiedTestBase { +class BFloat16NormalizationTest : public HloTestBase { protected: BFloat16NormalizationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -106,10 +106,10 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module)); + EXPECT_FALSE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -134,10 +134,10 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { HloInstruction* mul1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -164,10 +164,10 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { HloInstruction* sub1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -191,7 +191,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, reduce_comp_param0, reduce_comp_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(reduce_comp_builder.Build()); @@ -205,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -233,7 +233,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -263,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -272,7 +272,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -290,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -299,7 +299,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -314,7 +314,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); @@ -342,10 +342,10 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 0af71eaac96..5be7141aae4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloVerifiedTestBase { +class BFloat16PropagationTest : public HloTestBase { protected: BFloat16PropagationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. @@ -121,10 +121,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,6 +136,62 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) { + auto module = CreateNewVerifiedModule(); + + auto sub_builder = HloComputation::Builder("max"); + HloInstruction* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); + HloInstruction* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); + sub_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1)); + auto max_computation = module->AddEmbeddedComputation(sub_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + HloInstruction* rw = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + shape, add, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + window, max_computation)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0})); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(rw)); +} + // Tests that side-effecting all-reduce should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { auto module = CreateNewVerifiedModule(); @@ -186,10 +242,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -242,10 +298,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -281,10 +337,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction* dot = builder.AddInstruction( CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -310,10 +366,10 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -321,7 +377,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { // Tests that BF16 is propagated properly through fused computations. TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -356,7 +412,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -369,7 +425,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { // Tests that changes to BF16 that cannot be propagated outside a fusion are // discarded. TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -393,7 +449,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -408,7 +464,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { // (BF16, BF16) fusion_computation(F32 a, F32 b) // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -439,7 +495,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -458,7 +514,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { // on_true and on_false must match, so that as long as one of them is F32, the // other must be F32 as well. TEST_F(BFloat16PropagationTest, SelectOverTuples) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -489,7 +545,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -502,7 +558,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { // Tests that BF16 is propagated properly through a while computation with // non-tuple input/output. TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -545,7 +601,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -561,7 +617,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { // made to the while body and thus the fusion node inside it. TEST_F(BFloat16PropagationTest, ConditionPreventsPropagationForFusionInsideWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -610,7 +666,7 @@ TEST_F(BFloat16PropagationTest, auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -622,7 +678,7 @@ TEST_F(BFloat16PropagationTest, // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -690,7 +746,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -709,7 +765,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // Tests that BF16 is not propagated through multiple whiles that invoke the // same computation as long as one while prevents the propagation. TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -820,7 +876,7 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -859,10 +915,10 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( bf16_shape, HloOpcode::kAdd, convert0, convert1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -895,10 +951,10 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -941,10 +997,10 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 5b48f10505e..2b9502f63a8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -107,6 +108,21 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kSelect: case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + HloComputation* reduce_comp = hlo.called_computations()[0]; + for (HloInstruction* inst : reduce_comp->instructions()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) { + return false; + } + } + } + return true; + } default: break; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index ee4e5942731..40c012a5e42 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -641,7 +641,7 @@ Status BufferAssignment::ComputeSummaryStats() { bool schedule_complete = true; for (const auto& computation : module_->computations()) { if (!computation->IsFusionComputation()) { - const std::vector* sequence = + const HloInstructionSequence* sequence = liveness_->hlo_ordering().SequentialOrder(*computation); if (sequence == nullptr) { schedule_complete = false; @@ -1180,7 +1180,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); schedule.set_sequence(computation, *instruction_sequence); @@ -1215,7 +1215,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); @@ -1230,7 +1230,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, - HloInstructionSequence(*instruction_sequence), + *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 327211d3efd..b1fc50cb188 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloVerifiedTestBase { +class BufferAssignmentTest : public HloTestBase { protected: ~BufferAssignmentTest() override {} @@ -334,16 +334,16 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -358,17 +358,17 @@ TEST_F(BufferAssignmentTest, BufferForConst) { LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -387,10 +387,10 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate, param0, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -410,10 +410,10 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -439,10 +439,10 @@ TEST_F(BufferAssignmentTest, Basic) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -538,7 +538,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -553,7 +553,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -599,7 +599,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -622,7 +622,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -671,10 +671,10 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -706,7 +706,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // param0[100x10] ---> (map x+1) // // Builds the map function. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto map_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); auto inner_last = map_computation->root_instruction(); @@ -725,7 +725,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -761,7 +761,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // out-of-order reductions could overwrite an element before a use.) // // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); @@ -784,7 +784,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -812,7 +812,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // const4[f32[4]] --- tuple --- while[condition, body] // // Builds the nested condition and body. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto condition_computation = module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); auto body_computation = @@ -840,7 +840,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -878,7 +878,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { } TEST_F(BufferAssignmentTest, ExampleConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto true_computation = module->AddEmbeddedComputation( BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); auto false_computation = module->AddEmbeddedComputation( @@ -905,7 +905,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -941,9 +941,9 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto neg = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -970,9 +970,9 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1003,9 +1003,9 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1040,9 +1040,9 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1075,9 +1075,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1107,9 +1107,9 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1145,9 +1145,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1160,7 +1160,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Verify that buffers for embedded computations are properly marked as // thread-local and that embedded parameters are not marked as // is_entry_computation_parameter. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto vec_shape = ShapeUtil::MakeShape(F32, {42}); auto scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1191,7 +1191,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Allocations for the map computation should be thread-local and not // live-out. @@ -1238,9 +1238,9 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { ShapeUtil::MakeShape(S32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1274,9 +1274,9 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1318,9 +1318,9 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1332,9 +1332,9 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), ShapeUtil::MakeShape(S32, {101})}), /*operands=*/{}, /*custom_call_target=*/"foo_function")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1347,7 +1347,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { TEST_F(BufferAssignmentTest, TupleCallAsOutput) { // Test a computation which returns a tuple call value. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1365,7 +1365,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1388,7 +1388,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { // B: call(C, param) // C: call(D, param) // D: param - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1427,7 +1427,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1461,9 +1461,9 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto bitcast = builder.AddInstruction( HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1488,9 +1488,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, pred_param, tuple_param0, tuple_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1526,9 +1526,9 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto copy = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape, HloOpcode::kCopy, tuple_element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1568,9 +1568,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module, /*alignment=*/1); + auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1589,7 +1589,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module, /*alignment=*/64); + assignment = RunBufferAssignment(module.get(), /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1632,10 +1632,10 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = @@ -1673,11 +1673,11 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module, {param, log, rev, neg, concat, root}); + module.get(), {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1698,7 +1698,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { } TEST_F(BufferAssignmentTest, PeakBuffersWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); HloComputation* condition; { @@ -1733,7 +1733,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1783,13 +1783,13 @@ ENTRY main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* constant_1 = - module().entry_computation()->GetInstructionWithName("constant.1.1"); + m->entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module().entry_computation()->GetInstructionWithName("constant.1.2"); + m->entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); { const BufferAllocation& allocation_for_const_1 = @@ -1818,7 +1818,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloVerifiedTestBase { +class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1878,7 +1878,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1917,8 +1917,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1974,20 +1974,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* param = - module().entry_computation()->parameter_instruction(0); + m->entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1995,7 +1994,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2042,20 +2041,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* constant = - module().entry_computation()->GetInstructionWithName("constant.42"); + m->entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -2063,7 +2061,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2121,7 +2119,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto token = builder.AddInstruction(HloInstruction::CreateToken()); @@ -2156,7 +2154,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module).status()); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry @@ -2175,12 +2173,12 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2202,7 +2200,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2234,8 +2232,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2247,7 +2245,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -2277,13 +2275,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); } - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2308,13 +2306,14 @@ ENTRY Main { )"; HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_text, config); + config.set_debug_options(GetDebugOptionsFromFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(hlo_text, config)); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); - HloComputation* main = module().entry_computation(); - HloComputation* callee = module().GetComputationWithName("Callee"); + HloComputation* main = m->entry_computation(); + HloComputation* callee = m->GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2338,7 +2337,7 @@ ENTRY Main { } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2385,11 +2384,11 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module); + RunCopyInsertion(module.get()); HloSchedule schedule = ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); @@ -2407,18 +2406,18 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run(module, - absl::make_unique(schedule), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2462,8 +2461,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 17e50905059..aeee543e843 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -514,7 +514,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 34f3f914d59..a3ac2568b0f 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloVerifiedTestBase { +class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -93,10 +93,10 @@ class CallGraphTest : public HloVerifiedTestBase { TEST_F(CallGraphTest, SingletonComputation) { // Test the call graph of a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -112,13 +112,13 @@ TEST_F(CallGraphTest, SingletonComputation) { TEST_F(CallGraphTest, UnreachableComputation) { // Test the call graph of a module with an entry computation and an // unreachable computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -134,13 +134,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { TEST_F(CallGraphTest, ParallelComputation) { // Test a call graph of a module with an entry computation which calls another // computation in a parallel context via kMap. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* map_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -163,13 +163,13 @@ TEST_F(CallGraphTest, ParallelComputation) { TEST_F(CallGraphTest, SequentialComputations) { // Test a call graph of a module with an entry computation which calls another // computation in a sequential context via kCall. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* called_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -196,7 +196,7 @@ TEST_F(CallGraphTest, SequentialComputations) { TEST_F(CallGraphTest, ContextBothComputations) { // Test a call graph of a module with an entry computation which calls another // computation in both a parallel and sequential context. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -239,7 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { TEST_F(CallGraphTest, ComputationWithConditional) { // Test a call graph of a module with a conditional. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* true_computation = module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); HloComputation* false_computation = @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(3, call_graph->nodes().size()); @@ -298,7 +298,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -418,7 +418,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -479,10 +479,10 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { TEST_F(CallGraphTest, VisitSingletonComputation) { // Test the call graph visitor with a call graph with a single node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -494,12 +494,12 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { TEST_F(CallGraphTest, VisitUnreachableComputation) { // Test the call graph visitor with a call graph with an unreachable node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // Test visitation of only reachable nodes. { @@ -531,9 +531,9 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e6b56654359..0b6e323f75c 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloVerifiedTestBase; +using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -51,7 +51,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloInstruction* one = inner.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -79,7 +79,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // returns false should be identical to just returning false). TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create a lambda that calls a function that returns the false predicate. // Note we also use this lambda twice by reference, just to make the test a @@ -107,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -120,7 +120,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // whole pass. TEST_F(CallInlinerTest, InlineWithoutRunningPass) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { const Shape f32 = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( @@ -163,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc new file mode 100644 index 00000000000..2662fe46705 --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_cache.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = counter++; + return id; +} + +} // namespace + +ExecutionHandle CompilationCache::Insert( + std::unique_ptr executable) { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = GetUniqueId(); + VLOG(2) << "inserting cache key: " << key; + CHECK_EQ(cache_.count(key), 0); + cache_.emplace(key, std::move(executable)); + + ExecutionHandle handle; + handle.set_handle(key); + return handle; +} + +StatusOr> CompilationCache::LookUp( + const ExecutionHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = handle.handle(); + VLOG(2) << "looking up cache key: " << key; + if (cache_.count(key) == 0) { + VLOG(2) << "cache key not found: " << key; + return InvalidArgumentStrCat("can not find executable with handle ", key); + } else { + auto& result = cache_.at(key); + VLOG(2) << "hit executable: " << result->module().name(); + return result; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h new file mode 100644 index 00000000000..5f94def509d --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A cache which stores Executables indexed by computation handle and version. +// +// TODO(b/119042872): Provide mechanism for removing computations from the +// compilation cache. +class CompilationCache { + public: + CompilationCache() {} + + ExecutionHandle Insert(std::unique_ptr executable); + + // Lookup the Executable for the specified handle in the cache. Return a + // shared_ptr to the Executable if it exists in the cache. + StatusOr> LookUp( + const ExecutionHandle& handle) const; + + protected: + mutable tensorflow::mutex mutex_; + + using CacheKey = int64; + + absl::flat_hash_map> cache_ + GUARDED_BY(mutex_); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 6d67f970020..67132274c0d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 80c630c6201..8f08c244908 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -110,6 +110,6 @@ Compiler::GetPlatformCompilers() { } AotCompilationOptions::AotCompilationOptions() - : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + : debug_options_(GetDebugOptionsFromFlags()) {} } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167d..289eb6d9023 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -37,7 +37,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalSimplifierTest : public HloVerifiedTestBase { +class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); @@ -96,25 +96,28 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { - HloComputation* computation = MakeConditional(&module()); - ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); + ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Constant())); } TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* true_op = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -125,11 +128,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -138,18 +142,19 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); false_computation->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0ac4a65ec6a..7f7f1503a09 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -51,7 +51,8 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(HloComputation* computation, + bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -59,18 +60,24 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(HloComputation* computation, + bool canonicalize_depthwise_filter = false) + : computation_(computation), + filter_expansion_(!canonicalize_depthwise_filter) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Whether filter expansion is required. + bool filter_expansion_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(HloComputation* computation, + bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -190,9 +197,49 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { HloInstruction* filter_mask = GetExpandedFilterMask( filter->shape(), input_feature_dim, output_feature_dim, group_count, add); HloInstruction* expanded_filter; - // We want to repeat 'filter' in the 'input_feature_dim' dimension - // 'group_count' times. + if (group_size == 1) { + bool depthwise_separable = + (group_count == filter->shape().dimensions(output_feature_dim)); + // If the code generator handles depthwise separable convolutions + // inherently, then no filter expansion is needed. + if (!filter_expansion_ && depthwise_separable) { + const int64 old_kernel_input_feature_dimension = + dim_numbers.kernel_input_feature_dimension(); + const int64 old_kernel_output_feature_dimension = + dim_numbers.kernel_output_feature_dimension(); + + // For depthwise convolutions, we want the kernel input feature dimension + // to be smaller than the output feature dimension. If that's not the + // case, we swap the dimensions. + if (old_kernel_input_feature_dimension > + old_kernel_output_feature_dimension) { + Shape reshaped_filter_shape = filter->shape(); + auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); + std::swap(dimensions[old_kernel_input_feature_dimension], + dimensions[old_kernel_output_feature_dimension]); + + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + dim_numbers.set_kernel_input_feature_dimension( + old_kernel_output_feature_dimension); + + dim_numbers.set_kernel_output_feature_dimension( + old_kernel_input_feature_dimension); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), + reshaped_filter, group_count, convolution->window(), dim_numbers, + convolution->precision_config()); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + } + return Status::OK(); + } + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. Shape reshaped_filter_shape = ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); auto reshaped_filter = @@ -237,7 +284,7 @@ StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(comp, filter_expansion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index ce0138e56fb..cb6bc04c00a 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -27,7 +27,8 @@ namespace xla { // convolutions with feature_group_count = 1. class ConvolutionFeatureGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter() {} + ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) + : filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { return "convolution-feature-group-converter"; @@ -36,6 +37,9 @@ class ConvolutionFeatureGroupConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + // Tells whether filter expansion is required. + bool filter_expansion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4533ebb99bb..7446bc7cc11 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -403,7 +403,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(CreateNewModule()) {} + WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1362,7 +1362,7 @@ TEST_F(CopyInsertionTest, CrossingParameters) { // | / \ | // | / \| // (p1 , p0) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1395,7 +1395,7 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1428,7 +1428,7 @@ TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1461,7 +1461,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1496,7 +1496,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { // | | | // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1534,7 +1534,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { // | Add----+ // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1569,7 +1569,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1632,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1693,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1783,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -1896,7 +1896,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); @@ -1936,7 +1936,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); @@ -2003,7 +2003,7 @@ std::unique_ptr MakeBenchmarkWhileBody( void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { tensorflow::testing::StopTiming(); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); CopyInsertion copy_insertion; const Shape element_shape = ShapeUtil::MakeShape(F32, {}); std::vector tuple_params(num_tuple_inputs); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 36e25cbe678..2763d18121a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -824,7 +824,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -846,7 +845,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -887,7 +885,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -961,17 +958,16 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -997,7 +993,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 2083f440fdd..c58175428fe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloVerifiedTestBase { +class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -87,7 +87,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -150,7 +150,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index c9fb34be1cd..c085f85fb73 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloVerifiedTestBase { +class CpuCopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -65,7 +65,7 @@ class CpuCopyInsertionTest : public HloVerifiedTestBase { TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). Each constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 3); @@ -103,7 +103,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { // Test a kCall instruction which calls a computation which produces a three // element tuple: one is a constant, one is a parameter, and one is produced // in the computation. The constant and parameter should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index e6b6fcdf684..9cbfb88834b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloVerifiedTestBase { +class CpuHloSupportCheckerTest : public HloTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(CpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 7d99b914d4f..c95a514ca04 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -58,7 +58,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -98,7 +98,7 @@ TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -119,7 +119,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -138,7 +138,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -157,7 +157,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -350,7 +350,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -392,7 +392,7 @@ TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -466,7 +466,7 @@ TEST_F(OpcodeFusionTest, Slice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -641,7 +641,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { builder.AddInstruction( HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -670,7 +670,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { builder.AddInstruction(HloInstruction::CreateBinary( large_shape, HloOpcode::kAdd, small_exp, large_param)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 97659b88a79..2cd52e4a18a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index d6968323f33..620c45fa391 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1536,7 +1536,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs, rhs}, {lhs->getType()}, b); @@ -1551,7 +1552,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs, rhs}, {lhs->getType()}, b); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fad76338a57..f0b65046c14 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class ParallelTaskAssignmentTest : public HloVerifiedTestBase { +class ParallelTaskAssignmentTest : public HloTestBase { protected: const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = cpu::CpuExecutable::ShapeSizeBytes; @@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { + : HloTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -57,8 +57,9 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -84,8 +85,9 @@ TEST_F(ParallelTaskAssignmentTest, } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -100,8 +102,9 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -116,8 +119,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 1a3d82de954..7d8e51f909e 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloVerifiedTestBase { +class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloVerifiedTestBase { +class ShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { +class RandomShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4b129c95d46..382dfd0d99d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,7 +48,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 18ee25ba915..691b3c7bee2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 00a7aa2ad2f..d201a151d7a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 1deb412064b..04a81dfd35f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloVerifiedTestBase { +class CpuFusionTest : public HloTestBase { protected: CpuFusionTest() {} @@ -57,11 +57,11 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -104,11 +104,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -131,7 +131,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal.shape(); @@ -183,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -250,12 +250,12 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); // Create computation and module. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -310,11 +310,11 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate1, negate2, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index a434c04a980..773336c7a92 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index b35fd9dad87..f5419b7063b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb6..64fb5031839 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -18,19 +18,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloVerifiedTestBase { +class DefuserTest : public HloTestBase { protected: // Returns the number of fusion instructions in the module. - int FusionCount() { + int FusionCount(const HloModule* m) { int count = 0; - for (HloComputation* computation : module().computations()) { + for (HloComputation* computation : m->computations()) { if (computation->IsFusionComputation()) { count++; } @@ -43,6 +43,7 @@ class DefuserTest : public HloVerifiedTestBase { }; TEST_F(DefuserTest, NoFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -51,13 +52,14 @@ TEST_F(DefuserTest, NoFusionInstruction) { builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - module().AddEntryComputation(builder.Build()); - EXPECT_EQ(0, FusionCount()); + m->AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount(m.get())); - EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie()); } TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -66,21 +68,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -91,21 +94,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add(op::Parameter(), op::Parameter()))); } TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -128,22 +132,23 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction( {add2, constant, div, mul, sub, negate, add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, MultipleFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -166,7 +171,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add2, constant, div, mul}, HloInstruction::FusionKind::kLoop); computation->CreateFusionInstruction({sub, negate, add}, @@ -174,15 +179,16 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, NestedFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -193,7 +199,7 @@ TEST_F(DefuserTest, NestedFusionInstructions) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); auto outer_fusion = computation->CreateFusionInstruction( {negate, add}, HloInstruction::FusionKind::kLoop); HloInstruction* fused_negate = outer_fusion->fused_expression_root(); @@ -203,9 +209,9 @@ TEST_F(DefuserTest, NestedFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 4159aa281fa..d6371283221 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -108,6 +108,7 @@ class DfsHloVisitorBase { virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 4cd10ab06cd..e57184f639f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -203,6 +203,9 @@ class DfsHloVisitorWithDefaultBase Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } + Status HandleGetDimensionSize(HloInstructionPtr get_size) override { + return DefaultAction(get_size); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 515267edd7c..f98c943669b 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1815,8 +1815,6 @@ StatusOr ElementalIrEmitter::EmitElementalGather( // Clamp the gather index so that the gather region fits in the operand. // gather_dim_component_extended_inbound = // clamp(gather_dim_component_extended, 0, largest_valid_start_index); - - // TODO(b/111078873): This is implementation defined behavior. bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); auto gather_dim_component_extended_inbound = EmitIntegralMin( index.GetConstantWithIndexType(largest_valid_start_index), diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 47c56e2f7fb..10b8c01ff13 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 3a6780f2a67..45f620f3f33 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 5fbd73a5363..8eeb930b481 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloVerifiedTestBase { +class FlattenCallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -108,7 +108,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module); + std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -149,7 +149,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // Test corner case of a computation used as a body and a loop condition. TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation; { HloComputation::Builder builder(TestName() + ".cond"); @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -201,7 +201,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { // C // TEST_F(FlattenCallGraphTest, FlattenCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* c_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -224,7 +224,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { } TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* sub_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1e8435fe542..b1629616acd 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -111,7 +111,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -463,7 +462,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -627,7 +626,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -849,7 +848,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -909,7 +907,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -1036,6 +1033,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index fa3afa6a5d3..af9303a5b76 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; +class CudnnConvPadForTensorCoresTest : public HloTestBase {}; TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index c46672c598b..4ce877f62a5 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -254,7 +254,7 @@ MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // TODO(b/119479517): Theoretically cuDNN supports grouped convolutions also // for the backward input convolution, but at least for now with version 7.1.4 // it is slower. This needs to be re-evaluated for future cuDNN versions. // Note that we already have the necessary code down below, the only thing to diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 87a835f2504..443883a89f6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -34,11 +34,11 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvRewriterTest : public HloVerifiedTestBase { +class CudnnConvRewriterTest : public HloTestBase { public: CudnnConvRewriterTest() - : HloVerifiedTestBase(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false) { + : HloTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -118,10 +118,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { metadata.set_op_name("foo"); conv->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -152,10 +152,10 @@ TEST_F(CudnnConvRewriterTest, activations, gradients, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -182,10 +182,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -212,10 +212,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -241,10 +241,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -292,10 +292,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { /*feature_group_count=*/1, conv_window, conv_dnums) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -338,10 +338,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -371,10 +371,10 @@ TEST_F(CudnnConvRewriterTest, default_conv_window_, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -425,10 +425,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -475,10 +475,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -529,10 +529,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -584,10 +584,10 @@ TEST_F(CudnnConvRewriterTest, conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -600,7 +600,8 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { constant_arr.FillIota(0); string constant_str = LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); - ParseAndVerifyModule(absl::StrFormat(R"( + + const string module_str = absl::StrFormat(R"( HloModule test ENTRY entry_computation { @@ -610,10 +611,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, dim_labels=bf01_01oi->bf01, feature_group_count=1 })", - constant_str)); - EXPECT_TRUE(RunPass(&module())); + constant_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); EXPECT_THAT( - module().entry_computation()->root_instruction(), + m->entry_computation()->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, op::Reverse(op::Constant())), 0)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 02a0d028c11..91609c730b6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -42,7 +42,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { // Only the entry computation can possibly be sequentially ordered, and only // if we've assigned all instructions to a single stream. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return &computation == module_->entry_computation() ? entry_sequence_.get() : nullptr; @@ -51,7 +51,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { string ToString() const override { return ToStringHelper("GpuHloOrdering"); } private: - std::unique_ptr> entry_sequence_; + std::unique_ptr entry_sequence_; }; GpuHloOrdering::GpuHloOrdering( @@ -60,8 +60,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = absl::make_unique>( - thunk_launch_order); + entry_sequence_ = + absl::make_unique(thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering( for (auto* computation : module->computations()) { if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, + HloReachabilityMap::Build(computation)); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index b857fa775a7..6d3aed15ebe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloVerifiedTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: using HloVec = std::vector; @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloVerifiedTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 7d01eeb0256..b511155f85f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloVerifiedTestBase { +class GpuHloSupportCheckerTest : public HloTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(GpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4822b820f4e..8cc76c872c6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 7f2b59810f0..43f43b50e4a 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 57e66f5a12c..fb77bc4b8eb 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -41,7 +41,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), exp1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -61,7 +61,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), negate1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -80,7 +80,7 @@ TEST_F(InstructionFusionTest, HloInstruction* reshape2 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -99,7 +99,7 @@ TEST_F(InstructionFusionTest, HloInstruction* transpose2 = builder.AddInstruction( HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -723,7 +723,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) { op::Reduce(op::Broadcast(op::Parameter()), op::Constant())); } +TEST_F(InstructionFusionTest, FuseReverse) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY Reverse { + p0 = f32[50,96,1024]{2,1,0} parameter(0) + add = f32[50,96,1024]{2,1,0} add(p0, p0) + ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reverse(op::Add(op::Parameter(), op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 21e44e1e7d3..87b6cd640ac 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2197,9 +2197,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - auto index_type = b_.getInt64Ty(); + CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); + CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); // Naive C++ code for the outer loops: // @@ -2213,42 +2214,119 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // } // } // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter + // This follows the alternative representation of the algorithm described on + // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter + // + // Each mask specifies how to derive from one position in the array the + // position with which it should be compared (we calculate the xor of the + // position with the mask). + // As an optimization, we can move the 'mask' loop to inside the + // sorting/comparison loop if the comparisons happen within a small block of + // the array. To make this work, we collect all consecutive masks that are + // smaller than our chosen power of 2 tile size, and pass them to SortInPlace. + // Each thread then processes one tile of data. + const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages); + + // If we cannot combine several xor masks together, we don't use tiling, so we + // calculate the standard launch dimensions for the shape. However we only + // need to iterate through ~half of the dimension to sort (rounded up to the + // next highest power of 2), because each iteration compares one pair of + // elements. + Shape standard_iteration_shape = keys_shape; + uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); + standard_iteration_shape.set_dimensions(dimension_to_sort, + standard_num_iterations_in_sort_dim); + LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( + standard_iteration_shape, ir_emitter_context_->device_description()); + + // Calculate the launch dimensions for the case where we use tiling. We split + // the dimension that should be sorted into tiles of size 'kTileSize'. This + // means we first need to round 'dimension_to_sort_bound' up to be a multiple + // of the tile size. + int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize); + Shape iteration_shape = keys_shape; + + // We iterate through the element pairs that should be compared. + uint64 num_iterations_in_sort_dim = rounded_bound / 2; + iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim); + uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape); + + // For correctness reasons we need exactly 'kTileSize' / 2 many threads per + // block. Each thread is responsible for copying exactly two adjacent elements + // into shared memory, and then does a comparison of two possibly different + // elements taken from shared memory. + const uint64 kThreadsPerBlock = kTileSize / 2; + + // Check whether we should use any tiling. We might not be able to use it if + // we have not enough threads, or not enough shared memory. Also it does not + // give a speedup if the tile size is < 128. + int64 total_shared_memory_needed = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + total_shared_memory_needed += + kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type()); + } + bool no_tiling = + kTileSize < 128 || + kThreadsPerBlock > + ir_emitter_context_->device_description().threads_per_block_limit() || + total_shared_memory_needed > + ir_emitter_context_->device_description().shared_memory_per_block(); + + uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); + LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + + auto emit_kernel = [&](absl::Span xor_masks) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = xor_masks.size() > 1 + ? tiled_launch_dimensions + : standard_launch_dimensions; + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } + return llvm_ir::EmitSortInPlace( + dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_masks, + &b_, launch_dimensions, + xor_masks.size() > 1 ? num_iterations_in_sort_dim + : standard_num_iterations_in_sort_dim, + kTileSize); + }; + std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { for (int64 mask = stage; mask >= 0; --mask) { - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys_shape, ir_emitter_context_->device_description()); - UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), - ir_emitter_context_->llvm_module()); - - llvm::Value* xor_mask; + int64 xor_mask; if (mask == stage) { - xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + xor_mask = (1LL << (stage + 1)) - 1; } else { - xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + xor_mask = 1LL << mask; } - - IrArray keys_array; - std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + if (xor_mask >= kTileSize || no_tiling) { + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + xor_masks.clear(); } + TF_RETURN_IF_ERROR(emit_kernel({xor_mask})); + } else { + xor_masks.push_back(xor_mask); } - TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_mask, - &b_, &launch_dimensions)); } } + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + } AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); @@ -3261,13 +3339,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( param->shape().element_type(), module_), kTileSize + 1), kTileSize); - const int kNVPTXSharedMemoryAddrSpace = 3; - auto* tile_base_ptr = new llvm::GlobalVariable( - *b_.GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), - llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, - llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); + auto* tile_base_ptr = llvm_ir::AllocateSharedMemoryTile( + b_.GetInsertBlock()->getParent()->getParent(), tile_type, + IrName(hlo, StrCat("tile", id))); param_shmem_buffers[id] = tile_base_ptr; VLOG(3) << "Added shmem buffer for parameter " << id << ": " << llvm_ir::DumpToString(*tile_base_ptr); @@ -3454,6 +3528,29 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( return launch_dimensions; } +namespace { +// Returns true to indicate it is safe to use the tile based shared memory +// transpose implementation to implement the kernel for the instruction. +// +// An instruction is not safe for such an implementation if it can change the +// element order of a tensor without changing the dimension of the tensor, and +// the instruction has a corresponding elemental_ir_emitter. +bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { + auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { + HloOpcode opcode = instr->opcode(); + CHECK_NE(opcode, HloOpcode::kFusion); + return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); + }; + + if (hlo->opcode() == HloOpcode::kFusion) { + return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), + is_safe_for_tile_based_transpose); + } + + return is_safe_for_tile_based_transpose(hlo); +} +} // namespace + bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); @@ -3498,6 +3595,10 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } + if (!IsInstructionSafeForTileBasedTranspose(hlo)) { + return false; + } + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb // shared memory per SM. (This is increased to 96kb in Volta, but we don't diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 9427d3d54ad..d9b06828e2b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -140,6 +140,18 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return false; } + // The emitter only supports in-place DUS for fusions with a single DUS at the + // root. Don't sibling fuse DUS for now. + // TODO(b/119178699): Multi-output fusing DUS can improve performance if we + // share the input and output buffers and add support to the emitter. + if (instr1->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice || + (instr2->opcode() == HloOpcode::kFusion && + instr2->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice)) { + return false; + } + // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 1d4856e0cae..dc221f22a74 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); @@ -621,5 +621,39 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { } } +TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { + auto module = ParseHloString(R"(HloModule dus_mof + fusion.1 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + fusion.2 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + ENTRY entry { + p.00 = f16[50,96,1024]{2,1,0} parameter(0) + p.01 = f16[50,96,1024]{2,1,0} parameter(1) + p.1 = s32[1]{0} parameter(2) + p.2 = f16[1,96,1024]{2,1,0} parameter(3) + + f1 = f16[50,96,1024] fusion(p.00, p.1, p.2), kind=kLoop, calls=fusion.1 + f2 = f16[50,96,1024] fusion(p.01, p.1, p.2), kind=kLoop, calls=fusion.2 + ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2) + })") + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5b6cf2c04d0..4775baf44ae 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -122,7 +122,7 @@ std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = - computation.ComputeReachability(); + HloReachabilityMap::Build(&computation); std::vector seen_gemms; // The execution of different RNG Hlo instructions in the same module updates // a common global variable. To avoid a race condition, we simply assign all diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index c4f43cc9a61..f2ef11e1e6a 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,16 +21,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloVerifiedTestBase { +class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index ed46f08d597..d798b316437 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -37,7 +37,7 @@ cc_library( hdrs = ["gpu_codegen_test.h"], tags = tf_cuda_tests_tags(), deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 79e77d4c4d6..9e3ff8750b8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace gpu { -std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { +std::unique_ptr GpuCodegenTest::CreateNewUnverifiedModuleWithFTZ( + bool ftz) { HloModuleConfig config; - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_gpu_ftz(ftz); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index e4a3573babb..d2f30ae7bc4 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,9 +26,9 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz - // option. - std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + // Like HloTestBase::CreateNewUnverifiedModule(), with a flag for configuring + // the ftz option. + std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); // Compiles the given HLO module to PTX and verifies the PTX matches the given // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 780539c1642..268b48a1cad 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index 177b94934c7..d0ccd8619bd 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -39,7 +39,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/1, param_shape, "y")); builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } @@ -54,7 +54,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/0, param_shape, "x")); builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index a06576df7b8..da8e513a2c3 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 15d1e269cc2..a302b582ede 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -193,6 +193,33 @@ TEST_F(GpuKernelTilingTest, /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { + const char *const kHloString = R"( + HloModule FusionTransposeWithReverseNotTiled + fused_computation.1 { + arg0 = f32[128,64]{1,0} parameter(0) + copy0 = f32[128,64]{0,1} copy(arg0) + ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0} + } + + ENTRY reverse_break_assumption { + param0 = f32[128,64]{1,0} parameter(0) + ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6a9ecd9dae7..ea1fee040dd 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 15198865bda..14285459b5a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 0f2d5568caf..4636f1d9d20 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -85,7 +85,7 @@ TEST_F(GpuUnrollingTest, UnrollFourTimes) { TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { // The default unrolling factor is 4. HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc index 5fa9e91050a..3d00ac4dc7b 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -32,7 +32,7 @@ namespace gpu { namespace { using match::Concatenate; -class VariadicOpSplitterTest : public HloVerifiedTestBase {}; +class VariadicOpSplitterTest : public HloTestBase {}; TEST_F(VariadicOpSplitterTest, DontSplit) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 926b59a1b85..c7f51127649 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewModule()), + : module_(CreateNewUnverifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index e30e7667f30..fad3215fc81 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,16 +30,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; +class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module); + HloSchedule schedule(module.get()); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -351,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloVerifiedTestBase { +class HeapSimulatorTest : public HloTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 5c8d97b2d15..7e6150e9415 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,17 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloVerifiedTestBase { +class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : HloVerifiedTestBase() { - module_ = CreateNewModule(); + HloAliasAnalysisTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_, + analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -93,7 +93,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -110,7 +110,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { return false; } - HloModule* module_; + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -638,7 +638,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_).status()); + TF_ASSERT_OK(flattener.Run(module_.get()).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1012,7 +1012,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -1054,13 +1054,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - HloSchedule schedule(module_); + HloSchedule schedule(module_.get()); schedule.set_sequence(body, {body_param, body_root}); schedule.set_sequence(condition, {cond_param, cond_root}); { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b0f7cd91ad1..0c20d207ddb 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,7 +322,7 @@ void HloComputation::ComputeInstructionPostOrder( // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. + // result for things like HLO stringification. const auto& operands = current->operands(); for (int64 i = operands.size() - 1; i >= 0; --i) { dfs_stack.emplace_back(operands[i]); @@ -739,72 +740,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -std::unique_ptr HloComputation::ComputeReachability() - const { - const auto& all = MakeInstructionPostOrder(); - auto result = absl::make_unique(all); - auto channel_dependency_map = ComputeChannelDependencies(); - - std::vector inputs; - for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); - - switch (hlo->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - break; - } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - } - break; - } - default: - break; - } - - result->FastSetReachabilityToUnion(inputs, hlo); - } - return result; -} - -void HloComputation::UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map) { - std::queue worklist; - worklist.push(instruction); - - std::vector inputs; - - while (!worklist.empty()) { - const HloInstruction* item = worklist.front(); - worklist.pop(); - - inputs.assign(item->operands().begin(), item->operands().end()); - inputs.insert(inputs.end(), item->control_predecessors().begin(), - item->control_predecessors().end()); - - if (reachability_map->SetReachabilityToUnion(inputs, item)) { - // Add immediate successors to worklist. - for (const HloInstruction* user : item->users()) { - worklist.push(user); - } - for (const HloInstruction* succ : item->control_successors()) { - worklist.push(succ); - } - } - } -} - std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; for (auto* instruction : instructions()) { @@ -911,14 +846,46 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - /*extras=*/{}, context, suffix); + context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + replacements.emplace(std::move(r3)); + return CloneWithReplacements(std::move(replacements), context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context, - const string& suffix) { + HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -939,18 +906,50 @@ std::unique_ptr HloComputation::CloneWithReplacements( }; VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + + // We want to do a postorder walk over [replace(i) for i in instructions_]. + // We can't reuse MakeInstructionPostOrder() for this, because that will + // generate a postorder of plain instructions_, and our replacements may + // change the postorder! + // + // The postorder we want here is simpler than what MakeInstructionPostOrder() + // does -- we only care about operand dependencies -- so let's just do it + // ourselves. std::vector postorder; - for (HloInstruction* instr : extras) { - postorder.push_back(instr); - } - for (HloInstruction* instr : MakeInstructionPostOrder()) { - if (HloInstruction* replacement = replace(instr)) { - postorder.push_back(replacement); + absl::flat_hash_map visited; + for (const auto& instr : instructions_) { + std::vector dfs_stack; + HloInstruction* new_instr = replace(instr.get()); + if (!new_instr) { + continue; + } + dfs_stack.push_back(new_instr); + + while (!dfs_stack.empty()) { + auto* cur = dfs_stack.back(); + auto it = visited.find(cur); + if (it != visited.end()) { + dfs_stack.pop_back(); + if (it->second == kVisited) { + continue; + } + CHECK_EQ(it->second, kVisiting); + postorder.push_back(cur); + it->second = kVisited; + continue; + } + + visited.insert({cur, kVisiting}); + for (HloInstruction* operand : cur->operands()) { + HloInstruction* new_operand = replace(operand); + if (new_operand) { + dfs_stack.emplace_back(new_operand); + } + } } } std::vector> instructions; - std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { @@ -960,9 +959,8 @@ std::unique_ptr HloComputation::CloneWithReplacements( << operand->ToString() << ", used by " << instr->ToString(); new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, context); - instructions.push_back(std::move(new_instr)); + instructions.push_back( + instr->CloneWithNewOperands(instr->shape(), new_operands, context)); } Builder builder(name() + "." + suffix); for (auto& instr : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dec96d11a93..fc7d2035e5b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" @@ -215,19 +214,6 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - std::unique_ptr ComputeReachability() const; - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this @@ -333,14 +319,38 @@ class HloComputation { // the map's value to replace that instruction in the cloned computation. // // If replacements maps a key to nullptr, we remove that instruction from the - // new computation. - // If additional instructions are used by instructions in replacement map, - // they must be passed in post-order in the extras span. + // new computation. If an element of `replacements` references an instruction + // that's not already in the computation, it's cloned and added to the new + // computation. + // + // All relevant instructions are cloned, *including* unique_ptr in the + // `replacements` map. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); + + // Convenience overloads for CloneWithReplacements. You want to do + // + // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR + // + // but that doesn't work because std::initializer_list is not movable. These + // overloads let you do + // + // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK + // + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -355,6 +365,14 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. bool HasSideEffect() const; @@ -410,14 +428,6 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = - absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; - enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 2aaaef1d36d..1e7a6e197f5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -151,7 +151,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,7 +257,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -274,7 +274,7 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -376,7 +376,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -393,7 +393,7 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -412,7 +412,7 @@ TEST_F(HloComputationTest, CycleDetection) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); @@ -440,7 +440,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); @@ -466,7 +466,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -484,107 +484,6 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } -TEST_F(HloComputationTest, Reachability) { - // Test reachability of a non-trivial computation: - // - // const1 const2 - // | | - // | +-------+ - // | | | - // add .. negate - // | . | - // | .... exp - // | | - // +---+ +-+---+ - // | | | - // multiply copy - // - // There is a control dependency from 'add' to 'exp'. - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - - auto module = CreateNewModule(); - auto computation = - module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); - - TF_CHECK_OK(add->AddControlDependencyTo(exp)); - auto reachability = computation->ComputeReachability(); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_TRUE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_TRUE(reachability->IsReachable(constant1, copy)); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_TRUE(reachability->IsReachable(constant2, negate)); - EXPECT_TRUE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_TRUE(reachability->IsReachable(constant2, copy)); - - EXPECT_FALSE(reachability->IsReachable(exp, constant1)); - EXPECT_FALSE(reachability->IsReachable(exp, constant2)); - EXPECT_FALSE(reachability->IsReachable(exp, add)); - EXPECT_FALSE(reachability->IsReachable(exp, negate)); - EXPECT_TRUE(reachability->IsReachable(exp, exp)); - EXPECT_TRUE(reachability->IsReachable(exp, mul)); - EXPECT_TRUE(reachability->IsReachable(exp, copy)); - - EXPECT_FALSE(reachability->IsReachable(mul, constant1)); - EXPECT_FALSE(reachability->IsReachable(mul, constant2)); - EXPECT_FALSE(reachability->IsReachable(mul, add)); - EXPECT_FALSE(reachability->IsReachable(mul, negate)); - EXPECT_FALSE(reachability->IsReachable(mul, exp)); - EXPECT_TRUE(reachability->IsReachable(mul, mul)); - EXPECT_FALSE(reachability->IsReachable(mul, copy)); - - EXPECT_TRUE(reachability->IsConnected(constant1, copy)); - EXPECT_TRUE(reachability->IsConnected(copy, constant1)); - EXPECT_FALSE(reachability->IsConnected(negate, add)); - EXPECT_FALSE(reachability->IsConnected(add, negate)); - - // Remove the control dependency then update and verify the reachability map - ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); - computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_FALSE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_FALSE(reachability->IsReachable(constant1, copy)); - - // Change a use within the graph then update and verify the reachability map - ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); - computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_FALSE(reachability->IsReachable(constant2, negate)); - EXPECT_FALSE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_FALSE(reachability->IsReachable(constant2, copy)); -} - TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,7 +505,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -641,7 +540,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -677,7 +576,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -700,27 +599,5 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -TEST_F(HloComputationTest, ChannelReachability) { - const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); - HloComputation::Builder builder("ChannelReachability"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); - auto send = - builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); - auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); - auto recv = - builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); - auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(recv_done)); - auto reachability = computation->ComputeReachability(); - EXPECT_TRUE(reachability->IsReachable(param, recv_done)); - EXPECT_FALSE(reachability->IsReachable(send, recv)); - EXPECT_FALSE(reachability->IsReachable(send_done, recv)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 4f898ce61c3..5e37883d3d8 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -52,8 +52,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, and AfterAll operation. - // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // Skip Constant, Parameter, Tuple, AfterAll operation. + // Tuple constants are not directly supported by any backends, hence + // folding Tuple is not useful and would in fact be expanded back into + // kTuple by Algebraic Simplifier. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this // special case is not necessary. @@ -63,6 +65,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { instruction->opcode() == HloOpcode::kAfterAll) { continue; } + // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index e45f905f715..d12f920722e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloVerifiedTestBase; +using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -46,13 +46,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -67,13 +67,13 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -88,13 +88,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -130,11 +130,11 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -157,11 +157,11 @@ TEST_F(HloConstantFoldingTest, Slice) { Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); builder.AddInstruction(HloInstruction::CreateSlice( shape, literal_instruction, slice_start, slice_limits, slice_strides)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -182,11 +182,11 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -219,27 +219,28 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - ParseAndVerifyModule(kConstantFoldReduce); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_TRUE(result); - EXPECT_EQ(6, module() - .entry_computation() + EXPECT_EQ(6, m->entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - ParseAndVerifyModule(kConstantFoldReduce); - HloInstruction* add = module().computations().begin()->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); + HloInstruction* add = m->computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce()); } const char* const kConstantFoldLargePad = R"( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 108aeea097d..fdfb38b858c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -269,7 +269,7 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(map->to_apply())); + ProcessNestedSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64 element_count = ShapeUtil::ElementsIn(map->shape()); @@ -285,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. // This counts the number of times the reduction function is applied, so it @@ -311,7 +311,7 @@ Status HloCostAnalysis::HandleReduceWindow( auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each // output element there are window_size - 1 reductions to perform. @@ -336,9 +336,9 @@ Status HloCostAnalysis::HandleSelectAndScatter( // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, - ProcessSubcomputation(instruction->select())); + ProcessNestedSubcomputation(instruction->select())); TF_ASSIGN_OR_RETURN(const Properties scatter_properties, - ProcessSubcomputation(instruction->scatter())); + ProcessNestedSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter // source element there are window_size - 1 select computations to perform and @@ -574,7 +574,7 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); + ProcessNestedSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -595,7 +595,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, - ProcessSubcomputation(call->to_apply())); + ProcessUnnestedSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -624,13 +624,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. - // - // TODO(b/26346211): Improve the cost analysis for while nodes. TF_ASSIGN_OR_RETURN(const Properties body_properties, - ProcessSubcomputation(xla_while->while_body())); + ProcessUnnestedSubcomputation(xla_while->while_body())); - TF_ASSIGN_OR_RETURN(const Properties condition_properties, - ProcessSubcomputation(xla_while->while_condition())); + TF_ASSIGN_OR_RETURN( + const Properties condition_properties, + ProcessUnnestedSubcomputation(xla_while->while_condition())); current_properties_.clear(); for (const auto& property : body_properties) { @@ -647,10 +646,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { // Compute the cost of the true and false computations and take the maximum // from those for each property. - TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, - ProcessSubcomputation(conditional->true_computation())); - TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, - ProcessSubcomputation(conditional->false_computation())); + TF_ASSIGN_OR_RETURN( + const Properties true_computation_properties, + ProcessUnnestedSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN( + const Properties false_computation_properties, + ProcessUnnestedSubcomputation(conditional->false_computation())); current_properties_ = true_computation_properties; for (const auto& property : false_computation_properties) { if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { @@ -680,7 +681,7 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { const int64 element_count = ShapeUtil::ElementsIn(scatter->operand(2)->shape()); TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(scatter->to_apply())); + ProcessNestedSubcomputation(scatter->to_apply())); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * element_count; @@ -689,6 +690,11 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { return Status::OK(); } +Status HloCostAnalysis::HandleGetDimensionSize( + const HloInstruction* /*get_size*/) { + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -725,11 +731,20 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } -StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation) { +StatusOr +HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } +StatusOr +HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + hlo_properties_.insert(visitor.hlo_properties_.begin(), + visitor.hlo_properties_.end()); + return visitor.properties(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 46b4bbeef22..8ced9d776e1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -107,6 +107,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; Status HandleScatter(const HloInstruction* scatter) override; + Status HandleGetDimensionSize(const HloInstruction* get_size) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -153,7 +154,24 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Returns the properties computed from visiting the computation rooted at the // given hlo. - StatusOr ProcessSubcomputation(HloComputation* computation); + // + // The difference between ProcessNestedSubcomputation and + // ProcessUnnestedSubcomputation is that we expect to get profile results for + // an unnested subcomputation's individual instructions, while we expect that + // a nested subcomputation is completely subsumed by its parent. + // + // For example, subcomputations inside kFusion and kMap are considered nested, + // while subcomputations inside kWhile and kConditional are considered + // unnested. + // + // Another way of thinking of this is, kFusion is implemented on the GPU + // backend using just one GPU kernel, while kWhile's body is implemented as a + // sequence of kernels, one for each HLO therein. Backends don't necessarily + // need to follow this same implementation strategy, but we assume they do for + // the purposes of this platform-generic cost analysis. + StatusOr ProcessNestedSubcomputation(HloComputation* computation); + StatusOr ProcessUnnestedSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9acee892d59..6a15b3440c6 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index e07a196d115..aaa9ec60eb3 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,22 +19,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -class HloCreationUtilsTest : public HloVerifiedTestBase { +class HloCreationUtilsTest : public HloTestBase { protected: - HloModule* CreateModuleWithProgramShape( + std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, absl::Span input_shape_dims, absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - auto module = CreateNewModule("test"); + auto module = CreateNewVerifiedModule("test"); *entry_computation = module->AddEntryComputation( CreateComputationWithSignature({&input_shape}, output_shape, "entry") .ValueOrDie()); @@ -47,10 +47,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); @@ -67,9 +66,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + auto module = CreateModuleWithProgramShape( + S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, @@ -92,10 +90,9 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{1, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -113,10 +110,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -134,10 +130,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{1, 1}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -154,10 +149,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6}, + /*output_shape_dims=*/{3, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, ExpandFirstDimIntoNDims(param, {3, 1, 2})); @@ -176,10 +170,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{6}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -197,10 +190,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -218,10 +210,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(F32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9b18b0284f6..1eb0260468c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloVerifiedTestBase { +class HloCseTest : public HloTestBase { protected: HloCseTest() {} }; @@ -59,13 +59,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); @@ -89,14 +89,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); @@ -121,14 +121,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { shape_r0, HloOpcode::kAdd, root, constants[i])); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -201,7 +201,7 @@ TEST_F(HloCseTest, NonscalarConstants) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( {common_constant1, common_constant2, uncommon_constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -233,14 +233,14 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -277,21 +277,21 @@ index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -327,20 +327,20 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -369,22 +369,21 @@ condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), condition=%condition.1, body=%body - } + })"; - )"); - - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -411,13 +410,14 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -439,14 +439,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -470,14 +470,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -488,7 +488,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { TEST_F(HloCseTest, FusionInternalCSE) { // Test that we can CSE expressions that live within a fusion node // computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); @@ -512,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -554,14 +554,14 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -586,7 +586,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -595,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -607,7 +607,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // rng_function is an impure function because it does RNG. HloComputation* rng_function = nullptr; @@ -649,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -659,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule m add_computation { @@ -680,11 +680,12 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })"); + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -697,19 +698,19 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -730,11 +731,12 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})"); +})"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - const HloInstruction* sub = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + const HloInstruction* sub = m->entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 909853106d5..6422346c101 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewUnverifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1884,7 +1884,7 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, class HloDataflowAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -2476,7 +2476,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -2511,7 +2511,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 3b5cde2996c..6c8095d3977 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -80,7 +80,7 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -215,7 +215,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { // Tests that a while instruction with an infeed (effectul instruction) in its // body is not removed, even its user count is 0. TEST_F(HloDceTest, CalledComputationWithSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Condition computation of a while instruction. @@ -270,7 +270,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { // Tests that a nested call instruction with a side effect is not removed. TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Nested called computation with a side effect. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); @@ -364,7 +364,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { } TEST_F(HloDceTest, KeepUsedSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index b90e8db2339..acdb42128e3 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" @@ -22,13 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloVerifiedTestBase { +class HloDomainTest : public HloTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,13 +63,6 @@ class HloDomainTest : public HloVerifiedTestBase { } return false; } - - StatusOr ParseModule(absl::string_view hlo_string) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_string, config); - return &module(); - } }; // Dummy DomainMetadata implementation which create kDomain boundaries around @@ -144,31 +136,32 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,11 +179,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(!isolator_changed); } @@ -213,26 +207,27 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); - EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -250,11 +245,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_FALSE(isolator_changed); } @@ -273,15 +269,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module, "c"); + HloInstruction* add = FindInstruction(module.get(), "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -304,41 +301,42 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module)); + sharding_isolator.Run(module.get())); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -359,16 +357,17 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); - EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. @@ -384,7 +383,7 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data"); ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); @@ -410,7 +409,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); struct Assignment { @@ -446,25 +445,26 @@ ENTRY entry { sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); - EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple")); // Remove %tuple and %gte (tuple simplification) - HloInstruction* gte = FindInstruction(module, "gte"); - HloInstruction* tuple = FindInstruction(module, "tuple"); + HloInstruction* gte = FindInstruction(module.get(), "gte"); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); const HloInstruction* root = module->entry_computation()->root_instruction(); @@ -486,11 +486,11 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto hlo_string = module->ToString(); - ASSERT_TRUE(ParseModule(hlo_string).status().ok()); + ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok()); } // Tuple inputs are domain instructions. @@ -507,20 +507,21 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tpl instruction, in order to test domain sharding // application. - auto tpl = FindInstruction(module, "tpl"); + auto tpl = FindInstruction(module.get(), "tpl"); tpl->clear_sharding(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), @@ -555,36 +556,37 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } // Emulate instructions inserted at top and bottom within nested tuple domain. @@ -603,15 +605,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tuple.0 instruction, in order to test domain sharding // application. - auto tuple0 = FindInstruction(module, "tuple.0"); + auto tuple0 = FindInstruction(module.get(), "tuple.0"); tuple0->clear_sharding(); // Insert the following instructons above and below tuple.0, to emulate other @@ -655,7 +658,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_TRUE(tuple0->has_sharding()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 608a42bb607..d95b6ad04f2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,9 +50,9 @@ namespace { static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloVerifiedTestBase { + public HloTestBase { protected: - HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -60,14 +60,14 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(&module()).ValueOrDie(); + type_converter.Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - // Evaluate function that takes in a local module instead of using module_ - // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // Evaluate function that takes in a local module instead of using m_ + // that is in HloTestBase. Once m_ in HloTestBase is // removed, this should be the default Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { @@ -88,7 +88,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -108,7 +108,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -116,6 +116,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } bool use_bfloat16_; + std::unique_ptr m_ = CreateNewVerifiedModule(); }; #define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ @@ -135,7 +136,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -156,7 +157,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -181,7 +182,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -322,7 +323,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(args); @@ -346,7 +347,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -367,7 +368,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, {1, 2})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -386,7 +387,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -406,7 +407,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -428,7 +429,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -448,7 +449,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -468,7 +469,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -503,7 +504,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -530,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -574,7 +575,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -619,7 +620,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -658,7 +659,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -704,7 +705,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -748,7 +749,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -802,7 +803,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -857,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -941,7 +942,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1019,7 +1020,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1079,7 +1080,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1143,7 +1144,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1215,7 +1216,7 @@ TEST_P(HloEvaluatorTest, b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1286,7 +1287,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1297,11 +1298,12 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1319,12 +1321,12 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m->AddEmbeddedComputation(add_computation.Build()); HloInstruction* reduce_instruction = b.AddInstruction( HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{0}, add_func)); - module().AddEntryComputation(b.Build()); + m->AddEntryComputation(b.Build()); HloEvaluator hlo_eval; Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); @@ -1337,7 +1339,7 @@ void BM_ReducePrecisely(int num_iters) { tensorflow::testing::StopTiming(); HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1396,14 +1398,14 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1438,7 +1440,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1455,7 +1457,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1490,7 +1492,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1507,7 +1509,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1541,7 +1543,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1564,7 +1566,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1594,7 +1596,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1625,7 +1627,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1657,7 +1659,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1691,7 +1693,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1727,7 +1729,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1764,7 +1766,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1800,7 +1802,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1839,7 +1841,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1877,7 +1879,7 @@ TEST_P(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1966,7 +1968,7 @@ ENTRY main { slice_sizes={1, 3} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -1990,7 +1992,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -2014,7 +2016,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2039,7 +2041,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2066,7 +2068,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2092,7 +2094,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); @@ -2115,7 +2117,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2139,7 +2141,7 @@ ENTRY main { slice_sizes={1, 0} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), @@ -2161,7 +2163,7 @@ ENTRY main { slice_sizes={1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = @@ -2192,7 +2194,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2223,7 +2225,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2256,7 +2258,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2288,7 +2290,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2320,7 +2322,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); @@ -2354,7 +2356,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2386,7 +2388,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2418,7 +2420,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2455,7 +2457,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2491,7 +2493,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2523,7 +2525,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2555,7 +2557,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); @@ -2585,7 +2587,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal scatter_indices = @@ -2736,7 +2738,7 @@ ENTRY main { ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); @@ -2754,7 +2756,7 @@ ENTRY main { ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 13a74fd8a11..05cc1593e4e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1043,6 +1043,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGetDimensionSize: return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f6ed86b4165..26786ee950b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -312,6 +312,10 @@ StatusOr> HloInstruction::CreateFromProto( proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { + TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) && + (ShapeUtil::TupleElementCount(proto.shape()) == 2)) + << "Infeed should have a tuple shape with 2 operands, but has: " + << proto.shape(); const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -530,6 +534,12 @@ StatusOr> HloInstruction::CreateFromProto( absl::make_unique(exit_hlo_sharding)); break; } + case HloOpcode::kGetDimensionSize: + TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.dimensions_size() == 1); + instruction = CreateGetDimensionSize(proto.shape(), operands(0), + proto.dimensions(0)); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -1001,6 +1011,14 @@ HloInstruction::CreateSelectAndScatter( broadcast_dimensions); } +/* static */ std::unique_ptr +HloInstruction::CreateGetDimensionSize(const Shape& shape, + HloInstruction* operand, + int64 dimension) { + return absl::make_unique(shape, operand, + dimension); +} + /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -1109,7 +1127,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) { void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { - if (sharding_ != nullptr) { + if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType( + shape_, derived_instruction->shape())) { + // Only copy sharding if the shape of the two instruction is compatible + // because copying it between differently shaped instructions can produce + // invalid shardings. derived_instruction->set_sharding(*sharding_); } else { derived_instruction->clear_sharding(); @@ -1268,6 +1290,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kIota: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1715,6 +1738,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kScatter: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2440,6 +2464,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAfterAll(this); case HloOpcode::kIota: return visitor->HandleIota(this); + case HloOpcode::kGetDimensionSize: + return visitor->HandleGetDimensionSize(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2639,49 +2665,7 @@ Status HloInstruction::Accept( return this->Accept(&visitor); } -Status HloInstruction::AcceptOrdered( - DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; - TF_RET_CHECK(OrderIsTopologicalSort(order)); - - // Compute the predecessors of this instruction. - std::unordered_set predecessors; - TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { - predecessors.insert(instruction); - return Status::OK(); - })); - - for (auto* const_instruction : order) { - if (!ContainsKey(predecessors, const_instruction)) { - // Instruction is not a predecessors of 'this'. - continue; - } - - // The visitor can mark instructions as visited to skip particular - // instructions. - if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO %" << const_instruction->name() - << " as it was already visited."; - continue; - } - - // TODO(b/78350259): Eliminate const laundering. - HloInstruction* instruction = - const_cast(const_instruction); - - TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO %" << instruction->name(); - TF_RETURN_IF_ERROR(instruction->Visit(visitor)); - visitor->SetVisited(*instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); - } - - return visitor->FinishVisit(this); -} - -const Shape& HloInstruction::shape() const { - return shape_; -} +const Shape& HloInstruction::shape() const { return shape_; } std::vector HloInstruction::OperandIndices( const HloInstruction* operand) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 15a4da8dbe0..818d4ede0f3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -767,6 +767,9 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr CreateToken(); + static std::unique_ptr CreateGetDimensionSize( + const Shape& shape, HloInstruction* operand, int64 dimension); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -954,16 +957,6 @@ class HloInstruction { Status Accept( const std::function& visitor_func) const; - // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain at least the set of instructions - // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). Instructions contained in 'order' which are not in the set of - // instructions rooted at this node are ignored. 'order' must also be a valid - // topological sort of these instructions (defs appear before uses) though - // need not be a DFS post-order. - Status AcceptOrdered(DfsHloVisitor* visitor, - const std::vector& order); - // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d93351fe043..8048e332cb5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,7 +39,7 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloVerifiedTestBase { +class HloInstructionTest : public HloTestBase { protected: Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -151,7 +151,7 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); @@ -188,7 +188,7 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -221,7 +221,7 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -256,7 +256,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -305,7 +305,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -327,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -375,7 +375,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. @@ -416,7 +416,7 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -451,7 +451,7 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -479,7 +479,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -516,7 +516,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -546,7 +546,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -611,7 +611,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; @@ -629,7 +629,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -647,7 +647,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -669,7 +669,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -692,7 +692,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -749,7 +749,7 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); @@ -817,7 +817,7 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -977,7 +977,7 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); int visit_num = 0; @@ -1006,7 +1006,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); @@ -1016,7 +1016,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { } TEST_F(HloInstructionTest, MapIsElementwise) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); HloComputation::Builder builder(TestName()); HloComputation::Builder map_builder("id"); @@ -1067,7 +1067,7 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); @@ -1108,7 +1108,7 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); @@ -1151,7 +1151,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1192,7 +1192,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1204,7 +1204,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { } TEST_F(HloInstructionTest, FusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1226,7 +1226,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1330,7 +1330,7 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( @@ -1373,7 +1373,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1408,7 +1408,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1443,7 +1443,7 @@ TEST_F(HloInstructionTest, StringifyScatter) { update_builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* update_computation = module->AddEmbeddedComputation(update_builder.Build()); @@ -1495,7 +1495,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1531,7 +1531,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1587,7 +1587,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 88495e80000..4c765aa375c 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2349,4 +2349,43 @@ HloInstructionProto HloDomainInstruction::ToProto() const { return proto; } + +HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction( + const Shape& shape, HloInstruction* operand, int64 dimension) + : HloInstruction(HloOpcode::kGetDimensionSize, shape), + dimension_(dimension) { + AppendOperand(operand); +} + +HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(dimension()); + return proto; +} + +std::vector HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + return {StrCat("dimensions={", dimension(), "}")}; +} + +bool HloGetDimensionSizeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return dimension() == casted_other.dimension(); +} + +std::unique_ptr +HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + if (new_operands.size() != 1) { + LOG(FATAL) << "expects 1 operand"; + } + return absl::make_unique( + shape, new_operands[0], dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index bf4daf2be47..d43a8973ccf 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1385,6 +1385,33 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; }; + +class HloGetDimensionSizeInstruction : public HloInstruction { + public: + explicit HloGetDimensionSizeInstruction(const Shape& shape, + HloInstruction* operand, + int64 dimension); + + // Returns the dimension sizes or numbers associated with this instruction. + int64 dimension() const { return dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 1717770301e..170ec93a334 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -165,6 +165,7 @@ namespace opcode_matchers { } HLO_MATCHER(Abs); HLO_MATCHER(Add); +HLO_MATCHER(AllToAll); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); HLO_MATCHER(BatchNormGrad); @@ -178,6 +179,7 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 2f15997fc17..984a6266abb 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -172,7 +172,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, ScheduleModule(*module, @@ -218,7 +218,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( @@ -242,7 +242,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { } TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); // param != 0 diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 6a838b7eb96..14bf17f4be1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -559,7 +559,8 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { std::unique_ptr HloModule::Clone(const HloModuleConfig& config, const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config); + auto module = absl::make_unique( + absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 39f38b417ab..3ae67e4e5ee 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewModule(); - auto module_b = CreateNewModule(); + auto module_a = CreateNewUnverifiedModule(); + auto module_b = CreateNewUnverifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e6bfb8025d4..70c7d70b41c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -83,6 +83,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetDimensionSize, "get-dimension-size") \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 23d41d91d69..f5f99bece18 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -334,7 +334,7 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto* computation : module->MakeNonfusionComputations()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, HloReachabilityMap::Build(computation)); } } @@ -374,11 +374,10 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( return order_position_.at(a) < order_position_.at(b); } -const std::vector* -SequentialHloOrdering::SequentialOrder( +const HloInstructionSequence* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { return schedule_.is_computation_scheduled(&computation) - ? &schedule_.sequence(&computation).instructions() + ? &schedule_.sequence(&computation) : nullptr; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 66313492eb2..a07214c22c0 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +65,7 @@ class HloOrdering { // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. - virtual const std::vector* SequentialOrder( + virtual const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const = 0; // Return the call graph of the module used to compute ordering. @@ -96,7 +97,7 @@ class PredecessorHloOrdering : public HloOrdering { // Returns nullptr indicating the computation does not have a sequential // ordering. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return nullptr; } @@ -185,7 +186,7 @@ class SequentialHloOrdering : public HloOrdering { ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override; string ToString() const override; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index b045adc9640..2ab8aa57f6e 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 450660b94b7..4390145c6bd 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -108,7 +108,7 @@ class HloParser { bool ParseInstructionList(HloComputation** computation, const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); @@ -608,10 +608,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } - return ParseInstruciontRhs(builder, name, name_loc); + return ParseInstructionRhs(builder, name, name_loc); } -bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, +bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc) { Shape shape; HloOpcode opcode; @@ -1547,6 +1547,18 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); + case HloOpcode::kGetDimensionSize: + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateGetDimensionSize( + shape, operands[0], (*dimensions)[0])); + break; } instruction->SetAndSanitizeName(name); @@ -2708,7 +2720,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - std::vector split1 = absl::StrSplit(str, "_"); + std::vector split1 = absl::StrSplit(str, '_'); if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; @@ -3389,7 +3401,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { // e.g. // // f32[10] fusion(...), calls={...} - if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) { return false; } } else { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 81eeb9f13bf..d830fa61438 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,7 +44,9 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, +// e.g., "{replicated}". StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). @@ -55,10 +57,6 @@ StatusOr ParseWindow(absl::string_view str); StatusOr ParseConvolutionDimensionNumbers( absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(absl::string_view str); - // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index eae6d19792f..c59bdc0a0b3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1150,6 +1150,25 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } +)" +}, +// cross-replica-sum with all-reduce-id +{ +"CrossReplicaSumAllReduce", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add +} + )" }, // all-to-all diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index ee8cb12b231..20384b9da6b 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloPassPipelineTest : public HloVerifiedTestBase { +class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2d5197be9e6..f968a4a9445 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -104,5 +104,20 @@ bool IsScalarConstant(const HloInstruction* instruction) { return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape()); } +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes) { + for (const auto* instr : comp->instructions()) { + if (opcodes.count(instr->opcode())) { + return true; + } + for (const HloComputation* subcomp : instr->called_computations()) { + if (ContainsInstrWithOpcode(subcomp, opcodes)) { + return true; + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c0826a6aee1..215051f8834 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -41,6 +43,12 @@ bool AllOperandsAreConstants(const HloInstruction& instruction); // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); +// Determines whether the given computation contains an instruction with one of +// the given opcodes. Checks both comp's instructions and the instructions of +// any computations nested within it. +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes); + // Returns an operand of an instruction with the given opcode. If there are // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 961930f0a88..4aa80677524 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { @@ -22,7 +24,7 @@ HloReachabilityMap::HloReachabilityMap( : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { - indices_[hlo] = bit_vectors_.size(); + indices_[GetKey(hlo)] = bit_vectors_.size(); bit_vectors_.emplace_back(size_); } CHECK_EQ(size_, indices_.size()); // instructions should be unique @@ -71,4 +73,70 @@ bool HloReachabilityMap::IsConnected(const HloInstruction* a, return IsReachable(a, b) || IsReachable(b, a); } +std::unique_ptr HloReachabilityMap::Build( + const HloComputation* computation) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + auto channel_dependency_map = computation->ComputeChannelDependencies(); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + +void HloReachabilityMap::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction) { + std::queue worklist; + worklist.push(instruction); + + std::vector inputs; + + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); + + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); + + if (SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 3e27d098aeb..7823b06a41b 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -16,27 +16,30 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#include #include #include +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace xla { -class HloInstruction; - // A class for representing reachability between HloInstructions. // -// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix -// and it is up to the user of the class to set the adjacency matrix such that -// it represents reachability, i.e. such that it is transitive. That the graph -// be transitive is thus not an invariant of this class, but it is required for -// the name of the class and its methods to make sense. +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given @@ -44,6 +47,15 @@ class HloReachabilityMap { explicit HloReachabilityMap( absl::Span instructions); + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true @@ -70,6 +82,10 @@ class HloReachabilityMap { // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + // Returns true if "b" is reachable from "a" // // Note that this function only correctly answers queries about reachability @@ -83,7 +99,9 @@ class HloReachabilityMap { bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; // Checks if an instruction is in the Reachability map. - bool IsPresent(const HloInstruction* a) const { return indices_.contains(a); } + bool IsPresent(const HloInstruction* a) const { + return indices_.contains(GetKey(a)); + } private: // A bit-vector implementation specialized for this use case which provides a @@ -146,18 +164,24 @@ class HloReachabilityMap { absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); + uint64 GetKey(const HloInstruction* instruction) const { + uint64 unique_id = absl::bit_cast(instruction->unique_id()); + uint64 module_id = + absl::bit_cast(instruction->parent()->parent()->unique_id()); + return (module_id << 32) | unique_id; + } // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { - return FindOrDie(indices_, instruction); + return FindOrDie(indices_, GetKey(instruction)); } // The number of instructions in the reachability map. const size_t size_; - // Dense assignment from HloInstruction* to number. These numbers index - // into the bit_vectors_ vector and into the bits within a BitVector. - absl::flat_hash_map indices_; + // Dense assignment from HloInstruction::unique_id to number. These numbers + // index into the bit_vectors_ vector and into the bits within a BitVector. + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index d9848cee0bf..59517670980 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloVerifiedTestBase {}; +class HloReachabilityTest : public HloTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: @@ -48,7 +48,8 @@ TEST_F(HloReachabilityTest, Reachability) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); HloReachabilityMap reachability({a, b, c, d, e}); reachability.SetReachable(a, a); @@ -81,6 +82,130 @@ TEST_F(HloReachabilityTest, Reachability) { EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); } +TEST_F(HloReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloReachabilityMap::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + reachability->UpdateReachabilityThroughInstruction(exp); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + reachability->UpdateReachabilityThroughInstruction(negate); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + +TEST_F(HloReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloReachabilityMap::Build(computation); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f7e82fb1f88..22c3c40a93a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloVerifiedTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -162,7 +162,7 @@ class HloRematerializationTest : public HloVerifiedTestBase { // Test rematerialization of a single computation produced by // MakeRematerializableComputation. TEST_F(HloRematerializationTest, SingleComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, module)); + /*memory_limit_bytes=*/14 * 1024, module.get())); EXPECT_TRUE(changed); // Root should not have changed. @@ -203,7 +203,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, module)); + /*memory_limit_bytes=*/20 * 1024, module.get())); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -225,7 +225,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. TEST_F(HloRematerializationTest, RematerializeAroundWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // bit lower (17KB) to force rematerialization of the entry computation. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, module)); + /*memory_limit_bytes=*/17 * 1024, module.get())); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -261,7 +261,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while. Both the entry computation and while body computation should have // computations rematerialized. TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, module)); + /*memory_limit_bytes=*/15 * 1024, module.get())); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -293,7 +293,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. TEST_F(HloRematerializationTest, RematerializeNestedComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // ~12K so pick something slightly larger. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, module)); + /*memory_limit_bytes=*/13 * 1024, module.get())); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -346,7 +346,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + // // rng + tanh + exp - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( @@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -420,7 +420,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // The value %bcast is live across each call of Subcomputation (which requires // 8KB) though the value is not used in the calls. Rematerializing %bcast // across these calls reduces peak memory use from ~20KB down to ~16KB. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -533,7 +533,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // (ie %bcast is used indirectly by %negate), otherwise the %negate operand // aliases %add_2. const bool indirectly_used = GetParam(); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 45c684d6675..11994d99c93 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -143,7 +143,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { // Do not unify subcomputations with different parameter shapes. TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 6fd734a2b9e..1e2b31a1f2b 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloVerifiedTestBase { +class HloTfGraphBuilderTest : public HloTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 136824a3356..27fd685a69a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -755,6 +756,12 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } +Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { + return CheckShape( + get_size, ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimensions(0))); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -1331,6 +1338,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + if (crs->all_reduce_id().has_value()) { + TF_RET_CHECK(crs->all_reduce_id().value() > 0) + << "All reduce id must be greater than 0 for " + << crs->ToShortString(); + } + return Status::OK(); + } + Status Preprocess(HloInstruction* instruction) override { auto previous = instructions_by_name_.find(instruction->name()); TF_RET_CHECK(previous == instructions_by_name_.end()) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 83b6244d1be..9fbfd6a21c1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -94,6 +94,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleGather(HloInstruction* gather) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleGetDimensionSize(HloInstruction* get_size) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index afe01e5487c..5ddfe0a944f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,7 +35,7 @@ namespace { using ::testing::HasSubstr; -// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: @@ -66,7 +66,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -85,7 +85,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -104,7 +104,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -138,7 +138,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -303,7 +303,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -327,7 +327,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index e103222b55f..90904ac0011 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -90,20 +90,29 @@ string HumanReadableProfileBuilder::ToString() const { op.optimal_seconds < 0 ? "" : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), - op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), - op.transcendental_count <= 0 - ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + op.flop_count > 0 && nsecs > 0 + ? HumanReadableNumFlops(op.flop_count, nsecs) + : "", + op.transcendental_count > 0 && nsecs > 0 + ? HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + : "", bytes_per_sec, bytes_per_cycle, op.name); }; - float optimal_seconds_sum = 0.0; + double optimal_seconds_sum = 0; int64 total_flops = 0.; int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { if (op.optimal_seconds > 0) { - optimal_seconds_sum += op.optimal_seconds; + // An op can run faster than the estimated optimum. For example, we might + // estimate a fusion's speed by looking at the size of its operands and + // result, but perhaps the fusion doesn't read the entirety of all of its + // inputs. For the purposes of summing the instructions' optimal speeds, + // we treat the "optimum" as the smallest of either the estimated optimum + // and the actual speed. + optimal_seconds_sum += + std::min(double{op.optimal_seconds}, CyclesToSeconds(op.cycles)); } total_flops += std::max(op.flop_count, int64{0}); total_transcendentals += std::max(op.transcendental_count, int64{0}); @@ -114,7 +123,7 @@ string HumanReadableProfileBuilder::ToString() const { print_op({is_entry_computation_ ? "[total] [entry]" : "[total]", "[total]", /*category=*/"", total_cycles_, total_flops, total_transcendentals, - total_bytes, optimal_seconds_sum}, + total_bytes, static_cast(optimal_seconds_sum)}, /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. @@ -155,8 +164,10 @@ string HumanReadableProfileBuilder::ToString() const { entry.text = op.name; entry.short_text = op.short_name; entry.category_text = op.category; - entry.metric = - CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6; + // Ignore ops that run faster than the estimated optimal here, as we do + // when calculating optimal_seconds_sum. + entry.metric = std::max( + 0., CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6); total_discrepancy_in_microseconds += entry.metric; table.AddEntry(std::move(entry)); } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d5225..cf6cf897fe1 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -18,19 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { +class ImplicitBroadcastRemoverTest : public HloTestBase { protected: ImplicitBroadcastRemover remover_; }; TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -41,15 +42,16 @@ TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -60,13 +62,13 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); @@ -76,6 +78,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -86,9 +89,9 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Subtract(op::Parameter(), @@ -98,6 +101,7 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); @@ -108,9 +112,9 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, @@ -120,6 +124,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { } TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); @@ -132,9 +137,9 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), @@ -147,6 +152,7 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { TEST_F(ImplicitBroadcastRemoverTest, TernaryScalarAndDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -159,9 +165,9 @@ TEST_F(ImplicitBroadcastRemoverTest, builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 2d03aebc1ac..20cc18f9815 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloVerifiedTestBase { +class IndexedArrayAnalysisTest : public HloTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -61,12 +61,12 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { const string& root_expression, bool print_constants) { IndexedArrayAnalysis indexed_tensor_analysis; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN( - IndexedArrayAnalysis::Array* const array_result, - indexed_tensor_analysis.GetArrayFor( - module().entry_computation()->root_instruction())); + TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + m->entry_computation()->root_instruction())); string string_result = CanonicalizeWhitespace( indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index a85793e4774..7f2d7e7cffc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -155,6 +155,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGetDimensionSize: return true; } @@ -452,7 +453,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); @@ -566,7 +567,7 @@ bool InstructionFusion::MultiOutputFusionCreatesCycle( // A consumer operand may have been multii-output fused into a parallel // consumer and thus be missing from the oridinal reachability map. if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = consumer->parent()->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(consumer->parent()); } return reachability_->IsReachable(a, b); }; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 4045e886dd9..198bd7fce5f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/platform/macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index da1ad90959d..39904bd54b0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( @@ -361,7 +361,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary2, computation->root_instruction()); EXPECT_TRUE( @@ -385,7 +385,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index c9b40d3c619..b0fc1af8b89 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -110,3 +110,5 @@ REGISTER_MODULE_INITIALIZER( // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, + interpreter_platform); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6b033946698..a9041192220 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2092,6 +2092,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 47bfca2fd6e..2400b7bb7c4 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,20 +49,19 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloVerifiedTestBase { +class LayoutAssignmentTest : public HloTestBase { protected: - void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout, + void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, /*channel_constraints=*/channel_constraints); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m).status()); } - std::vector LayoutOf(HloModule* module, absl::string_view name) { + std::vector LayoutOf(HloModule* m, absl::string_view name) { auto minor_to_major = - FindInstruction(module, name)->shape().layout().minor_to_major(); + FindInstruction(m, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } @@ -91,7 +90,7 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { TEST_F(LayoutAssignmentTest, ComputationLayout) { // Verify the layouts of the root and parameter instructions of a computation // match the ComputationLayout for two different layouts. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); @@ -101,8 +100,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout layout = LayoutUtil::MakeLayout(minor_to_major); Shape shape(ashape); @@ -113,7 +112,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -131,8 +130,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); Shape col_major_shape(ashape); @@ -149,7 +148,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -160,7 +159,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { // Verify that the layout of the fused parameters in a fusion instruction // match that of the fusion operands. Other fused instructions should have no // layout. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); auto constant_literal1 = LiteralUtil::CreateR2WithLayout( @@ -180,8 +179,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); @@ -194,7 +193,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -229,13 +228,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { auto negate = builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kNegate, get_element0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -267,17 +266,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -302,11 +301,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { auto nested_tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, inner_tuple})); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = nested_tuple->shape(); *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); @@ -316,7 +315,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -332,9 +331,9 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module) + .Run(m.get()) .ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), @@ -361,9 +360,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -374,7 +372,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -403,8 +401,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { HloInstruction::CreateTranspose(bshape, log, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + auto computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -415,7 +413,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -439,9 +437,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { HloInstruction::CreateBroadcast(bshape, param, {1, 2})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); Shape input_shape_with_layout(ashape); Shape output_shape_with_layout(cshape); @@ -454,7 +452,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -488,9 +486,8 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tuple)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tuple)); ComputationLayout computation_layout(computation->ComputeProgramShape()); Shape param_shape_with_layout(f32_4); @@ -507,7 +504,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -558,9 +555,8 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); auto reshape = builder.AddInstruction( HloInstruction::CreateReshape(cshape, concatenate)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(reshape)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(reshape)); Shape param0_shape_with_layout(ashape); Shape param1_shape_with_layout(ashape); @@ -573,7 +569,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -593,11 +589,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -611,11 +607,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -681,12 +677,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - ParseAndVerifyModule(module_str); - + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); @@ -721,9 +717,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -735,19 +732,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -757,7 +754,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); @@ -784,7 +781,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { true_builder.AddInstruction(HloInstruction::CreateTuple({add})); } HloComputation* true_computation = - module->AddEmbeddedComputation(true_builder.Build()); + m->AddEmbeddedComputation(true_builder.Build()); auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); { @@ -800,14 +797,14 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = - module->AddEmbeddedComputation(false_builder.Build()); + m->AddEmbeddedComputation(false_builder.Build()); builder.AddInstruction(HloInstruction::CreateConditional( result_tshape, pred, tuple, true_computation, tuple, false_computation)); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -828,13 +825,13 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module).status(); + Status error_status = layout_assignment.Run(m.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -861,9 +858,10 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -873,12 +871,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(&module(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } @@ -897,17 +895,17 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] cross-replica-sum(gte), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -917,12 +915,12 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); - const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); } @@ -938,11 +936,12 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -966,11 +965,12 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -997,11 +997,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1028,11 +1029,12 @@ TEST_F(LayoutAssignmentTest, } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1050,11 +1052,12 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1107,20 +1110,21 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); // Sanity check to verify that there's a layout mismatch. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Make sure that layout assignment did not magically eliminate the mismatch, // in which case the test didn't prove anything. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); } TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { @@ -1136,32 +1140,32 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { // and result layout should match that of the computation. { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); } { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); @@ -1179,24 +1183,24 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // The custom call should be partially encapsulated in kCopy instructions // because of the layout mismatches. - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); @@ -1211,18 +1215,18 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall())); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); } @@ -1238,25 +1242,25 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Tuple()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); } @@ -1273,36 +1277,34 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, // Try with a couple different layouts. In each case the custom calls operand // and result layout should match that of the computation. TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ExpectTupleLayoutIs(module->result_shape(), {{1, 0}, {1, 0}}); + ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}}); - const HloInstruction* custom_call = - FindInstruction(module.get(), "custom-call"); + const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call"); ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } Status AssignLayoutsToComputation( - HloModule* module, - ChannelLayoutConstraints* channel_constraints = nullptr) { - if (!module->entry_computation_layout().result_layout().LayoutIsSet()) { - module->mutable_entry_computation_layout() + HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { + if (!m->entry_computation_layout().result_layout().LayoutIsSet()) { + m->mutable_entry_computation_layout() ->mutable_result_layout() ->SetToDefaultLayout(); } LayoutAssignment layout_assignment( - module->mutable_entry_computation_layout(), + m->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout, channel_constraints); - return layout_assignment.Run(module).status(); + return layout_assignment.Run(m).status(); } TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { @@ -1325,16 +1327,16 @@ TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { auto add = b.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1)); b.AddInstruction(HloInstruction::CreateTuple({add, transpose})); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(b.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(b.Build()); Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0}); Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1}); - *module->mutable_entry_computation_layout()->mutable_result_layout() = + *m->mutable_entry_computation_layout()->mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor})); const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0}); - ForceParameterLayout(module.get(), 0, r2_dim0major); - ForceParameterLayout(module.get(), 1, r2_dim0major); - TF_ASSERT_OK(AssignLayoutsToComputation(module.get())); + ForceParameterLayout(m.get(), 0, r2_dim0major); + ForceParameterLayout(m.get(), 1, r2_dim0major); + TF_ASSERT_OK(AssignLayoutsToComputation(m.get())); EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0)); EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(), diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 850501a4b5c..728a66b388f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", ], @@ -197,14 +198,17 @@ cc_library( hdrs = ["sort_util.h"], deps = [ ":ir_array", + ":kernel_support_library", ":llvm_loop", ":llvm_util", ":loop_emitter", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2e5aebb74c2..df78726166e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -83,10 +84,9 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b) { +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -260,6 +260,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, /*AddNull=*/false); } +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name) { + const int kNVPTXSharedMemoryAddrSpace = 3; + return new llvm::GlobalVariable( + *module, tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); +} + llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index f59baff263f..c604c7c870a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -101,10 +102,9 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b); +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -155,6 +155,11 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module); +// Allocates a tile of shared memory. +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name); + // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point // of the builder is set to the same place after calling this function diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 05ba4a40da4..fd16af67fe9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -18,7 +18,9 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -28,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -39,147 +43,352 @@ namespace xla { namespace llvm_ir { namespace { -// Adds the inner comparison loop where we compare elements pointed to by -// 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, - const IrArray::Index& compare_keys_index, - const IrArray& keys_array, - const std::vector& values_arrays, - llvm::IRBuilder<>* b) { - // if (is_smaller_index && - // compare_keys[dimension_to_sort] < dimension_to_sort_bound) - llvm::Value* is_smaller_index = b->CreateICmpSLT( - keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = EmitIfThenElse( - b->CreateAnd(is_smaller_index, - b->CreateICmpSLT(compare_keys_index[dimension_to_sort], - keys_index.GetConstantWithIndexType( - dimension_to_sort_bound))), - "smaller_comparison_index", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_data.true_block, b); - auto key1 = keys_array.EmitReadArrayElement(keys_index, b); - auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); - auto compare_key1 = key1; - auto compare_key2 = key2; - auto key_type = keys_array.GetShape().element_type(); - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the sort - // has a predictable behavior in the presence of NaNs. Rather than using - // floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; + +// Adds the inner comparison loop body where we compare elements. +void EmitCompareLoopBody( + int64 iteration_bound, PrimitiveType key_type, int64 num_values, + llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, + std::function read_element, + std::function + write_element, + llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { + auto index_typed_constant = [&](int64 value) { + return llvm::ConstantInt::get(index_type, value); + }; + // The 'xor_mask' determines which elements are compared against each other. + // Index 'current_keys_index' will be compared with 'current_keys_index' xor + // 'xor_mask'. This means that we will always compare a block of consecutive + // elements against elements from the adjacent block of the same size. When + // 'xor_mask' is a power of 2, it immediately identifies the size of such a + // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In + // that case, we essentially flip the last 'k' - 1 bits when computing the + // position of the element to compare to, so the block size is 2^(k - 1). + int64 block_size = xor_mask; + // Check if it is a value 2^k - 1. + if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) { + block_size = (xor_mask + 1) / 2; } - auto comparison = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - // If key2 < key1 - auto if_smaller_data = - EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_smaller_data.true_block, b); - // Swap key1 with key2. - keys_array.EmitWriteArrayElement(keys_index, key2, b); - keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - for (const auto& values_array : values_arrays) { - // Also swap the values. - auto value1 = values_array.EmitReadArrayElement(keys_index, b); - auto value2 = values_array.EmitReadArrayElement(compare_keys_index, b); - values_array.EmitWriteArrayElement(keys_index, value2, b); - values_array.EmitWriteArrayElement(compare_keys_index, value1, b); + auto current_keys_index = element_pair_index; + if (block_size == 1) { + // If the block size is 1, we take every second element and compare it to + // the next one. + current_keys_index = + b->CreateMul(current_keys_index, index_typed_constant(2)); + } else if (block_size * 2 < iteration_bound) { + // current_keys_index iterates through the 'left' elements of the element + // pairs to be compared. We first need to compute the comparison block to + // which the element belongs. The block id of that block is index / + // block_size. + auto block_id = + b->CreateUDiv(current_keys_index, index_typed_constant(block_size)); + // The index of the 'left' element within its block is simply the remainder + // when dividing by 'block_size'. + auto index_within_block = + b->CreateURem(current_keys_index, index_typed_constant(block_size)); + // The first element of the 'left' block of elements that is compared + // against elements from the adjacent 'right' block of elements is + // 'block_id' * (2 * 'block_size'). + auto first_element_in_block = + b->CreateMul(block_id, index_typed_constant(2 * block_size)); + current_keys_index = + b->CreateAdd(first_element_in_block, index_within_block); } + auto compare_keys_index = + b->CreateXor(current_keys_index, index_typed_constant(xor_mask)); + // current_keys_index < compare_keys_index + llvm::Value* is_smaller_index = + b->CreateICmpSLT(current_keys_index, compare_keys_index); + // compare_keys_index < iteration_bound + llvm::Value* index_is_inbounds = b->CreateICmpSLT( + compare_keys_index, index_typed_constant(iteration_bound)); + llvm::Value* do_comparison = + needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds) + : b->getInt1(true); + + // if (is_smaller_index && index_is_inbounds) + KernelSupportLibrary ksl(b); + ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + auto key1 = read_element(0, current_keys_index); + auto key2 = read_element(0, compare_keys_index); + auto compare_key1 = key1; + auto compare_key2 = key2; + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } + // If key2 < key1 + ksl.IfReturnVoid( + "is_smaller_than", + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1), + [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); + }); +} + +void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, + int64 dimension_to_sort, + int64 dimension_to_sort_bound, + PrimitiveType keys_type, + absl::Span xor_masks, + const std::vector& params, + const std::vector& param_shmem_buffers, + int64 tile_size, llvm::IRBuilder<>* b) { + KernelSupportLibrary ksl(b); + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); + llvm_ir::AddRangeMetadata(0, tile_size / 2, + llvm::cast(thread_id)); + thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), + /*isSigned=*/true, "thread.id.x"); + + auto copy_loop_body = + [&](std::function + read_or_write) { + auto value_one = tiled_keys_index.GetConstantWithIndexType(1); + auto current_keys_index = + b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); + // We want to copy two adjacent elements. We first check whether the + // first index position is within bounds. + ksl.IfReturnVoid( + "smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + auto cache_index = b->CreateShl(thread_id, value_one); + read_or_write(cache_index, current_keys_index); + // Increment to go the next index position. + current_keys_index = b->CreateAdd(current_keys_index, value_one); + // Here we check whether the next index position is within bounds. + ksl.IfReturnVoid( + "inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); + }); + }; + + // Copy operand tiles from the operand buffers to shared memory. + IrArray::Index keys_index = tiled_keys_index; + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = params[i].EmitReadArrayElement(keys_index, b); + b->CreateStore(value, + b->CreateGEP(param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), + cache_index})); + }); + } + // Wait until all reads have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + + // Now emit the bodies of the comparison loops. + auto read_element = [&](int64 operand, llvm::Value* index) { + return b->CreateLoad( + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + b->CreateStore( + value, + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + for (int64 xor_mask : xor_masks) { + // The index of the element pair to be compared within the tile stored in + // shared memory. We order the element pairs by the element with the smaller + // index. + auto element_pair_index = thread_id; + // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't + // need any bounds checks. + if (dimension_to_sort_bound % tile_size) { + // Otherwise we need a bounds check for the last tile. The last tile has + // size 'dimension_to_sort_bound' % 'tile_size'. + ksl.IfReturnVoid( + "is_last_tile", + b->CreateICmpUGE( + b->CreateMul(tiled_keys_index[dimension_to_sort], + tiled_keys_index.GetConstantWithIndexType(2)), + tiled_keys_index.GetConstantWithIndexType( + RoundDownToNearest(dimension_to_sort_bound, tile_size))), + [&]() { + EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, + params.size() - 1, element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b); + }, + [&]() { + EmitCompareLoopBody( + tile_size, keys_type, params.size() - 1, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + }); + } else { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + } + // Wait until all comparisons have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + } + + // Copy the operand tiles back from shared memory to the operand buffers. + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = b->CreateLoad(b->CreateGEP( + param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); + params[i].EmitWriteArrayElement(keys_index, value, b); + }); + } + // We should normally synchronize here to make sure all writes have happened. + // However the very next thing each thread does is reading 2 elements from the + // operand buffer and writing it into the same location in shared memory from + // which it previously copied it to the operand buffer, and we synchronize + // after this has happened. We can be sure that a thread always writes to the + // same location in shared memory because we have exactly tile_size / 2 many + // threads, and the linear index calculated by ParallelLoopEmitter uses + // linear_index = blockIdx.x * blockDim.x + threadIdx.x; } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions) { + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, + const int64 tile_size) { + // Iterate through the keys shape in physical order, but skip the dimension to + // sort and make it the innermost loop which is the loop where the comparisons + // happen. In the dimension to sort, if we use tiling, we iterate through it + // in tiles of 64 elements each, so we use another loop that happens within + // one thread to process this tile worth of data (thereby combining several + // comparison stages of the bitonic sort algorithm because they all happen + // within those 64 elements and are therefore independent of the other + // comparisons). + const Shape& keys_shape = keys_array.GetShape(); - - // Create loop nests which loop through the operand dimensions. The sort - // dimension is handled in the innermost loop which performs the sorting. - ForLoopNest loop_nest(name, b); - IrArray::Index keys_index = - loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); - if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); - } - - // 'compare_keys_index' is the index of the element that 'keys_index' should - // be compared to. - IrArray::Index compare_keys_index(keys_index.GetType()); - for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { + int64 rank = ShapeUtil::Rank(keys_shape); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::vector dimensions_in_iteration_order(rank); + std::vector iteration_order_to_logical_order(rank); + int64 dim = 0; + for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) { if (dimension != dimension_to_sort) { - compare_keys_index.push_back(keys_index[dimension]); - } else { - compare_keys_index.push_back(nullptr); + dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension); + iteration_order_to_logical_order[dim++] = dimension; + } + } + dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim; + iteration_order_to_logical_order[dim] = dimension_to_sort; + + Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + dimensions_in_iteration_order); + std::vector params(1, keys_array); + params.insert(params.end(), values_arrays.begin(), values_arrays.end()); + + // Allocate shared memory for the tiled compare loop. + std::vector param_shmem_buffers(params.size(), nullptr); + if (xor_masks.size() > 1) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + for (int64 i = 0; i < params.size(); ++i) { + llvm::Type* tile_type = + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + params[i].GetShape().element_type(), module), + tile_size); + param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( + module, tile_type, absl::StrCat(name, "_tile_param_", i)); } } - // Naive C++ code for the inner compare loop: - // - // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { - // int64 j = i ^ xor_mask; - // if (i < j && j < dimension_to_sort_bound) { - // int64 min_key = std::min(keys[i], keys[j]); - // keys[j] = std::max(keys[i], keys[j]); - // keys[i] = min_key; - // } - // } - // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), - {dimension_to_sort_bound}); auto compare_loop_body_emitter = - [&](const IrArray::Index& compare_index) -> Status { - keys_index[dimension_to_sort] = compare_index[0]; - compare_keys_index[dimension_to_sort] = - b->CreateXor(compare_index[0], xor_mask); - EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_arrays, b); + [&](const IrArray::Index& tiles_index) -> Status { + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // /* emitted in EmitCompareLoopBody() */ + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + IrArray::Index keys_index(tiles_index.GetType(), rank); + for (int64 i = 0; i < rank; ++i) { + keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + } + if (xor_masks.size() > 1) { + EmitTiledCompareLoop(keys_index, dimension_to_sort, + dimension_to_sort_bound, keys_shape.element_type(), + xor_masks, params, param_shmem_buffers, tile_size, + b); + } else { + auto read_element = [&](int64 operand, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + return params[operand].EmitReadArrayElement(keys_index, b); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + keys_index[dimension_to_sort] = index; + params[operand].EmitWriteArrayElement(keys_index, value, b); + }; + EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), + values_arrays.size(), tiles_index[rank - 1], + xor_masks[0], tiles_index.GetType(), read_element, + write_element, b); + } return Status::OK(); }; - if (launch_dimensions != nullptr) { - TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, - compare_shape, - *launch_dimensions, b) - .EmitLoop(name)); - } else { - TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) - .EmitLoop(name)); - } - - // Set the IR builder insert point to the exit basic block of the outer most - // loop. This ensures later instructions are inserted after this loop nest. - b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); + return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape, + launch_dimensions, b) + .EmitLoop(name); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 2f3bcda2307..556a217322d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,13 +30,14 @@ namespace xla { namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, -// the inner compare loop will not be parallelized. +// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' +// contains only powers of 2, or values 2^k - 1 (k > 0). Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions); + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 84059dd0f71..fd18bfdc3e7 100644 --- a/tensorflow/compiler/xla/service/map_inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using MapInlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(MapInlinerTest, MapMax) { @@ -59,12 +59,12 @@ TEST_F(MapInlinerTest, MapMax) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -93,12 +93,12 @@ TEST_F(MapInlinerTest, MapConstant) { HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -131,12 +131,12 @@ TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 2ca527bc4cb..9ccdd7d8d81 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -257,7 +258,7 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, } void MultiOutputFusion::RecomputeReachability() { - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); } void MultiOutputFusion::UpdateReachability( @@ -317,9 +318,9 @@ bool MultiOutputFusion::Perform() { << instr2->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); } + Update(instr1, instr2); HloInstruction* ret = Fuse(instr1, instr2); set_is_fused(ret == instr1 ? instr2 : instr1); - Update(instr1, instr2); changed = true; VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" << ret->fused_instructions_computation()->ToString( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9508ab2ed1d..1c7583ece72 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c522e7ae23b..c227106511c 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,20 +59,15 @@ string CanonicalPlatformName(const string& name) { /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { - se::MultiPlatformManager::PlatformMap platform_map; - se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( - [&platform_map](se::MultiPlatformManager::PlatformMap* map) { - platform_map = *map; - return se::port::Status::OK(); - }); - if (platform_map.empty()) { + std::vector all_platforms = + se::MultiPlatformManager::AllPlatforms(); + if (all_platforms.empty()) { LOG(WARNING) << "no executor platforms available: platform map is empty"; } // Gather all platforms which have an XLA compiler. std::vector platforms; - for (auto& platform_pair : platform_map) { - auto* platform = platform_pair.second; + for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { platforms.push_back(platform); @@ -222,8 +217,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // fix the number of devices to one. However we do let the user override // this behavior to help run tests on the host that run models in parallel // across multiple devices. - device_count = legacy_flags::GetDebugOptionsFromFlags() - .xla_force_host_platform_device_count(); + device_count = + GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 688cceff0cd..b70cb705747 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -111,7 +111,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( VLOG(2) << "Adding to operand " << i << ": " << operand; if (!is_valid_shape(operand->shape())) { - VLOG(2) << "Skipped: value is not an F32 vector"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } @@ -168,7 +168,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( << instruction->ToString(); if (!is_valid_shape(instruction->shape())) { - VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 0b4e82e8d60..76c6a87f176 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass { // equivalent behavior can be obtained by adding ReducePrecision // instructions after the instructions that pull the F32 arrays out of // the tuples. - // - // TODO(b/64093391): Remove the IsScalar check once this won't cause - // failures on the GPU backend if the ReducePrecision instruction ends up - // inserted between a scalar constant and the init_value argument of a - // Reduce operation. - return shape.element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(shape); + return shape.element_type() == PrimitiveType::F32; } // Is this instruction one such that following or preceding it with a new diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 69e4b534bd8..16fa80d53e7 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,34 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); +} + +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -84,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -113,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -146,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -178,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -215,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -242,7 +269,7 @@ TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -268,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -294,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -321,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -349,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -375,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -411,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -458,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index fcf269eee92..341659b15c4 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,9 +34,10 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase {}; +class ReshapeMoverTest : public HloTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -50,12 +51,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -74,6 +75,7 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { // Verifies that the reshape is not moved, since rng0 is trivially reshapable // and therefore there is no nontrivial reshapes to move. TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( @@ -92,18 +94,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); } TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -117,12 +120,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -130,6 +133,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -143,11 +147,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -177,6 +181,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { // | // reshape4 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction( @@ -196,12 +201,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); @@ -221,6 +226,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { // Verifies that the reshape0 does not sink below add, because param1 is not // trivially reshapable nor is a Reshape/Transpose. TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -232,11 +238,11 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -257,6 +263,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // Verifies that we don't unnecessarily sink reshapes, which are in fact // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -275,12 +282,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -309,6 +316,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { // // (note that reshape1 here is trivial). TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -320,12 +328,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -348,6 +356,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { // For now we treat it as non-trivial, so we verify that we don't sink the // reshapes in this case. TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -362,12 +371,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); @@ -376,6 +385,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -389,14 +399,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -405,6 +415,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); @@ -423,13 +434,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -438,6 +449,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { } TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto pred_shape = ShapeUtil::MakeShape(PRED, {}); @@ -452,11 +464,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -477,6 +489,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { // // We expect reshape{0,1} AND reshape{2,3} to be lifted. TEST_F(ReshapeMoverTest, MultiplePasses) { + auto m = CreateNewVerifiedModule(); auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); @@ -500,14 +513,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -526,11 +539,11 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Transpose(op::Multiply())); } @@ -555,8 +568,8 @@ TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_FALSE(changed); } @@ -580,10 +593,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); } @@ -597,10 +610,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Add())); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 6f9094a5c2e..75f7413b3c3 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -292,7 +292,7 @@ StatusOr> Service::CreateModuleConfig( config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { - config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config->set_debug_options(GetDebugOptionsFromFlags()); } if (execute_backend_ != nullptr && @@ -760,38 +760,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - ExecuteGraphParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - -Status Service::PickParallelResponse( - const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { - // The "result device" selection is a bit hacky, but better than assuming it - // is device 0. We have b/76035356 for restructuring the client API to clean - // up the current asymmetries and support more functionalities. - for (int64 i = 0; i < parallel_result.responses_size(); ++i) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica( - parallel_result.responses(i).output(), 0)); - const Shape& shape = buffer->on_host_shape(); - if (!ShapeUtil::IsEmptyTuple(shape)) { - *result = parallel_result.responses(i); - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return Status::OK(); - } - } - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - VLOG(1) << "Defaulting to device 0 result"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -836,10 +804,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute-graph request"; - +Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { + VLOG(1) << "running compile request"; if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -847,22 +813,21 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return InvalidArgument("programe shape may not be empty"); } - // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); + return InvalidArgument( + "The compile request does not support multiple device handles."); } - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - + std::vector argument_shapes; + absl::c_transform(arg->input_shape_with_layout(), + std::back_inserter(argument_shapes), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(arg->computation().host_program_shape(), - replicated_arguments.front(), - arg->execution_options())); + argument_shapes, &arg->execution_options())); + VLOG(3) << "Compile created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -871,6 +836,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); + + VLOG(1) << "successfully completed 'compile' request"; + return Status::OK(); +} + +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { + VLOG(1) << "running execute request"; + if (!arg->has_handle()) { + return InvalidArgument("execution handle should not be empty"); + } + TF_ASSIGN_OR_RETURN(auto executable, + compilation_cache_.LookUp(arg->handle())); + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + // Check that the replicated_arguments has the same shape and layout as the + // module config used when creating the exectuable. + const int64 num_module_args = + executable->module_config().entry_computation_layout().parameter_count(); + if (num_module_args != arg->arguments_size()) { + return InvalidArgument( + "The executable expects %lld arguments, but sees %lld.", + num_module_args, arg->arguments_size()); + } + for (int64 i = 0; i < num_module_args; i++) { + const Shape& shape_module = + executable->module_config().entry_computation_layout().parameter_shape( + i); + const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape(); + if (!ShapeUtil::Equal(shape_module, shape_arg)) { + return InvalidArgumentStrCat( + "The executable exepcts the ", i, "th argument in shape ", + ShapeUtil::HumanStringWithLayout(shape_module), " but sees ", + ShapeUtil::HumanStringWithLayout(shape_arg)); + } + } + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( execute_backend_->default_stream_executor())); @@ -884,9 +891,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + arg->computation().name(), result->mutable_profile())); + ExecuteAndRegisterResult(executable.get(), replicated_arguments, + execute_backend_.get(), + "result of " + executable->module().name(), + result->mutable_profile())); if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN( @@ -898,7 +906,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - VLOG(1) << "successfully completed 'execute-graph' request"; + VLOG(1) << "successfully completed 'execute' request"; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 8cf1a7b9f01..11e1a79552f 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" +#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" @@ -90,11 +91,14 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Executes a computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. - // Returns global data output and execution timing. - Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + // Compiles a computation into an executable. The request contains the whole + // computation graph. Returns the handle to the executable. + Status Compile(const CompileRequest* arg, CompileResponse* result) override; + + // Executes an executable with the provided global data passes as immutable + // arguments. The request contains the handle to the executable. Returns + // global data output and execution timing. + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each @@ -179,10 +183,6 @@ class Service : public ServiceInterface { absl::Span arguments, const ExecutionOptions& execution_options); - // Picks a parallel response and fills the result. - Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, - ExecuteResponse* result); - // Prepare the executors for executing parallel. StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, @@ -254,11 +254,6 @@ class Service : public ServiceInterface { Backend* backend, absl::Span device_handles, absl::Span result_tags, ExecutionProfile* profile); - // Executes a single computation which has more than one target device. - // The N devices are expected to all return an empty tuple, but one, which - // will be the result of this computation. - Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. @@ -281,6 +276,9 @@ class Service : public ServiceInterface { ServiceOptions options_; + // Cache containing previously built Executables. + CompilationCache compilation_cache_; + // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2f8f092303e..61a60ef9efa 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2031,6 +2031,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } +/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( + const Shape& shape, int64 dimension) { + if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", + dimension); + } + + return ShapeUtil::MakeShape(S64, {}); +} + /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { @@ -2833,6 +2843,15 @@ Status ValidateScatterDimensionNumbers( } } + // Validate window size. + auto window_size = dim_numbers.update_window_dims_size() + + dim_numbers.inserted_window_dims_size(); + if (window_size != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Scatter op has window of size %d; doesn't match operand of rank %d.", + window_size, ShapeUtil::Rank(operand_shape)); + } + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. if (dim_numbers.scatter_dims_to_operand_dims_size() != scatter_indices_shape[dim_numbers.index_vector_dim()]) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index cd4e5ab52ca..31ef4b2e410 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -291,6 +291,9 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + static StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64 dimension); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7b65e8c1c9d..4639e32db4d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -2673,5 +2673,23 @@ TEST_F(ScatterGatherShapeInferenceTest, << statusor.status(); } +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_InsufficientWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Scatter op has window of size 4; doesn't match operand of rank 5.")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 56952e3adae..28a30b5ee2d 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -157,4 +157,23 @@ void ScopedShapedBuffer::Deallocate() { } } +ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { + const xla::Shape& sub_on_host_shape = + xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); + const xla::Shape& sub_on_device_shape = + xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); + + ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, + memory_allocator(), device_ordinal()); + auto src_it = buffers().find(index); + auto dst_it = output.buffers().begin(); + while (dst_it != output.buffers().end()) { + dst_it->second = src_it->second; + src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0); + ++src_it; + ++dst_it; + } + return output; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e1d26da4a20..f5210c9cfa6 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -176,6 +176,11 @@ class ScopedShapedBuffer : public ShapedBuffer { // It's the caller's job to ensure that the memory contained therein is freed. TF_MUST_USE_RESULT ShapedBuffer release(); + // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer + // that holds ownership of the subtree. Sets the buffers corresponding to the + // subtree to null in 'this'. + ScopedShapedBuffer TakeSubTree(ShapeIndexView index); + protected: void Deallocate(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index d69e6362e91..ca64bd3c8dd 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -107,5 +109,79 @@ TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { // TestAllocator's destructor checks that all memory was freed. } +TEST(ScopedShapedBufferTest, TestTakeSubTree) { + TestAllocator allocator; + + Shape s = ShapeUtil::MakeShape(F32, {1}); + s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); + s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); + + ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + sb.buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + TF_ASSERT_OK_AND_ASSIGN( + OwningDeviceMemory m, + allocator.Allocate(/*device_ordinal=*/0, /*size=*/77)); + *buffer = m.Forget(); + }); + ShapeTree buffers = sb.buffers(); + + // Takes a subtree out of 'sb', and verifies the buffers are as expected. + xla::ShapeIndex subtree_index = {1}; + ScopedShapedBuffer output = sb.TakeSubTree(subtree_index); + + output.buffers().ForEachElement([&](const xla::ShapeIndex& sub_index, + const se::DeviceMemoryBase& buffer) { + xla::ShapeIndex orig_index = subtree_index; + for (int i : sub_index) { + orig_index.push_back(i); + } + EXPECT_TRUE(buffers.find(orig_index)->second.IsSameAs(buffer)); + }); + sb.buffers().ForEachElement( + [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + if (ShapeIndexView(index).StartsWith(subtree_index)) { + EXPECT_TRUE(buffer.is_null()); + } else { + EXPECT_TRUE(buffers.find(index)->second.IsSameAs(buffer)); + } + }); +} + +// Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_TakeSubTree(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + TestAllocator allocator; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)shaped_buffer.TakeSubTree(/*index=*/{fan_out / 2}).release(); + } + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_TakeSubTree) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 79b5c09abb3..7a565bf0768 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewModule("fuse_with_constant_operands"); + auto module = CreateNewUnverifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index d9ebebf74ed..10ef2d38fa2 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -48,7 +48,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { } void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); module_->AddEntryComputation(std::move(computation)); } @@ -809,7 +809,7 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { class PointsToAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -1176,7 +1176,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -1211,7 +1211,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 516754e2110..65b0f8c8044 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloVerifiedTestBase { +class TupleSimplifierTest : public HloTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -65,10 +65,10 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, scalar_shape_, "param2")); builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -78,10 +78,10 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { HloInstruction::CreateParameter(0, tuple_shape_, "param")); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -98,12 +98,12 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), gte); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -125,13 +125,13 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -157,12 +157,12 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), element); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -182,12 +182,12 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -207,19 +207,19 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { // Verify that the root computation can be excluded - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloInstruction* p0; HloInstruction* p1; @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module, /*change_expected=*/true, /*exclude_entry=*/true); + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 541b117e029..68e2569f66b 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" namespace xla { @@ -229,4 +232,96 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, return nullopt; } +// If the only user of this instruction is a get-tuple-element, return that +// get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may +// get a false negative if there are several copies of the same GTE, or there +// are unused GTEs, but we can live with this. +static HloInstruction* GetOnlyGTE(HloInstruction* inst) { + if (inst->user_count() != 1) { + return nullptr; + } + + HloInstruction* user = inst->users().back(); + if (user->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; + } + return user; +} + +optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { + // If we know the exact trip count, it's also the upper bound. + auto exact_trip_count = ComputeWhileLoopTripCount(while_op); + if (exact_trip_count) { + VLOG(2) << "Loop has exact trip count."; + return exact_trip_count; + } + + // There is one more case we know how to handle. If the loop condition only + // looks at one element of the tuple, and the loop body sets this element to a + // constant, there are two options: + // 1) Evaluating the condition on this constant returns true. In this case, + // the loop either executes 0 times, or is an infinite loop, depending on the + // init value. + // 2) Evaluating the condition on this constant returns false. In this case, + // the loop executes 0 or 1 times, depending on the init value. This means + // that, regardless of the init value, the upper bound on the trip count is 1. + + // Check whether the condition depends on a single parameter, and find out + // which. + auto* while_cond = while_op->while_condition(); + auto* while_cond_param = while_cond->parameter_instruction(0); + auto* cond_gte = GetOnlyGTE(while_cond_param); + if (!cond_gte) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // Now check whether this gets set to a constant by the while body. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(3) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + int64 indvar_index = cond_gte->tuple_index(); + auto* while_body_indvar = while_body_root->operand(indvar_index); + if (while_body_indvar->opcode() != HloOpcode::kConstant) { + VLOG(3) << "While body does not set the IV to a constant: " + << while_body_indvar->ToString(); + return nullopt; + } + + // We have a constant. Evaluate the condition on this constant. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), + /*dest_shape_index=*/{indvar_index}, + /*src_shape_index=*/{})); + StatusOr eval_result = + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + + if (!eval_result.ok()) { + VLOG(2) << "Couldn't evaluate while loop condition."; + return nullopt; + } + + Literal cond_result_pred = std::move(eval_result.ValueOrDie()); + CHECK(ShapeUtil::Equal(cond_result_pred.shape(), + ShapeUtil::MakeShape(PRED, {}))); + + // Per the explanation above, if the evaluated condition returns false, the + // loop executes at most once. + bool cond_returns_true = cond_result_pred.GetFirstElement(); + if (!cond_returns_true) { + VLOG(2) << "Upper bound on the trip count is 1"; + return 1; + } + + VLOG(2) << "Loop has no known upper bound on the trip count."; + return nullopt; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf497f4892b..ac69a727bd6 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -28,6 +28,10 @@ namespace xla { absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, int64 max_value_returned = 128); +// Returns an upper bound on the trip count of the loop if it's statically +// known, nullopt otherwise. +absl::optional ComputeWhileLoopTripCountUpperBound( + HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc new file mode 100644 index 00000000000..1da0fbeac89 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class WhileLoopAnalysisTest : public HloTestBase {}; + +TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(-1) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); +} + +TEST_F(WhileLoopAnalysisTest, NoUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(42) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, ExactBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + index = s32[] get-tuple-element(p_body), index=1 + one = s32[] constant(1) + inc = s32[] add(index, one) + ROOT root = (f32[2], s32[]) tuple(val, inc) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] less-than(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] constant(0) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 067cfcc17d6..8b381dec073 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -46,8 +46,9 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( return Status::OK(); } -StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); const HloInstruction& init_value = *while_instr->operand(0); @@ -57,24 +58,48 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( bool changed = false; - for (HloInstruction* invariant_gte : - WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - int64 index = invariant_gte->tuple_index(); + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64 index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); - // Should have at least one user that's not while_body_root. - if (invariant_gte->user_count() <= 1) { + // Original value should be a constant. + if (invariant_value.opcode() != HloOpcode::kConstant) { continue; } - if (invariant_value.opcode() == HloOpcode::kConstant) { - auto* constant_instr = + // Sink into the while_body. + // Should have at least one user that's not while_body_root. + if (invariant_body_gte->user_count() > 1) { + HloInstruction* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_gte, constant_instr, while_body->root_instruction(), + invariant_body_gte, constant_instr, while_body->root_instruction(), index)); changed = true; } + + // Check if there is a corresponding GTE in while_conditional. + auto it = conditional_gte_index_to_insts.find(index); + if (it == conditional_gte_index_to_insts.end()) { + continue; + } + + for (HloInstruction* invariant_cond_gte : it->second) { + // Should have at least one user. + if (invariant_cond_gte->user_count() > 0) { + HloInstruction* constant_instr = while_cond->AddInstruction( + invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR( + invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); + changed = true; + } + } } return changed; @@ -115,10 +140,8 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { } for (HloInstruction* while_instr : while_instrs) { - // We only sink into while loop bodies, but this can be extended to - // transform conditions as well. TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileBody(while_instr)); + TrySinkingConstantsIntoWhileLoop(while_instr)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 577bad6c706..a866bc1264b 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { // Sinks while loop invariant values that happen to be constants into the while -// loop body. This is probably not a win in isolation but may unlock further -// optimizations like constant folding. +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like constant folding. // // state = (..., const, ...) // while (pred(state)) { @@ -46,22 +46,19 @@ namespace xla { // tuple trivially loop invariant. WhileLoopSimplifier will later get rid of // `v`. // -// We only sink into while loop bodies, but this can be extended to transform -// conditions as well. -// // TODO(b/79121449): We should also sink broadcasts of constants. class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; absl::string_view name() const override { - return "while-loop-invariant-code-motion"; + return "while-loop-constant-sinking"; } StatusOr Run(HloModule* module) override; private: - StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); + StatusOr TrySinkingConstantsIntoWhileLoop(HloInstruction* while_instr); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index d17b86fab5b..75d406435b6 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -242,5 +242,178 @@ ENTRY entry { } } } + +TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1 + ROOT root = (f32[],f32[]) tuple(add, p_body.1) +} + +condition { + p_cond = (f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + while_init = (f32[],f32[]) tuple(const_0, const_1) + ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[],(f32[],f32[])) parameter(0) + p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0 + p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1 + p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0 + add = f32[] add(p_b.0, p_b.1.0) + ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1) +} + +condition { + p_c = (f32[],(f32[],f32[])) parameter(0) + p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 + p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 + p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 + ROOT result = pred[] less-than(p_c.0, p_c.1.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) + ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::Lt(_, op::GetTupleElement(op::Constant()))); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); + for (const HloInstruction* inst : while_condition->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add.0 = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + add.1 = f32[] add(p_body.1, const) + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.0 = pred[] less-than(p_cond.0, p_cond.2) + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + ROOT result = pred[] and(lt.0, lt.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(0) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 9795b2830b6..b7c28bfac78 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" @@ -143,6 +144,12 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( string while_instr_name = while_instr->ToString(print_no_metadata); VLOG(2) << "Trying to hoist from " << while_instr_name; + auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr); + if (maybe_upper_bound && *maybe_upper_bound <= 1) { + VLOG(2) << "Loop has a trip count of at most 1, skipping."; + return false; + } + HloComputation* while_body = while_instr->while_body(); // Maps instructions in the while body to instructions hoisted outside the @@ -180,6 +187,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return false; } + // LICM in the presence of domain instructions is complex, bail. + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + return false; + } + } + // instructions_to_replace[i] is hoisted into a loop invariant instruction // replacement_instructions[i]. std::vector instructions_to_replace; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b7..046ccb2d3f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -26,7 +26,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { +class WhileLoopInvariantCodeMotionTest : public HloTestBase { public: // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. @@ -58,6 +58,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( } TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -76,19 +77,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -100,6 +100,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { } TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -135,19 +136,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -173,6 +173,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistTriviallyLoopVaryingComputation) { // Basic negative test: the add expression is not loop invariant. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -189,20 +190,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); @@ -210,6 +211,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -228,25 +230,26 @@ TEST_F(WhileLoopInvariantCodeMotionTest, builder.AddInstruction( HloInstruction::CreateTuple({gte_1, gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = @@ -267,7 +270,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -277,14 +280,14 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -294,6 +297,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the // bitcast either. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); @@ -317,7 +321,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -327,15 +331,15 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -346,6 +350,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); Shape while_shape = @@ -367,21 +372,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -396,6 +400,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -416,22 +421,23 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - while_body = module().AddEmbeddedComputation(builder.Build()); + while_body = m->AddEmbeddedComputation(builder.Build()); } HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -439,7 +445,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { HloComputation::Builder builder(TestName() + ".passthrough"); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + HloComputation* result = m->AddEmbeddedComputation(builder.Build()); result->AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); @@ -450,11 +456,11 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } @@ -482,14 +488,14 @@ ENTRY entry { )"; TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN( bool simplified_loop, - WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get())); EXPECT_TRUE(simplified_loop); - HloComputation* while_body = module().GetComputationWithName("wide.body"); + HloComputation* while_body = m->GetComputationWithName("wide.body"); ASSERT_NE(while_body, nullptr); // We expect the while body to be the equivalent of: @@ -523,10 +529,44 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { } TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0) + val.0 = f32[2] get-tuple-element(p_body), index=0 + val.1 = f32[2] get-tuple-element(p_body), index=1 + add = f32[2] add(val.0, val.1) + const = s32[] constant(-1) + ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const) + } + + condition { + p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=3 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1) + ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(module.get())); EXPECT_FALSE(simplified_loop); } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 630d71e5ca2..6f924a29d8a 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -20,40 +20,14 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { using absl::optional; - -// Determines whether the given instruction is a send/recv node, or has a -// subcomputation which contains a send/recv node. -static bool IsOrContainsSendOrRecv(const HloInstruction* instr); - -// Determines whether the given computation contains a send or recv node. -static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto* instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr)) { - return true; - } - } - return false; -} - -static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kSendDone || - instr->opcode() == HloOpcode::kRecv || - instr->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (const auto& subcomp : instr->called_computations()) { - if (ContainsSendOrRecv(subcomp)) { - return true; - } - } - return false; -} +using hlo_query::ContainsInstrWithOpcode; // Tries to remove elements in a while loop's tuple that aren't used within the // loop. @@ -253,7 +227,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond), /*extras=*/{}); + make_while_computation_replacements(while_cond)); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -266,8 +240,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements), - /*extras=*/{}); + while_body->CloneWithReplacements(std::move(while_body_replacements)); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -458,6 +431,180 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { return changed_cond || changed_body; } +// Converts a flat list of instructions into a tuple of the desired shape. For +// example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns +// a tuple of value ((A, B), C). +// +// desired_shape must be a tuple. (This precondition allows us to return a +// unique_ptr rather than a raw ptr.) +static std::unique_ptr UnflattenTupleInstr( + absl::Span instrs, const Shape& desired_shape, + std::vector>* new_instrs) { + CHECK(ShapeUtil::IsTuple(desired_shape)) + << ShapeUtil::HumanString(desired_shape); + + // For each child shape in `desired_shape`, slice out the correct number of + // `instrs` and call UnflattenTupleInstr recursively. At each step we remove + // elements from `instrs` so that it only contains instructions we have not + // yet processed. + std::vector elems; + for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { + const Shape& subshape = desired_shape.tuple_shapes(i); + if (!ShapeUtil::IsTuple(subshape)) { + elems.push_back(instrs[0]); + instrs.remove_prefix(1); + continue; + } + + // Count the number of leaf nodes underneath desired_shape[i]. + int64 num_leaves = 0; + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + ++num_leaves; + } + }); + + std::unique_ptr subinstr = + UnflattenTupleInstr(instrs.subspan(0, num_leaves), + desired_shape.tuple_shapes(i), new_instrs); + elems.push_back(subinstr.get()); + new_instrs->push_back(std::move(subinstr)); + instrs.remove_prefix(num_leaves); + } + return HloInstruction::CreateTuple(elems); +} + +// Builds a vector whose elements are the values in the flattened tuple for +// `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the +// vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C). +static std::vector GetFlatTupleElems( + HloInstruction* instr, + std::vector>* new_instrs) { + const auto& shape = instr->shape(); + if (!ShapeUtil::IsTuple(shape)) { + return {instr}; + } + std::vector elems; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + new_instrs->push_back( + HloInstruction::CreateGetTupleElement(subshape, instr, i)); + auto* gte = new_instrs->back().get(); + auto flattened_subshape = GetFlatTupleElems(gte, new_instrs); + elems.insert(elems.end(), flattened_subshape.begin(), + flattened_subshape.end()); + } + return elems; +} + +static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + if (!ShapeUtil::IsNestedTuple(while_shape)) { + return false; + } + + // Cowardly refuse to perform this optimization in the presence of kDomain + // instructions, which may reference other instructions in the loop and + // therefore make this complicated. + if (ContainsInstrWithOpcode(while_body, {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_cond, {HloOpcode::kDomain})) { + return false; + } + + std::vector flattened_shape_elems; + ShapeUtil::ForEachSubshape(while_shape, + [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + flattened_shape_elems.push_back(s); + } + }); + Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto nested = [&](HloInstruction* instr) { + std::vector gtes; + const Shape& flat_shape = instr->shape(); + for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) { + gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement( + flat_shape.tuple_shapes(i), instr, i))); + } + auto nested_instr = + UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs); + CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape)) + << ShapeUtil::HumanString(nested_instr->shape()) << " vs " + << ShapeUtil::HumanString(while_shape); + return nested_instr; + }; + + auto flattened = [&](HloInstruction* instr) { + return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs)); + }; + + // Create a new while-condition computation, where parameter 0 has flat shape + // but all uses of it go through the nested shape. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Create a new while-body computation, where parameter 0 has a flat shape and + // all uses of it go through the nested shape, and where the root has a flat + // shape constructed from the old nested root. + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_body->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + flattened(add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile( + flattened_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + computation->AddInstruction(flattened(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -478,32 +625,46 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { for (HloInstruction* while_op : while_ops) { // We can't remove while loops that contain send/recv nodes, because we rely // on the particular loop structure around the node matching on the send and - // recv sides. Removing dead while params requires us to remove the loop + // recv sides. Other while simplifications require us to remove the loop // and replace it with a new one, so we can't do that either. - if (ContainsSendOrRecv(while_op->while_body()) || - ContainsSendOrRecv(while_op->while_condition())) { + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { VLOG(2) << "Not attempting to simplify while loop because it contains a " "send/recv node: " << while_op->ToShortString(); continue; } - StatusOr result = TryPropagateConstant(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); + changed |= result; - result = TryRemoveWhileLoop(while_op); - TF_RETURN_IF_ERROR(result.status()); - if (result.ValueOrDie()) { - changed = true; - // Don't try to remove dead while params after successfully removing the - // while loop -- that would result in use-after-free nastiness. + TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); + changed |= result; + if (result) { + // Don't continue simplifying after successfully removing the while loop + // -- that would result in use-after-free nastiness. continue; } - result = TryRemoveDeadWhileParams(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); + changed |= result; + if (result) { + // Successfully flattening nested tuples results in us cloning and + // replacing the while loop, meaning that `while_op` is no longer valid. + continue; + } + + TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; + if (result) { + // Successfully removing dead while params results in us cloning and + // replacing the while loop, meaning that `while_op` is no longer valid. + continue; + } } XLA_VLOG_LINES(3, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 0bc5a0107bb..a378f179c63 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,11 +25,22 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. +// // - A while loop with static trip count of 1 is replaced by its body (sans // loop). +// // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // +// - If the while loop's parameter is a nested tuple, it's flattened to a +// single-level tuple. This is good because it usually reduces the number of +// kTuple instructions, but also because it unlocks additional optimizations +// (e.g. removing unused loop parameters). +// +// Flattening nested while loop tuples adds a whole mess of likely unnecessary +// kGetTupleElement and kTuple operations to the graph. We expect that tuple +// simplifier will be run afterwards. +// class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 1c892ba179e..05005e0b262 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -27,18 +29,21 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopSimplifierTest : public HloVerifiedTestBase { +class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. - void MakeModuleWithSimpleLoop(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoop(int num_iters); // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to // the loop-condition through an element of a tuple which is the // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string_template = R"( HloModule SimpleLoop SimpleLoop.body { @@ -67,10 +72,11 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( int num_iters) { string hlo_string_template = R"( HloModule SimpleLoopWithIndirectLoopBound @@ -104,60 +110,55 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/2); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithControlDependencySimplifiedDependencyPreserved) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -166,9 +167,8 @@ TEST_F(WhileLoopSimplifierTest, // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -179,13 +179,12 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -194,7 +193,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and @@ -202,16 +201,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A non-tuple shaped loop shouldn't be simplified or crash the compiler. @@ -236,8 +234,8 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A while loop that does nothing else besides swapping tuple elements @@ -268,8 +266,8 @@ TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each @@ -297,8 +295,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. @@ -320,8 +318,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't @@ -348,8 +346,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be @@ -390,16 +388,15 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { } )"; - ParseAndVerifyModule(hlo_string); - HloModule* the_module = &module(); - EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. HloInstruction* new_while_op = - *std::find_if(the_module->entry_computation()->instructions().begin(), - the_module->entry_computation()->instructions().end(), + *std::find_if(m->entry_computation()->instructions().begin(), + m->entry_computation()->instructions().end(), [&](const HloInstruction* instr) { return (instr->opcode() == HloOpcode::kWhile && instr->name() != "while"); @@ -440,8 +437,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, @@ -473,8 +470,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { @@ -505,8 +502,65 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { + const string hlo_string = R"( + HloModule Test + Body { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ta = (s32[1]) get-tuple-element(param), index=0 + a = s32[1] get-tuple-element(ta), index=0 + a.1 = s32[1] add(a, a) + tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1 + ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + } + Cond { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + b = s32[2] constant({0,1}) + c = s32[3] constant({0,1,2}) + d = s32[4] constant({0,1,2,3}) + ta = (s32[1]) tuple(a) + td = (s32[4]) tuple(d) + tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td) + init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + + const auto& instrs = m->entry_computation()->instructions(); + HloInstruction* new_while = + *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + Shape flat_tuple = + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") + .ValueOrDie(); + SCOPED_TRACE(m->ToString()); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") + .ValueOrDie())); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1f583ca44b7..039ccda7322 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -270,4 +272,17 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { return result; } +/*static*/ absl::flat_hash_map> +WhileUtil::GetGTEsMapForWhileConditional( + const HloComputation& while_conditional) { + absl::flat_hash_map> result; + for (HloInstruction* user : + while_conditional.parameter_instruction(0)->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + result[user->tuple_index()].push_back(user); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 524dcec5f12..cba41ccd8b1 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -85,6 +87,13 @@ class WhileUtil { // Assumes `while_body` is the body computation of the while loop in question. static std::vector GetInvariantGTEsForWhileBody( const HloComputation& while_body); + + // Returns a map of index to GetTupleElement instructions in + // `while_conditional` that access elements in the parameter tuple. Assumes + // `while_conditional` is the conditional computation of the while loop in + // question. + static absl::flat_hash_map> + GetGTEsMapForWhileConditional(const HloComputation& while_conditional); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index b9ef18892d7..a546a6d39cc 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -45,7 +45,8 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - auto module = CreateNewModule("zero_sized_elimination_test_module"); + auto module = + CreateNewUnverifiedModule("zero_sized_elimination_test_module"); module->AddEntryComputation(builder_.Build()); return ZeroSizedHloElimination{}.Run(module.get()); } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 14c35e7b84f..33edbd1b20d 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,8 +47,11 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Compile(const CompileRequest* arg, + CompileResponse* result) = 0; + + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 17120e610cb..d0c35d8dee4 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -74,6 +74,11 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { return out; } +bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { + return size() >= prefix.size() && + indices_.subspan(0, prefix.size()) == prefix.indices_; +} + namespace { // Returns whether the given primitive type corresponds to an array shape. diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 191ab04759f..a7a3026cf3f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -147,6 +147,9 @@ class ShapeIndexView { string ToString() const; + // Returns true if this shape index starts with 'prefix'. + bool StartsWith(ShapeIndexView prefix) const; + private: absl::Span indices_; }; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d395c9a4cee..db34d34f969 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,7 +44,7 @@ cc_library( testonly = True, srcs = ["xla_internal_test_main.cc"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -117,12 +117,12 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", @@ -141,44 +141,6 @@ cc_library( ], ) -cc_library( - name = "hlo_verified_test_base", - testonly = True, - srcs = ["hlo_verified_test_base.cc"], - hdrs = ["hlo_verified_test_base.h"], - deps = [ - ":hlo_test_base", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "hlo_verified_test_base_test", - srcs = ["hlo_verified_test_base_test.cc"], - deps = [ - ":hlo_test_base", - ":hlo_verified_test_base", - ":test_macros_cpu", - ":test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -868,7 +830,8 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - shard_count = 25, + shard_count = 40, + tags = ["optonly"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 9966e4606ef..9930bfc95c2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 3aebf784664..211d004ec8c 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -596,6 +596,272 @@ TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { this->RunTest(); } +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 1024}; + std::vector filter_dims = {3, 3, 1, 1024}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1024); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(4096, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { + this->RunTest(); +} + template class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { public: diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1407e68d9a3..3622f2c1e84 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); @@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, @@ -143,7 +143,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -175,7 +175,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -209,7 +209,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); Literal result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 001490c6a8c..738b6442354 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -70,7 +70,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -85,7 +85,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -106,7 +106,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -130,7 +130,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = @@ -155,7 +155,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 4d4b676a538..d1fddf9d6b4 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -81,7 +81,7 @@ class FusionTest : public HloTestBase { } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -183,7 +183,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +231,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +266,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +290,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +314,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +344,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +359,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +374,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +389,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +404,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +419,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +434,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +449,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +465,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +483,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +501,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +519,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +541,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +559,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,7 +587,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -607,7 +607,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +630,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +682,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7ab2ecda586..d8fa00272f8 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -85,6 +85,23 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + } +} + HloTestBase::HloTestBase(bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, std::function @@ -100,17 +117,48 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool allow_mixed_precision_in_hlo_verifier, std::function instruction_can_change_layout_func) - : test_runner_(test_platform), reference_runner_(reference_platform) { + : test_runner_(test_platform), + reference_runner_(reference_platform), + verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, instruction_can_change_layout_func); } -std::unique_ptr HloTestBase::CreateNewModule(const string& name) { +std::unique_ptr HloTestBase::CreateNewUnverifiedModule( + const string& name) { return absl::make_unique(name, GetModuleConfigForTest()); } +std::unique_ptr HloTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + +StatusOr> +HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique(TestName(), config); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + return std::move(module); +} + +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + /* static */ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, HloModule* module) { @@ -135,7 +183,7 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { } DebugOptions HloTestBase::GetDebugOptionsForTest() { - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 217428befa4..366726d90b4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -38,6 +38,31 @@ limitations under the License. namespace xla { +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and @@ -72,7 +97,27 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - std::unique_ptr CreateNewModule(const string& name = TestName()); + // + // This returns a vanilla HloModule that doesn't run the HLO verifier on + // destruction. + std::unique_ptr CreateNewUnverifiedModule( + const string& name = TestName()); + + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + // Parses the given string and returns module as a vanilla, unverified + // HloModule. + StatusOr> ParseAndReturnUnverifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -247,6 +292,8 @@ class HloTestBase : public ::testing::Test { HloRunner test_runner_; HloRunner reference_runner_; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc deleted file mode 100644 index 8bd0a729b77..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - } -} - -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), - verifier_layout_sensitive_(layout_sensitive), - allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} - -HloModule& HloVerifiedTestBase::module() { - if (!module_) { - module_ = CreateNewVerifiedModule(TestName()); - } - return *module_; -} - -HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(CreateNewVerifiedModule(name)); - return modules_.back().get(); -} - -void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - CHECK(!module_) << "Called ParseModule when test already has a module."; - module_ = CreateNewVerifiedModule(TestName()); - TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); - module_->VerifyOrAddFailure("after parsing"); -} - -StatusOr> -HloVerifiedTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { - auto module = CreateNewVerifiedModule(TestName()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); - return std::move(module); -} - -std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( - const string& name) { - return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h deleted file mode 100644 index 388a99bb364..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ - -#include -#include -#include - -#include "absl/base/macros.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - -// A base class for HLO tests that stores a default VerifiedHloModule. -class HloVerifiedTestBase : public HloTestBase { - protected: - HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - - // Constructs a default shape verifier. - std::unique_ptr MakeShapeVerifier(); - - // Returns the default HloModule, lazily creating it if necessary via - // HloTestBase::CreateNewModule(). - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule& module(); - - ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") - void ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Creates a new module for a test, and stores it in modules_ so it can be - // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent - // creation of unverified modules. - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule* CreateNewModule(const string& name = TestName()); - - // Creates and returns a verified HLO module with the given name. - std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); - - private: - // It is confusing to store modules created by module() and CreateNewModule() - // in different fields, but it allows us to migrate tests to - // HloVerifiedTestBase more easily, so it's a win because we can verify more - // modules. See b/80488902. - // - // Lazily populated. Access via module(). - std::unique_ptr module_; - - // Populated by calls to CreateNewModule. - std::vector> modules_; - - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc deleted file mode 100644 index 5c0263e811f..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// This class includes unit tests which are expected to fail because invalid HLO -// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to -// include the necessary gunit parts to test this test machinery (needs the -// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the -// disabled tests enabled and failures can be manually compared against -// expectations. -class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; - -XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { - // Test shouldn't fail if no module is created at all. -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { - // Use module() to lazily create an empty module, build it up, and verify no - // failures. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { - // Use module() to lazily create an empty module and build up an invalid - // module. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); - - *hlo_module.entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { - // Call CreateNewModule and build up a valid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { - // Call CreateNewModule and build up a invalid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); - - *module->entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndVerifyModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - ParseAndVerifyModule(hlo_string); - EXPECT_EQ(module().entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_EQ(module->entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} - -RANDOM GARBAGE -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleBad - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[1234] add(x,y) -} -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index c622b295094..a78ccacec11 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -68,7 +68,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); compiler->SetPreOptimizationHook(pre_opt_hook); @@ -90,7 +90,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - std::unique_ptr hlo_module = CreateNewModule(); + std::unique_ptr hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto module_group = absl::make_unique("test_module_group"); @@ -124,9 +124,9 @@ class LLVMCompilerTest : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - static std::unique_ptr CreateNewModule() { + static std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index ca7637a0cfa..3f5135438fc 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = @@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 2cc33ab0963..3fb69419e73 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -166,6 +166,26 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +TEST_F(SliceTest, SliceOfReshape) { + Array2D values(2 * 3 * 24, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR2FromArray2D(&builder, values); + auto reshape = Reshape(original, {24, 3, 2, 7}); + Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1}); + ComputeAndCompare(&builder, {}); +} + +TEST_F(SliceTest, SliceOfCollapsingReshape) { + Array4D values(2, 3, 5, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR4FromArray4D(&builder, values); + auto reshape = Reshape(original, {2 * 3 * 5, 7}); + Slice(reshape, {0, 0}, {4, 7}, {1, 1}); + ComputeAndCompare(&builder, {}); +} + XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { Array4D values(2, 4, 6, 8); values.FillRandom(3.14f); @@ -253,7 +273,6 @@ XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } - // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b34fd0f2e87..a2b7c26331b 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -28,7 +28,7 @@ namespace { class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateToken()); @@ -38,8 +38,22 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } +XLA_TEST_F(TokenHloTest, TokenInTuple) { + std::unique_ptr module = CreateNewUnverifiedModule(); + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateTuple({token})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + Literal token_literal = LiteralUtil::CreateToken(); + EXPECT_TRUE( + LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal}))); +} + XLA_TEST_F(TokenHloTest, TokenTree) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); @@ -54,7 +68,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); @@ -75,7 +89,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateParameter( 0, @@ -95,7 +109,7 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 376559500ef..ca036f1ae0d 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -91,8 +91,8 @@ Status ParseOneProfileOutputLine( string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; - string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; - string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?"; + string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?"; // The underlined part is what we're trying to match with match_opcode: // @@ -307,6 +307,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, matrix_shape); + SCOPED_TRACE(profile_output); std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); @@ -318,14 +319,13 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = std::find_if( - while_body_profile_start, profile_output_lines.end(), - [](absl::string_view s) { - return absl::StartsWith(s, "********** microseconds report **********"); - }); + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds "); + }); - // We emit a blank line before the "********** microseconds report **********" - // line. + // We emit a blank line before the "microseconds report" line. while_body_profile_end--; ASSERT_NE(while_body_profile_end, profile_output_lines.end()); @@ -380,7 +380,7 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 15603619b62..dca0aa52a53 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -15,14 +15,14 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { LOG(ERROR) << "\n" << usage; @@ -49,7 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { // different API than Tensorflow's. testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) - base::SetFlag(&FLAGS_benchmarks, pattern); + absl::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); #else tensorflow::testing::Benchmark::Run(pattern); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 3a086c66bbb..8926bbed2b5 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -33,6 +33,7 @@ cc_library( name = "dumped_computation_to_graphviz_library", srcs = ["dumped_computation_to_graphviz.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -40,7 +41,6 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -78,6 +78,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -91,7 +92,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -207,13 +207,13 @@ tf_cc_binary( name = "dumped_computation_to_tf_graphdef", srcs = ["dumped_computation_to_tf_graphdef.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index c866a13de75..b623556468f 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,7 +54,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = client->GetComputationStats(computation, debug_options) @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 07ef5ff656b..f8bb9a6b1e2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,7 +53,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 109411f99b6..47be9f5adf1 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -47,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -191,8 +191,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, // Run the computation num_runs times, and return the result from the last // execution. - const bool xla_hlo_profile = - legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 65948ef4b0c..28df3b03f39 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -322,6 +322,34 @@ message UnregisterRequest { message UnregisterResponse { } +message CompileRequest { + // The graph to be compiled. + HloModuleProto computation = 1; + + // Options that affect how XLA compiles code to service this request. + ExecutionOptions execution_options = 2; + + // The layouts of the input arguments. If not set, the default layout will be + // used. Although the real arguments are not needed in compilation, the + // layouts of the arguments can affect the compilation. + repeated Shape input_shape_with_layout = 3; +} + +message CompileResponse { + // The handle to the executable. + ExecutionHandle handle = 1; +} + +message ExecuteRequest { + ExecutionHandle handle = 1; + + // The shape and layout of the arguments must be the same as the those of the + // executable's parameters. + repeated GlobalDataHandle arguments = 2; +} + +// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace +// the uses with calls to Compile and Execute. message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b6bd919e2b2..683ccc40f16 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -332,11 +332,13 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - // The F16s and BF16s are encoded in little endian byte order + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 16 + // Next = 18 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 9e3d2454d16..67f475846e5 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["xrt_state_ops.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -21,7 +22,6 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 5678f0905ff..6ab77fbaaf0 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -6,6 +6,24 @@ import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; +message DeviceAssignment { + message ComputationDevice { + message DeviceMeshCoordinates { + // The mesh coordinates for the device. Usually (X, Y, Core), in the order + // in which they are returned in the TopologyProto. + // X = value(0) + // Y = value(1) + // Core = value(2) + repeated int32 value = 1; + } + // As many replicas as there are in the replicated computation. + repeated DeviceMeshCoordinates replica_devices = 1; + } + // As many ComputationDevice as many there are computations (number + // of cores per replica). + repeated ComputationDevice computation_devices = 1; +} + // Options for an XLA compilation. message XLAComputationConfig { // The number of replicas the computation will be run on. If this is @@ -23,6 +41,11 @@ message XLAComputationConfig { // computation. per_core_args_and_result_shapes is optional for a // single-core computation. repeated xla.ProgramShape per_core_program_shape = 5; + // Describes how replicated computation instances should be assigned to + // devices. There are num_cores_per_replica computations, and each one will be + // sent and executed to the set of replica device numbers described in the + // DeviceAssignment proto. + DeviceAssignment device_assignment = 6; } // Options and XLA computation for a compilation. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f45010ec26e..1fffbb5f660 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -142,7 +142,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): name="StatsAccumulator/{}".format(self._name)) # Allocate both stats accumulator and quantile accumulator on the same # device so that we can build splits with fewer RPCs. - with ops.colocate_with(self._stats_accumulator.resource()): + with ops.colocate_with(self._stats_accumulator.resource_handle): self._quantile_accumulator = quantile_ops.QuantileAccumulator( init_stamp_token, epsilon=epsilon, @@ -268,8 +268,8 @@ class DenseSplitHandler(InequalitySplitHandler): handler = make_dense_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -447,8 +447,8 @@ class SparseSplitHandler(InequalitySplitHandler): handler = make_sparse_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 05ce0884ccf..356ae337685 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -34,7 +34,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -62,7 +62,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2, 1], @@ -91,7 +91,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -123,7 +123,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -133,7 +133,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -164,7 +164,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -175,7 +175,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): deserialize = ( - accumulator.deserialize( + accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], @@ -223,7 +223,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -261,7 +261,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -299,7 +299,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -336,7 +336,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -349,7 +349,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -386,7 +386,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -399,7 +399,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): 0.08]]]) with ops.control_dependencies([op1]): - deserialize = accumulator.deserialize( + deserialize = accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 25b2c9e2fd7..fca22c71a83 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import @@ -31,6 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -82,6 +85,44 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): tree_ensemble_config=restored_tensors[1]) +class TreeEnsembleVariable(tracking.TrackableResource): + """A Tree ensemble model.""" + + def __init__(self, stamp_token, tree_ensemble_config, name, container=None): + self._stamp_token = stamp_token + self._tree_ensemble_config = tree_ensemble_config + self._name = name + self._container = container + self._init_op = None + super(TreeEnsembleVariable, self).__init__() + + def create_resource(self): + return gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_ensemble_variable( + self.resource_handle, self._stamp_token, self._tree_ensemble_config) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_ensemble_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return { + "tree_ensemble_variable": + functools.partial( + TreeEnsembleVariableSavable, + tree_ensemble_handle=self.resource_handle, + create_op=self.initializer) + } + + def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, @@ -99,12 +140,11 @@ def tree_ensemble_variable(stamp_token, A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: - resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( - container, shared_name=name, name=name) - create_op = gen_model_ops.create_tree_ensemble_variable( - resource_handle, stamp_token, tree_ensemble_config) - is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( - resource_handle) + tree_ensemble_var = TreeEnsembleVariable(stamp_token, tree_ensemble_config, + name, container) + resource_handle = tree_ensemble_var.resource_handle + create_op = tree_ensemble_var.initializer + is_initialized_op = tree_ensemble_var.is_initialized() # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 19b6b3296db..0c319cc9bd1 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,12 +33,60 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): +class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for QuantileAccumulator.""" + + def __init__(self, resource_handle, create_op, name): + self._resource_handle = resource_handle + self._create_op = create_op + stamp_token, state, are_buckets_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_serialize(resource_handle)) + # slice_spec is useful for saving a slice from a variable. + # It's not meaningful in quantile accumulator. + slice_spec = "" + def make_save_spec(tensor, suffix): + return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix) + + specs = [make_save_spec(stamp_token, "_stamp")] + specs += [make_save_spec(state, "_state")] + specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] + specs += [make_save_spec(buckets, "buckets")] + super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, + specs, name) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated quantile accumulator from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. + + Returns: + The operation that restores the state of the quantile accumulator. + """ + # Read the restored tensors with the same order that were added to saving + # spec. + stamp_token = restored_tensors[:1] + state = restored_tensors[1:2] + are_buckets_ready = restored_tensors[2:3] + buckets = restored_tensors[3] + with ops.control_dependencies([self._create_op]): + return gen_quantile_ops.quantile_accumulator_deserialize( + self._resource_handle, + stamp_token=stamp_token, + stream_state=state, + are_buckets_ready=are_buckets_ready, + buckets=buckets) + + +class QuantileAccumulator(tracking.TrackableResource): """A resource that allows distributed quantile computation.""" def __init__(self, @@ -61,82 +109,64 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): generate_quantiles: Generate quantiles instead of approximate boundaries. If true, exactly `num_quantiles` will be produced in the final summary. """ + self._init_stamp_token = init_stamp_token self._epsilon = epsilon + self._num_quantiles = num_quantiles + self._max_elements = max_elements + self._container = container self._generate_quantiles = generate_quantiles + super(QuantileAccumulator, self).__init__() name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: - self._quantile_accumulator_handle = ( - gen_quantile_ops.quantile_stream_resource_handle_op( - container=container, shared_name=name, name=name)) - self._create_op = gen_quantile_ops.create_quantile_accumulator( - self._quantile_accumulator_handle, - init_stamp_token, - epsilon=epsilon, - max_elements=max_elements, - num_quantiles=num_quantiles, - generate_quantiles=generate_quantiles) - is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( - self._quantile_accumulator_handle) - resources.register_resource(self._quantile_accumulator_handle, - self._create_op, is_initialized_op) - self._make_savable(name) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self._init_op, + is_initialized_op) + self._saveable = QuantileAccumulatorSaveable(self.resource_handle, + self._init_op, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - def _make_savable(self, name): - stamp_token, state, are_buckets_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_serialize( - self._quantile_accumulator_handle)) - # slice_spec is useful for saving a slice from a variable. - # It's not meaningful in quantile accumulator. - slice_spec = "" - def make_save_spec(tensor, suffix): - return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix) + def create_resource(self): + return gen_quantile_ops.quantile_stream_resource_handle_op( + container=self._container, shared_name=self._name, name=self._name) - specs = [make_save_spec(stamp_token, "_stamp")] - specs += [make_save_spec(state, "_state")] - specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] - specs += [make_save_spec(buckets, "buckets")] - super(QuantileAccumulator, - self).__init__(self._quantile_accumulator_handle, specs, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + def initialize(self): + return gen_quantile_ops.create_quantile_accumulator( + self.resource_handle, + self._init_stamp_token, + epsilon=self._epsilon, + max_elements=self._max_elements, + num_quantiles=self._num_quantiles, + generate_quantiles=self._generate_quantiles) - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated quantile accumulator from 'restored_tensors'. + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. + def is_initialized(self): + return gen_quantile_ops.quantile_accumulator_is_initialized( + self.resource_handle) - Returns: - The operation that restores the state of the quantile accumulator. - """ - # Read the restored tensors with the same order that were added to saving - # spec. - stamp_token = restored_tensors[:1] - state = restored_tensors[1:2] - are_buckets_ready = restored_tensors[2:3] - buckets = restored_tensors[3] - with ops.control_dependencies([self._create_op]): - return gen_quantile_ops.quantile_accumulator_deserialize( - self._quantile_accumulator_handle, - stamp_token=stamp_token, - stream_state=state, - are_buckets_ready=are_buckets_ready, - buckets=buckets) + def _gather_saveables_for_checkpoint(self): + return {"quantile_accumulator", self.saveable} def get_buckets(self, stamp_token): """Returns quantile buckets created during previous flush.""" are_buckets_ready, buckets = ( gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token)) return are_buckets_ready[0], buckets[0] def schedule_get_buckets(self): """Returns a scheduled read of buckets created during previous flush.""" return batch_ops_utils.ScheduledStampedResourceOp( - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, op=gen_quantile_ops.quantile_accumulator_get_buckets) def _make_summary(self, column, example_weights): @@ -161,14 +191,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Adds quantile summary to its stream in resource.""" summary = self._make_summary(column, example_weights) return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) def add_prebuilt_summary(self, stamp_token, summary): """Adds quantile summary to its stream in resource.""" return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) @@ -177,7 +207,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summary = self._make_summary(column, example_weights) return batch_ops_utils.ScheduledStampedResourceOp( op=gen_quantile_ops.quantile_accumulator_add_summaries, - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, summaries=summary) def flush(self, stamp_token, next_stamp_token): @@ -190,17 +220,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) def flush_summary(self, stamp_token, next_stamp_token): """Finalizes quantile summary stream and resets it for next iteration.""" result = gen_quantile_ops.quantile_accumulator_flush_summary( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result - - def resource(self): - return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 2e94e353f32..ad1191d4123 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,12 +26,83 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): +class StatsAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for StatsAccumulator.""" + + def __init__(self, resource_handle, create_op, is_scalar, name): + self._create_op = create_op + self._resource_handle = resource_handle + self._is_scalar = is_scalar + slice_spec = "" + saver_name = self._resource_handle.name + (stamp_token, num_updates, partition_ids, feature_ids, gradients, + hessians) = self.serialize() + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + saver_name + "_stamp"), + saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, + saver_name + "_num_updates"), + saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, + saver_name + "_partition_ids"), + saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, + saver_name + "_feature_ids"), + saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, + saver_name + "_gradients"), + saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, + saver_name + "hessians"), + ] + super(StatsAccumulatorSaveable, self).__init__(self._resource_handle, specs, + name) + + def serialize(self): + """Serializes the stats accumulator state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( + self._resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( + self._resource_handle) + + def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, + gradients, hessians): + """Resets the stats accumulator with the serialized state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return self.deserialize( + stamp_token=restored_tensors[0], + num_updates=restored_tensors[1], + partition_ids=restored_tensors[2], + feature_ids=restored_tensors[3], + gradients=restored_tensors[4], + hessians=restored_tensors[5]) + + +class StatsAccumulator(tracking.TrackableResource): """A resource that allows to accumulate gradients and hessians. For consistency guarantees, we use read and write stamp tokens. @@ -58,58 +129,69 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ + self._stamp_token = stamp_token + self._gradient_shape = gradient_shape + self._hessian_shape = hessian_shape + self._container = container + + if (gradient_shape == tensor_shape.scalar() and + hessian_shape == tensor_shape.scalar()): + self._is_scalar = True + else: + self._is_scalar = False + if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: - # Both values are scalars. - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): - self._is_scalar = True - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_scalar_resource_handle_op( - container, name, name=name)) - - create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( - self._resource_handle, stamp_token) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( - self._resource_handle)) - else: - self._is_scalar = False - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_tensor_resource_handle_op( - container, name, name=name)) - create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( - self._resource_handle, stamp_token, gradient_shape.as_list(), - hessian_shape.as_list()) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( - self._resource_handle)) - - self._create_op = create_op - slice_spec = "" - saver_name = self._resource_handle.name - (stamp_token, num_updates, partition_ids, feature_ids, gradients, - hessians) = self.serialize() - specs = [ - saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, - saver_name + "_stamp"), - saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, - saver_name + "_num_updates"), - saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, - saver_name + "_partition_ids"), - saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, - saver_name + "_feature_ids"), - saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, - saver_name + "_gradients"), - saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, - saver_name + "hessians"), - ] - - super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) - resources.register_resource(self._resource_handle, create_op, + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self.initializer, is_initialized_op) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + self._saveable = StatsAccumulatorSaveable( + self.resource_handle, self.initializer, self._is_scalar, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + def create_resource(self): + if self._is_scalar: + return ( + gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( + self._container, self._name, name=self._name)) + else: + return ( + gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( + self._container, self._name, name=self._name)) + + def initialize(self): + if self._is_scalar: + return gen_stats_accumulator_ops.create_stats_accumulator_scalar( + self.resource_handle, self._stamp_token) + else: + return gen_stats_accumulator_ops.create_stats_accumulator_tensor( + self.resource_handle, self._stamp_token, + self._gradient_shape.as_list(), self._hessian_shape.as_list()) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( + self.resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( + self.resource_handle) + + @property + def saveable(self): + return self._saveable + + def _gather_saveables_for_checkpoint(self): + return {"stats_accumulator", self.saveable} def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): """Updates the stats accumulator.""" @@ -117,11 +199,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): partition_ids, feature_ids, gradients, hessians)) if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) def schedule_add(self, partition_ids, feature_ids, gradients, hessians): @@ -131,7 +213,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): if self._is_scalar: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_scalar_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -139,7 +221,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): else: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_tensor_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -153,55 +235,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): return gen_stats_accumulator_ops.stats_accumulator_tensor_make_summary( partition_ids, feature_ids, gradients, hessians) - def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, - gradients, hessians): - """Resets the stats accumulator with the serialized state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - def flush(self, stamp_token, next_stamp_token): """Flushes the stats accumulator.""" if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - self._resource_handle, stamp_token, next_stamp_token) + self.resource_handle, stamp_token, next_stamp_token) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - self._resource_handle, stamp_token, next_stamp_token) - - def serialize(self): - """Serializes the stats accumulator state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( - self._resource_handle) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( - self._resource_handle) - - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated tree ensemble from 'restored_tensors'. - - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. Not meaningful for trees. - - Returns: - The operation that restores the state of the tree ensemble variable. - """ - with ops.control_dependencies([self._create_op]): - return self.deserialize( - stamp_token=restored_tensors[0], - num_updates=restored_tensors[1], - partition_ids=restored_tensors[2], - feature_ids=restored_tensors[3], - gradients=restored_tensors[4], - hessians=restored_tensors[5]) - - def resource(self): - return self._resource_handle + self.resource_handle, stamp_token, next_stamp_token) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 1cf61a10ba2..ab5713fbe26 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -992,7 +992,7 @@ class GradientBoostedDecisionTreeModel(object): # Get accumulated steps and examples for the current layer. _, _, _, _, acc_examples, acc_steps = ( - steps_accumulator.serialize()) + steps_accumulator.saveable.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) ensemble_update_ops.append( diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 5ecd4f34183..40b1e667ee6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -25,6 +25,13 @@ import six from tensorflow.python.training.server_lib import ClusterSpec +def _format_master_url(master, rpc_layer=None): + if rpc_layer: + return '%s://%s' % (rpc_layer, master) + else: + return master + + @six.add_metaclass(abc.ABCMeta) class ClusterResolver(object): """Abstract class for all implementations of ClusterResolvers. @@ -57,12 +64,13 @@ class ClusterResolver(object): 'cluster_spec is not implemented for {}.'.format(self)) @abc.abstractmethod - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Retrieves the name or URL of the session master. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. @@ -77,10 +85,18 @@ class ClusterResolver(object): class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec, master=''): + def __init__(self, cluster_spec, master='', task_type=None, task_index=None, + environment='', num_accelerators_per_worker=0, + rpc_layer=None): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() + self._task_type = task_type + self._task_index = task_index + self._environment = environment + self._num_accelerators_per_worker = num_accelerators_per_worker + self._rpc_layer = rpc_layer + if not isinstance(cluster_spec, ClusterSpec): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec @@ -93,12 +109,13 @@ class SimpleClusterResolver(ClusterResolver): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC used by distributed TensorFlow. Returns: The name or URL of the session master. @@ -106,10 +123,52 @@ class SimpleClusterResolver(ClusterResolver): If a task_type and task_index is given, this will override the `master` string passed into the initialization function. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + else: + master = self._master - return self._master + return _format_master_url(master, rpc_layer or self._rpc_layer) + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._environment + + def num_accelerators_per_worker(self, session_config=None): + """Returns the number of accelerator cores per worker. + + Args: + session_config: Unused. The SimpleClusterResolver does not do automatic + detection of accelerators, so a TensorFlow session will never be + created, and thus a `session_config` is never necessary here, and will + be ignored. + """ + del session_config + return self._num_accelerators_per_worker + + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer class UnionClusterResolver(ClusterResolver): @@ -119,13 +178,22 @@ class UnionClusterResolver(ClusterResolver): merges the underlying ClusterResolvers, and returns one unified ClusterSpec when cluster_spec is called. The details of the merge function is documented in the cluster_spec function. + + For additional Cluster Resolver properties such as task type, task index, + rpc layer, environment, etc..., we will return the value from the first + ClusterResolver in the union. """ - def __init__(self, *args): + def __init__(self, *args, **kwargs): """Initializes a UnionClusterResolver with other ClusterResolvers. Args: *args: `ClusterResolver` objects to be unionized. + **kwargs: + rpc_layer - (Optional) Override value for the RPC layer used by + TensorFlow. + task_type - (Optional) Override value for the current task type. + task_index - (Optional) Override value for the current task index. Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. @@ -133,6 +201,13 @@ class UnionClusterResolver(ClusterResolver): """ super(UnionClusterResolver, self).__init__() + self._rpc_layer = kwargs.pop('rpc_layer', None) + self._task_type = kwargs.pop('task_type', None) + self._task_index = kwargs.pop('task_index', None) + + if kwargs: + raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) + if not args: raise ValueError('At least one ClusterResolver is required.') @@ -216,7 +291,7 @@ class UnionClusterResolver(ClusterResolver): return ClusterSpec(merged_cluster) - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. This usually returns the master from the first ClusterResolver passed in, @@ -225,11 +300,45 @@ class UnionClusterResolver(ClusterResolver): Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + return _format_master_url(master, rpc_layer or self._rpc_layer) - return self._cluster_resolvers[0].master() + return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) + + @property + def task_type(self): + return self._task_type or self._cluster_resolvers[0].task_type + + @property + def task_index(self): + return self._task_index or self._cluster_resolvers[0].task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._cluster_resolvers[0].environment + + def num_accelerators_per_worker(self, session_config=None): + return self._cluster_resolvers[0].num_accelerators_per_worker( + session_config) + + @property + def rpc_layer(self): + return self._rpc_layer or self._cluster_resolvers[0].rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index c004b2e2d3b..b94c9612b5b 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -57,6 +57,62 @@ class UnionClusterResolverTest(test.TestCase): actual_cluster_spec = union_resolver.cluster_spec() self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testInitSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + self.assertEqual(simple_resolver.task_type, "ps") + self.assertEqual(simple_resolver.task_index, 1) + self.assertEqual(simple_resolver.environment, "cloud") + self.assertEqual(simple_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(simple_resolver.rpc_layer, "grpc") + + def testOverrideSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + simple_resolver.task_type = "worker" + simple_resolver.task_index = 2 + simple_resolver.rpc_layer = "http" + + self.assertEqual(simple_resolver.task_type, "worker") + self.assertEqual(simple_resolver.task_index, 2) + self.assertEqual(simple_resolver.rpc_layer, "http") + + def testSimpleOverrideMasterWithTaskIndexZero(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker0:2222") + + def testSimpleOverrideMasterWithRpcLayer(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 2, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker2:2222") + def testSimpleOverrideMaster(self): base_cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], @@ -65,7 +121,42 @@ class UnionClusterResolverTest(test.TestCase): simple_resolver = SimpleClusterResolver(base_cluster_spec) actual_master = simple_resolver.master("worker", 2) - self.assertEquals(actual_master, "worker2:2222") + self.assertEqual(actual_master, "worker2:2222") + + def testUnionClusterResolverGetProperties(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + resolver1 = SimpleClusterResolver(cluster_spec_1, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + cluster_spec_2 = server_lib.ClusterSpec({ + "ps": ["ps2:2222", "ps3:2222"], + "worker": ["worker3:2222", "worker4:2222", "worker5:2222"] + }) + resolver2 = SimpleClusterResolver(cluster_spec_2, task_type="worker", + task_index=2, environment="local", + num_accelerators_per_worker=16, + rpc_layer="http") + + union_resolver = UnionClusterResolver(resolver1, resolver2) + + self.assertEqual(union_resolver.task_type, "ps") + self.assertEqual(union_resolver.task_index, 1) + self.assertEqual(union_resolver.environment, "cloud") + self.assertEqual(union_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(union_resolver.rpc_layer, "grpc") + + union_resolver.task_type = "worker" + union_resolver.task_index = 2 + union_resolver.rpc_layer = "http" + + self.assertEqual(union_resolver.task_type, "worker") + self.assertEqual(union_resolver.task_index, 2) + self.assertEqual(union_resolver.rpc_layer, "http") def testTwoNonOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ @@ -116,10 +207,13 @@ class UnionClusterResolverTest(test.TestCase): union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) unspecified_master = union_cluster.master() - self.assertEquals(unspecified_master, "") + self.assertEqual(unspecified_master, "") specified_master = union_cluster.master("worker", 1) - self.assertEquals(specified_master, "worker1:2222") + self.assertEqual(specified_master, "worker1:2222") + + rpc_master = union_cluster.master("worker", 1, rpc_layer="grpc") + self.assertEqual(rpc_master, "grpc://worker1:2222") def testOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 5083e4d10ba..195b68959b6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -30,6 +30,10 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +def _format_master_url(master, rpc_layer=None): + return '%s://%s' % (rpc_layer, master) if rpc_layer else master + + class GceClusterResolver(ClusterResolver): """Cluster Resolver for Google Compute Engine. @@ -45,7 +49,10 @@ class GceClusterResolver(ClusterResolver): zone, instance_group, port, - job_name='worker', + task_type='worker', + task_index=0, + rpc_layer='grpc', + num_accelerators_per_worker=0, credentials='default', service=None): """Creates a new GceClusterResolver object. @@ -55,13 +62,22 @@ class GceClusterResolver(ClusterResolver): each instance in the instance group. Args: - project: Name of the GCE project - zone: Zone of the GCE instance group - instance_group: Name of the GCE instance group + project: Name of the GCE project. + zone: Zone of the GCE instance group. + instance_group: Name of the GCE instance group. port: Port of the listening TensorFlow server (default: 8470) - job_name: Name of the TensorFlow job this set of instances belongs to + task_type: Name of the TensorFlow job this GCE instance group of VM + instances belong to. + task_index: The task index for this particular VM, within the GCE + instance group. In particular, every single instance should be assigned + a unique ordinal index within an instance group manually so that they + can be distinguished from each other. + rpc_layer: The RPC layer TensorFlow should use to communicate across + instances. + num_accelerators_per_worker: Number of accelerators (GPUs) present per + instance. credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default() + GoogleCredentials.get_application_default(). service: The GCE API object returned by the googleapiclient.discovery function. (Default: discovery.build('compute', 'v1')). If you specify a custom service object, then the credentials parameter will be ignored. @@ -72,7 +88,9 @@ class GceClusterResolver(ClusterResolver): self._project = project self._zone = zone self._instance_group = instance_group - self._job_name = job_name + self._task_type = task_type + self._task_index = task_index + self._rpc_layer = rpc_layer self._port = port self._credentials = credentials @@ -133,10 +151,58 @@ class GceClusterResolver(ClusterResolver): previous_response=response) worker_list.sort() - return ClusterSpec({self._job_name: worker_list}) + return ClusterSpec({self._task_type: worker_list}) - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + def master(self, task_type=None, task_index=None, rpc_layer=None): + task_type = task_type if task_type is not None else self._task_type + task_index = task_index if task_index is not None else self._task_index + + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + if rpc_layer or self._rpc_layer: + return '%s://%s' % (rpc_layer or self._rpc_layer, master) + else: + return master return '' + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + raise RuntimeError( + 'You cannot reset the task_type of the GceClusterResolver after it has ' + 'been created.') + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in. + + For users in the GCE environment, the environment property is always an + empty string, and Google users will not use this ClusterResolver for running + on internal systems. + """ + return '' + + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer + + def num_accelerators_per_worker(self, session_config=None): + del session_config # Unused, since this is set manually in __init__. + return self._num_accelerators_per_worker + diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py index 87b83031224..c691552e860 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py @@ -135,12 +135,86 @@ class GceClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testMasterRetrieval(self): + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_index=0, + port=8470, + credentials=None, + service=self.standard_mock_service_client()) + self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470') + + def testMasterRetrievalWithCustomTasks(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual( + gce_cluster_resolver.master('worker', 2, 'test'), + 'test://10.3.4.5:8470') + + def testOverrideParameters(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='testworker', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + gce_cluster_resolver.task_index = 1 + gce_cluster_resolver.rpc_layer = 'test' + + self.assertEqual(gce_cluster_resolver.task_type, 'testworker') + self.assertEqual(gce_cluster_resolver.task_index, 1) + self.assertEqual(gce_cluster_resolver.rpc_layer, 'test') + self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470') + + def testOverrideParametersWithZeroOrEmpty(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='', + task_index=1, + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual(gce_cluster_resolver.master( + task_type='', task_index=0), 'grpc://10.1.2.3:8470') + def testCustomJobNameAndPortRetrieval(self): gce_cluster_resolver = GceClusterResolver( project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='custom', + task_type='custom', port=2222, credentials=None, service=self.standard_mock_service_client()) @@ -196,7 +270,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker1_name_to_ip)) @@ -205,7 +279,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker2_name_to_ip)) @@ -214,7 +288,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='ps', + task_type='ps', port=2222, credentials=None, service=self.gen_standard_mock_service_client(ps_name_to_ip)) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c4ac9d07001..1f6803a9ff9 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _tpuService(self): + """Creates a new Cloud TPU API object. + + This works around an issue where the underlying HTTP connection sometimes + times out when the script has been running for too long. Other methods in + this object calls this method to get a new API object whenever they need + to communicate with the Cloud API. + + Returns: + A Google Cloud TPU API object. + """ + if self._service: + return self._service + + credentials = self._credentials + if credentials is None or credentials == 'default': + credentials = GoogleCredentials.get_application_default() + + if self._discovery_url: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials, + discoveryServiceUrl=self._discovery_url) + else: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials) + def _requestComputeMetadata(self, path): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) @@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver): return None @staticmethod - def _discoveryUrl(): + def _environmentDiscoveryUrl(): return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, @@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver): self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name - self._credentials = credentials + # Whether we should actually attempt to contact Cloud APIs should_resolve = self._shouldResolve() + # We error out if we are in a non-Cloud environment which cannot talk to the + # Cloud APIs using the standard class and a special object is not passed in. + self._service = service + if (self._service is None and should_resolve and + not _GOOGLE_API_CLIENT_INSTALLED): + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + # We save user-passed credentials, unless the user didn't pass in anything. + self._credentials = credentials + if (credentials == 'default' and should_resolve and + _GOOGLE_API_CLIENT_INSTALLED): + self._credentials = None + + # Automatically detect project and zone if unspecified. if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] - self._project = project self._zone = zone - if credentials == 'default' and should_resolve: - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None and should_resolve: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - final_discovery_url = self._discoveryUrl() or discovery_url - if final_discovery_url: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials, - discoveryServiceUrl=final_discovery_url) - else: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) - else: - self._service = service + self._discovery_url = self._environmentDiscoveryUrl() or discovery_url self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and (should_resolve or - in_gke): + if (coordinator_name and not coordinator_address and + (should_resolve or in_gke)): self._start_local_server() else: self._coordinator_address = coordinator_address @@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) + service = self._tpuService() + request = service.projects().locations().nodes().get(name=full_name) response = request.execute() if 'state' in response and response['state'] != 'READY': diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index ad4f6432630..478c82967ba 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase): del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - def testDiscoveryUrl(self): + def testEnvironmentDiscoveryUrl(self): os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' self.assertEqual('https://{api}.internal/{apiVersion}', - TPUClusterResolver._discoveryUrl()) + TPUClusterResolver._environmentDiscoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index fbdca497fcc..a63366e1361 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -59,8 +59,6 @@ option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires M # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) -set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") -set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) @@ -72,25 +70,25 @@ endif() if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option # for targets that link ${CMAKE_THREAD_LIBS_INIT}. - find_package (Threads) + find_package (Threads REQUIRED) # Options for linking CUDA/CUDNN libraries - option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) + option(tensorflow_PATH_CUDA_LIB "Additional library search path for cudnn, nccl, culibos" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) - option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) - option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_CUDNN_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_CUDNN_LIB) + option(tensorflow_PATH_NCCL_LIB "Override PATH_CUDA_LIB for nccl" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_NCCL_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_NCCL_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_NCCL_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -210,14 +208,17 @@ endif() include(CheckCXXCompilerFlag) # OpenMP Support -CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) -if (GCC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") -endif() -CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) -if (MSVC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") -endif() +if (WIN32) + CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) + if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") + endif() +else (WIN32) + CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) + if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + endif() +endif (WIN32) # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) @@ -377,29 +378,19 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # later command will make use of the value in tensorflow_CUDA_VERSION - find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) - - # Test compatibility of compiler on CUDA - try_compile(CUDA_TEST_COMPILE_C - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - try_compile(CUDA_TEST_COMPILE_CXX - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) - message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + # minimum 9.1 in cuda version + find_package(CUDA 9.1 REQUIRED) + if(NOT CUDA_FOUND) + message(FATAL_ERROR "CUDA not found.") endif() - # by default we assume compute cabability 3.5 and 5.2. If you change this change it in - # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") + # use cmake internal CUDA_ARCH_NAME switch + # e.g. CUDA_ARCH_NAME="Auto" will autodetect + # CUDA_ARCH_NAME="All" will use all arches + cuda_select_nvcc_arch_flags(NVCC_ARCH_FLAGS ${CUDA_ARCH_NAME}) + list(APPEND CUDA_NVCC_FLAGS ${NVCC_ARCH_FLAGS}) + message(STATUS "Using CUDA arch flags: ${NVCC_ARCH_FLAGS_readable}") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) @@ -423,43 +414,94 @@ if (tensorflow_ENABLE_GPU) else (WIN32) set(CUDNN_INCLUDE "${tensorflow_CUDNN_INCLUDE}") - find_library(nccl_STATIC_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT nccl_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl.so PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT nccl_LIBRARY) message(FATAL_ERROR "NCCL is required for GPU-build") - else (NOT nccl_STATIC_LIBRARY) - message("nccl-static: ${nccl_STATIC_LIBRARY}") + else (NOT nccl_LIBRARY) + message("nccl: ${nccl_LIBRARY}") # something like /usr/lib64/libnccl_static.a - endif (NOT nccl_STATIC_LIBRARY) + endif (NOT nccl_LIBRARY) - find_library(cudnn_STATIC_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT cudnn_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn.so PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT cudnn_LIBRARY) message(FATAL_ERROR "CUDNN is required for GPU-build") - else (NOT cudnn_STATIC_LIBRARY) - message("cudnn-static: ${cudnn_STATIC_LIBRARY}") - endif (NOT cudnn_STATIC_LIBRARY) + else (NOT cudnn_LIBRARY) + file(READ ${CUDNN_INCLUDE}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + # fetch cudnn version + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") + if(NOT CUDNN_VERSION_MAJOR) + set(CUDNN_VERSION "???") + else() + set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") + endif() + message(STATUS "cudnn library: ${cudnn_LIBRARY} (found version: \"${CUDNN_VERSION}\")") + endif (NOT cudnn_LIBRARY) - find_library(culibos_STATIC_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT culibos_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + # shared first (if exists) else static one + find_library(culibos_LIBRARY NAMES libculibos.so libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + # only static version + find_library(culibos_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT culibos_LIBRARY) message(FATAL_ERROR "CULIBOS is required for GPU-build") - else (NOT culibos_STATIC_LIBRARY) - message("culibos-static: ${culibos_STATIC_LIBRARY}") - endif (NOT culibos_STATIC_LIBRARY) + else (NOT culibos_LIBRARY) + message("culibos: ${culibos_LIBRARY}") + endif (NOT culibos_LIBRARY) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} - ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) + ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_LIBRARY} ${culibos_LIBRARY} ${nccl_LIBRARY}) endif (WIN32) include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. - string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + string(REPLACE "." "" short_CUDA_VER ${CUDA_VERSION}) + + # List of enumerated CUDA caps + string(REPLACE " " ";" NVCC_ARCH_LIST "${NVCC_ARCH_FLAGS_readable}") + set(list ${NVCC_ARCH_LIST}) + + # Construct capability string + foreach(NVCC_ARCH ${NVCC_ARCH_LIST}) + if (NVCC_ARCH MATCHES "sm_") + string(REGEX REPLACE "^.sm*" "" NVCC_ARCH ${NVCC_ARCH}) + math(EXPR NVCC_ARCH_MAJOR "${NVCC_ARCH} / 10") + math(EXPR NVCC_ARCH_MINOR "(${NVCC_ARCH} - (${NVCC_ARCH_MAJOR}*10))") + if (TF_CUDA_CAP) + set(TF_CUDA_CAP "${TF_CUDA_CAP},CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + else (TF_CUDA_CAP) + set(TF_CUDA_CAP "CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + endif (TF_CUDA_CAP) + endif() + endforeach() # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" + "#define TF_CUDA_CAPABILITIES ${TF_CUDA_CAP}\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" - "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" + "#define TF_CUDNN_VERSION \"64_${CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -494,14 +536,14 @@ if (tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll cudart_dll_name=cudart64_${short_CUDA_VER}.dll - cuda_version_number=${tensorflow_CUDA_VERSION} + cuda_version_number=${CUDA_VERSION} nvcuda_dll_name=nvcuda.dll cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=${tensorflow_CUDA_VERSION} - cudnn_version_number=${tensorflow_CUDNN_VERSION}) + cuda_version_number=${CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index c6c5021f60b..4546dbdecc0 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -20,6 +20,7 @@ if (systemlib_ABSEIL_CPP) absl_dynamic_annotations absl_malloc_internal absl_throw_delegate + absl_int128 absl_strings str_format_internal absl_bad_optional_access) @@ -50,6 +51,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) @@ -60,6 +62,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) @@ -71,6 +74,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/libabsl_dynamic_annotations.a ${abseil_cpp_BUILD}/absl/base/libabsl_malloc_internal.a ${abseil_cpp_BUILD}/absl/base/libabsl_throw_delegate.a + ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index d94b703700c..96160568fa7 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -57,6 +57,7 @@ tensorflow/python/ops tensorflow/python/ops/distributions tensorflow/python/ops/linalg tensorflow/python/ops/losses +tensorflow/python/ops/signal tensorflow/python/platform tensorflow/python/profiler tensorflow/python/profiler/internal @@ -377,8 +378,6 @@ tensorflow/contrib/seq2seq/python/ops tensorflow/contrib/session_bundle tensorflow/contrib/session_bundle/example tensorflow/contrib/signal -tensorflow/contrib/signal/python -tensorflow/contrib/signal/python/ops tensorflow/contrib/slim tensorflow/contrib/slim/python tensorflow/contrib/slim/python/slim diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 88b4a6165c0..d66e39ac07c 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,14 +68,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index ef337b3a15c..9cfa8b90749 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -89,7 +89,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index ef487d3509b..df7b854afcc 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -373,8 +373,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 8d13dc7316a..3b49755afcf 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -49,7 +48,7 @@ class XLACompileContextTest(test.TestCase): histogram_summary = summary.histogram('histogram_summary', dummy_tensor) image_summary = summary.image('image_summary', dummy_tensor) scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) summary.merge( [ audio_summary, histogram_summary, image_summary, scalar_summary, diff --git a/tensorflow/contrib/copy_graph/python/__init__.py b/tensorflow/contrib/copy_graph/python/__init__.py index b9ff28eb0d7..5c1048e02a3 100644 --- a/tensorflow/contrib/copy_graph/python/__init__.py +++ b/tensorflow/contrib/copy_graph/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/copy_graph/python/util/__init__.py b/tensorflow/contrib/copy_graph/python/util/__init__.py index b9ff28eb0d7..5c1048e02a3 100644 --- a/tensorflow/contrib/copy_graph/python/util/__init__.py +++ b/tensorflow/contrib/copy_graph/python/util/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 57ffaa87e45..670b5494327 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -63,8 +63,8 @@ cuda_py_test( ], shard_count = 6, tags = [ - "no_oss", # b/117989214 "noasan", # http://b/62067814 + "requires-gpu-sm35", ], ) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 5d8c6191f8d..53202322686 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -24,6 +24,10 @@ @@CudnnGRUSaveable @@CudnnRNNReluSaveable @@CudnnRNNTanhSaveable +@@CudnnParamsFormatConverterLSTM +@@CudnnParamsFormatConverterGRU +@@CudnnParamsFormatConverterTanh +@@CudnnParamsFormatConverterRelu """ from __future__ import absolute_import @@ -48,6 +52,10 @@ _allowed_symbols = [ "CudnnGRUSaveable", "CudnnRNNReluSaveable", "CudnnRNNTanhSaveable", + "CudnnParamsFormatConverterLSTM", + "CudnnParamsFormatConverterGRU", + "CudnnParamsFormatConverterTanh", + "CudnnParamsFormatConverterRelu", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index c59d3682d40..ae839108ebe 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -202,12 +202,13 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): dtype=dtype) random_seed.set_random_seed(1234) params_size_t = model.params_size() - params = variables.Variable( + params = variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), dtype=dtype, validate_shape=False) saveable = _CreateParamsSavable(params, model) - weights, biases = saveable._OpaqueParamsToCanonical() + weights, biases = saveable.format_converter._opaque_to_cu_canonical( + saveable._variables) reset_params = state_ops.assign( params, array_ops.zeros([params_size_t], dtype=dtype), @@ -248,7 +249,7 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): params_size_t = model.params_size() names = ["rnn_1", "rnn_2"] param_vars = [ - variables.Variable( + variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), dtype=dtype, validate_shape=False) for name in names @@ -256,8 +257,10 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): saveables = [] for name, params in zip(names, param_vars): saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0]._OpaqueParamsToCanonical() - weights2, biases2 = saveables[1]._OpaqueParamsToCanonical() + weights1, biases1 = saveables[0].format_converter._opaque_to_cu_canonical( + saveables[0]._variables) + weights2, biases2 = saveables[1].format_converter._opaque_to_cu_canonical( + saveables[1]._variables) reset_params = [ state_ops.assign( params, @@ -304,7 +307,7 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): direction=direction, dtype=dtype) params_size_t = model.params_size() - params = variables.Variable( + params = variables.VariableV1( array_ops.ones([params_size_t], dtype=dtype), validate_shape=False, dtype=dtype) @@ -422,21 +425,21 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): cudnn_rnn_ops.CUDNN_LSTM, constant_op.constant([4]), 200, 200, direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() with self.assertRaisesRegexp( ValueError, "Shape must be rank 0 but is rank 1"): model = _CreateModel( cudnn_rnn_ops.CUDNN_LSTM, 4, constant_op.constant([200]), 200, direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() with self.assertRaisesRegexp( ValueError, "Shape must be rank 0 but is rank 1"): model = _CreateModel( cudnn_rnn_ops.CUDNN_LSTM, 4, 200, constant_op.constant([200]), direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() class CudnnRNNTestInference(TensorFlowTestCase): @@ -458,7 +461,7 @@ class CudnnRNNTestInference(TensorFlowTestCase): params_size_t = model.params_size() input_data = array_ops.ones([seq_length, batch_size, input_size]) input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.Variable( + params = variables.VariableV1( array_ops.ones([params_size_t]), validate_shape=False) if has_input_c: input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) @@ -584,20 +587,20 @@ class CudnnRNNTestTraining(TensorFlowTestCase): dtype=dtype, dropout=dropout) params_size_t = model.params_size() - input_data = variables.Variable( + input_data = variables.VariableV1( random_ops.random_uniform( [seq_length, batch_size, input_size], dtype=dtype), dtype=dtype) - input_h = variables.Variable( + input_h = variables.VariableV1( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), dtype=dtype) - params = variables.Variable( + params = variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), validate_shape=False, dtype=dtype) if has_input_c: - input_c = variables.Variable( + input_c = variables.VariableV1( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), dtype=dtype) @@ -639,7 +642,8 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTraining(self): + def DISABLED_testSimpleTraining(self): + # TODO(jamesqin): fix b/117989214 test_configs = [ { "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index f09466b631f..60229af374b 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -27,5 +27,10 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibl from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterGRU +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterLSTM +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterRelu +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterTanh from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index a324c6e7d76..8bbcc7cd039 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -388,11 +388,11 @@ class _CudnnRNN(base_layer.Layer): output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. Raises: - ValueError: initial_state is not a tuple. + TypeError: initial_state is not a tuple. """ if initial_state is not None and not isinstance(initial_state, tuple): - raise ValueError("Invalid initial_state type: %s, expecting tuple.", - type(initial_state)) + raise TypeError("Invalid initial_state type: %s, expecting tuple." % + initial_state) dtype = self.dtype inputs = ops.convert_to_tensor(inputs, dtype=dtype) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 2c92f317883..d06d0c6bdaa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -74,7 +74,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): - """Cudnn Compatible GRUCell. + r"""Cudnn Compatible GRUCell. A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by @@ -177,172 +177,60 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): return new_h, new_h -# TODO(yaozhang): make sure we only save the canonical version of params and -# don't save the platform-specific version to avoid potential race -# conditions where params is updated by both versions when being restored. -# Currently, checkpointing will function properly, despite that we save both -# versions, because Saver restores customized savables after Variables. -# However, it is good to not rely on this restoring order of Saver and to -# avoid unnecessary storage. Add a test to check only the canonical version is -# saved. -class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): - """Abstract SaveableObject implementation handling Cudnn opaque params.""" +class CudnnParamsFormatConverter(object): + """Abstract class that converts between params of Cudnn Rnn and TF Rnn.""" def __init__(self, - opaque_params, num_layers, num_units, input_size, input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - scope=None, - name="cudnn_rnn_saveable"): - """Creates a CudnnOpaqueParamsSaveable object. - - CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file - and is used to save/restore the weights and biases parameters in a - canonical format which is directly consumable by platform-independent tf - RNN cells. Parameters are saved as tensors layer by layer with weight - tensors followed by bias tensors, and forward direction followed by - backward direction (if applicable). When restoring, a user could name - param_variables as desired, and restore weight and bias tensors to these - variables. - - For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per - bias for each layer: tensor 0 is applied to the input from the previous - layer and tensor 1 to the recurrent input. - - For CudnnLSTM, there are 8 tensors per weight and per bias for each - layer: tensor 0-3 are applied to the input from the previous layer and - tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; - tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; - tensor 3 and 7 the output gate. - - For CudnnGRU, there are 6 tensors per weight and per bias for each layer: - tensor 0-2 are applied to the input from the previous layer and - tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; - tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + direction=CUDNN_RNN_UNIDIRECTION): + """Constructor. Args: - opaque_params: a variable, Cudnn RNN opaque params. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be one + of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' + (default) always applies a linear projection of input onto RNN hidden + state. (standard RNN behavior). * 'skip_input' is only allowed when + input_size == num_units; * 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' - scope: string of VariableScope, the scope of equivalent subgraph - consisting only platform-independent tf RNN cells. - name: the name of the CudnnOpaqueParamsSaveable object. + 'unidirectional' or 'bidirectional' """ - # Define in subclasses. self._num_layers = num_layers self._input_size = input_size self._num_units = num_units self._input_mode = input_mode self._direction = direction - if scope is not None: - scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope - self._scope = scope_name or None - else: - self._scope = None - - self._variables = opaque_params self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - weights, biases = self._OpaqueParamsToCanonical() - (weights, weight_names), (biases, bias_names) = self._TransformCanonical( - weights, biases) - # We currently don't use slice_spec. It might be useful in a distributed - # setting where each parameter server node stores a slice of variable, - # instead of having the master pull all slices and then save them. - slice_spec = "" - params = weights + biases - self._weight_names = weight_names - self._bias_names = bias_names - self._param_names = weight_names + bias_names - prefixed_param_names = weight_names + bias_names - if self._scope: - prefixed_param_names = [ - "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] - specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, prefixed_param_names) - ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) + def tf_canonical_to_opaque(self, tf_canonicals): + r"""Converts tf canonical weights to cudnn opaque param.""" + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] + opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) + return opaque_params - def restore(self, restored_tensors, restored_shapes): - weights, biases = self._ReverseTransformCanonical(restored_tensors) - weights = [array_ops.reshape(w, [-1]) for w in weights] - opaque_params = self._CanonicalToOpaqueParams(weights, biases) + def opaque_to_tf_canonical(self, opaque_param): + r"""Converts cudnn opaque param to tf canonical weights.""" + cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) + weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) + return weights, biases - return state_ops.assign( - self._variables, opaque_params, validate_shape=False) - - def _checkpointable_save(self, save_buffer): - weights, biases = self._OpaqueParamsToCanonical() - with ops.device("gpu:0"): - (weights, _), (biases, _) = self._TransformCanonical( - weights, biases) - for name, tensor in zip(self._param_names, weights + biases): - save_buffer[name] = array_ops.identity(tensor) - - def _checkpointable_restore(self, restore_buffer): - tensors = [array_ops.identity(restore_buffer[name]) - for name in self._param_names] - return self.restore( - restored_tensors=tensors, - restored_shapes=None # Unused - ) - - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. - - When saving or restoring, converts to or from the opaque buffer - format. Weights are saved and loaded in the configuration expected by - cuDNN-compatible cells. - - Args: - checkpointable: An object inheriting from `CheckpointableBase` to add - dependencies too (typically the cuDNN `Layer`). - dtype: The dtype for the canonical parameter Tensors. - """ - split_dependencies = split_dependency.split_dependency( - component_names=self._param_names, - component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) - - def _checkpointable_track_params(self, checkpointable, params): - """Tracks parameters in a canonical configuration.""" - return # NotImplementedError raised by the Layer. - - def _TFCanonicalNamePrefix(self, layer, is_fwd=True): - if self._direction == CUDNN_RNN_UNIDIRECTION: - return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) - else: - if is_fwd: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % - (layer, self._rnn_cell_name)) - else: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % - (layer, self._rnn_cell_name)) - - def _OpaqueParamsToCanonical(self): + def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. + Args: + opaque_param: An opaque tensor storing cudnn rnn params (weights and + biases). Returns: 2 list for weights and biases respectively. """ @@ -351,14 +239,14 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, - params=self._variables, + params=opaque_param, num_params=self._num_params, rnn_mode=self._rnn_mode, input_mode=self._input_mode, direction=self._direction) return (weights, biases) - def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): + def _cu_canonical_to_opaque(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. Args: @@ -378,7 +266,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): input_mode=self._input_mode, direction=self._direction) - def _TransformCanonical(self, cu_weights, cu_biases): + def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. The elements of argument lists are laid out in the following format: @@ -398,46 +286,43 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. Returns: - 2 tuples, one for weights and the other for bias. - Each tuple has two lists: the 1st for transformed tf canonical tensors - and the 2nd for the names of the tensors under which they are saved. + 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] - tf_weights_names, tf_bias_names = [], [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num for i in range(self._num_layers): - layer_weights = cu_weights[i * layer_weights_num: - (i + 1) * layer_weights_num] + layer_weights = cu_weights[i * layer_weights_num:(i + 1) * + layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - prefix = self._TFCanonicalNamePrefix(i) - self._TransformSingleLayerCanonical(layer_weights, layer_biases, prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + layer_weights, layer_biases, tf_weights, tf_biases) else: - fw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=True) - bw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=False) - fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] fw_biases = layer_biases[:len(layer_biases) // 2] bw_biases = layer_biases[len(layer_biases) // 2:] - self._TransformSingleLayerCanonical(fw_weights, fw_biases, fw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + fw_weights, + fw_biases, + tf_weights, + tf_biases, + ) - self._TransformSingleLayerCanonical(bw_weights, bw_biases, bw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - return (tf_weights, tf_weights_names), (tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + bw_weights, + bw_biases, + tf_weights, + tf_biases, + ) + return (tf_weights, tf_biases) - def _TransformSingleLayerCanonical(self, cu_weights, cu_biases, prefix, - tf_weights, tf_weights_names, tf_biases, - tf_bias_names): + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -447,15 +332,12 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Args: cu_weights: a list of tensors, single layer weights. cu_biases: a list of tensors, single layer biases. - prefix: the shared prefix of all tensor names. tf_weights: a list where transformed weights are stored. - tf_weights_names: a list where names of transformed weights are stored. tf_biases: a list where transformed biases are stored. - tf_bias_names: a list where names of transformed biases are stored. """ raise NotImplementedError("Abstract method") - def _ReverseTransformCanonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -502,30 +384,27 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return cu_weights, cu_biases def _cudnn_to_tf_weights(self, *cu_weights): - r"""Stitching cudnn canonical weights to generate tf canonical weights.""" + r"""Stitches cudnn canonical weights to generate tf canonical weights.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_weights(self, layer, *tf_weights): - r"""Reverse the operations in StitchWeights().""" + r"""Reverses the operations in StitchWeights().""" raise NotImplementedError("Abstract method") def _cudnn_to_tf_biases(self, *biases): - r"""Stitching cudnn canonical biases to generate tf canonical biases.""" + r"""Stitches cudnn canonical biases to generate tf canonical biases.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_biases(self, *tf_biases): - r"""Reverse the operations in StitchBiases().""" + r"""Reverses the operations in StitchBiases().""" raise NotImplementedError("Abstract method") -class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn LSTM opaque params.""" - +class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF LSTM.""" _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) - def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order return [i_g, c_g, f_g, o_g] @@ -603,44 +482,16 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # Return ifco order for Cudnn LSTM. return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): - (w,) = self._cudnn_to_tf_weights(*weights) - (b,) = self._cudnn_to_tf_biases(*biases) - + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): + (w,) = self._cudnn_to_tf_weights(*cu_weights) + (b,) = self._cudnn_to_tf_biases(*cu_biases) tf_weights.append(w) - tf_weights_names.append(prefix + "/kernel") - tf_biases.append(b) - tf_bias_names.append(prefix + "/bias") - - def _checkpointable_track_params(self, checkpointable, params): - """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" - biases = [] - weights = [] - for name in self._weight_names: - weights.append(params[name]) - for name in self._bias_names: - biases.append(params[name]) - assert len(params) == len(weights) + len(biases) - if len(weights) == 1 and len(biases) == 1: - # For single-layer cells, allow substituting a cell with no MultiRNNCell - # wrapping. - kernel, = weights # pylint: disable=unbalanced-tuple-unpacking - bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access - assert len(biases) == len(weights) - for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access - cell.bias = bias - cell.kernel = kernel -class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn GRU opaque params.""" +class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF GRU.""" _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER @@ -702,29 +553,18 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): b_ri, b_rr = array_ops.split(br, 2, axis=0) return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): # pylint: disable=invalid-name - W_ir, w_h, r_h = self._cudnn_to_tf_weights(*weights) - b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*biases) + W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) + b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) # pylint: enable=invalid-name - tf_weights.extend([W_ir, w_h, r_h]) - tf_weights_names.append(prefix + "/gates/kernel") - tf_weights_names.append(prefix + "/candidate/input_projection/kernel") - tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") - tf_biases.extend([b_ir, b_wh, b_rh]) - tf_bias_names.append(prefix + "/gates/bias") - tf_bias_names.append(prefix + "/candidate/input_projection/bias") - tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") -class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - - _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) +class CudnnParamsFormatConverterBasic(CudnnParamsFormatConverterLSTM): + """Helper class that converts between params of Cudnn and TF Relu/Tanh RNN.""" def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -766,18 +606,270 @@ class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): return b_i, b_h -class CudnnRNNTanhSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" +class CudnnParamsFormatConverterTanh(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Tanh RNN.""" _rnn_mode = CUDNN_RNN_TANH _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER -class CudnnRNNReluSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Relu opaque params.""" +class CudnnParamsFormatConverterRelu(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Relu RNN.""" _rnn_mode = CUDNN_RNN_RELU _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER +# TODO(yaozhang): make sure we only save the canonical version of params and +# don't save the platform-specific version to avoid potential race +# conditions where params is updated by both versions when being restored. +# Currently, checkpointing will function properly, despite that we save both +# versions, because Saver restores customized savables after Variables. +# However, it is good to not rely on this restoring order of Saver and to +# avoid unnecessary storage. Add a test to check only the canonical version is +# saved. +class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): + """Abstract SaveableObject implementation handling Cudnn opaque params.""" + + def __init__(self, + opaque_params, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + scope=None, + name="cudnn_rnn_saveable"): + """Creates a CudnnOpaqueParamsSaveable object. + + CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file + and is used to save/restore the weights and biases parameters in a + canonical format which is directly consumable by platform-independent tf + RNN cells. Parameters are saved as tensors layer by layer with weight + tensors followed by bias tensors, and forward direction followed by + backward direction (if applicable). When restoring, a user could name + param_variables as desired, and restore weight and bias tensors to these + variables. + + For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per + bias for each layer: tensor 0 is applied to the input from the previous + layer and tensor 1 to the recurrent input. + + For CudnnLSTM, there are 8 tensors per weight and per bias for each + layer: tensor 0-3 are applied to the input from the previous layer and + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + + For CudnnGRU, there are 6 tensors per weight and per bias for each layer: + tensor 0-2 are applied to the input from the previous layer and + tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; + tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + + Args: + opaque_params: a variable, Cudnn RNN opaque params. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + scope: string of VariableScope, the scope of equivalent subgraph + consisting only platform-independent tf RNN cells. + name: the name of the CudnnOpaqueParamsSaveable object. + """ + # Define in subclasses. + self._num_layers = num_layers + self._input_size = input_size + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + if scope is not None: + scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope + self._scope = scope_name or None + else: + self._scope = None + + self._variables = opaque_params + self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + # Defined in subclasses. + self._format_converter = None + + tf_weights, tf_biases = ( + self.format_converter.opaque_to_tf_canonical(self._variables)) + tf_weight_names, tf_bias_names = self._tf_canonical_names() + # We currently don't use slice_spec. It might be useful in a distributed + # setting where each parameter server node stores a slice of variable, + # instead of having the master pull all slices and then save them. + slice_spec = "" + params = tf_weights + tf_biases + self._weight_names = tf_weight_names + self._bias_names = tf_bias_names + self._param_names = tf_weight_names + tf_bias_names + prefixed_param_names = tf_weight_names + tf_bias_names + if self._scope: + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names + ] + specs = [ + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, prefixed_param_names) + ] + super(CudnnOpaqueParamsSaveable, self).__init__( + array_ops.identity(self._variables), specs, name) + + @property + def format_converter(self): + if self._format_converter is None: + self._format_converter = self._format_converter_cls( + self._num_layers, self._num_units, self._input_size, self._input_mode, + self._direction) + return self._format_converter + + def restore(self, restored_tensors, restored_shapes): + opaque_params = self.format_converter.tf_canonical_to_opaque( + restored_tensors) + return state_ops.assign( + self._variables, opaque_params, validate_shape=False) + + def _checkpointable_save(self, save_buffer): + weights, biases = self.format_converter.opaque_params_to_tf_canonical( + self._variables) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [ + array_ops.identity(restore_buffer[name]) for name in self._param_names + ] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = split_dependency.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + + def _tf_canonical_names(self): + tf_weights_names, tf_biases_names = [], [] + for i in range(self._num_layers): + if self._direction == CUDNN_RNN_UNIDIRECTION: + prefix = self._tf_canonical_name_prefix(i) + self._tf_canonical_names_single_layer(prefix, tf_weights_names, + tf_biases_names) + else: + fwd_prefix = self._tf_canonical_name_prefix(i, is_fwd=True) + bak_prefix = self._tf_canonical_name_prefix(i, is_fwd=False) + + self._tf_canonical_names_single_layer(fwd_prefix, tf_weights_names, + tf_biases_names) + self._tf_canonical_names_single_layer(bak_prefix, tf_weights_names, + tf_biases_names) + return tf_weights_names, tf_biases_names + + def _tf_canonical_name_prefix(self, layer, is_fwd=True): + if self._direction == CUDNN_RNN_UNIDIRECTION: + return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) + else: + if is_fwd: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % + (layer, self._rnn_cell_name)) + else: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % + (layer, self._rnn_cell_name)) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_biases_names): + raise NotImplementedError("Abstract method") + + +class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn LSTM opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterLSTM + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/kernel") + tf_bias_names.append(prefix + "/bias") + + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + + +class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn GRU opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterGRU + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/gates/kernel") + tf_weights_names.append(prefix + "/candidate/input_projection/kernel") + tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") + + tf_bias_names.append(prefix + "/gates/bias") + tf_bias_names.append(prefix + "/candidate/input_projection/bias") + tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") + + +class CudnnRNNTanhSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterTanh + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + +class CudnnRNNReluSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterRelu + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and training. But it is possible that the layout of the parameter buffers @@ -850,7 +942,7 @@ def _get_num_params(rnn_mode, num_layers, direction): elif rnn_mode == CUDNN_RNN_TANH: num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: - raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) + raise ValueError("Invalid \'rnn_mode\': %s" % rnn_mode) num_params = num_layers * num_params_per_layer if direction != CUDNN_RNN_UNIDIRECTION: num_params *= 2 @@ -918,7 +1010,7 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 is not "1": + if use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1582,7 +1674,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): """ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s", direction) + raise ValueError("Invalid direction: %s" % direction) super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index f82453f3b5e..a938f8629d8 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -46,6 +46,9 @@ Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` Take a very simple model consisting of a single layer: ```python +import tensorflow as tf +from tensorflow import keras + inputs = tf.keras.layers.Input(shape=(1,)) predictions = tf.keras.layers.Dense(1)(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=predictions) @@ -90,8 +93,8 @@ Similarly, we can also call `evaluate` and `predict` as before using appropriate datasets. ```python -model.evaluate(eval_dataset) -model.predict(predict_dataset) +model.evaluate(eval_dataset, steps=1) +model.predict(predict_dataset, steps=1) ``` That's all you need to train your model with Keras on multiple GPUs with diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 22736c799d2..4094e52169a 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -374,9 +374,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", - # TODO(b/118820960): Re-enable this test in guitar. - "manual", - "noguitar", ], ) @@ -470,6 +467,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -492,6 +490,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -757,8 +756,6 @@ cuda_py_test( "no_oss", # TODO(b/117919883): Fix python error. "no_pip", "no_windows_gpu", - # TODO(b/118815591): Re-enable this test in guitar.) - "noguitar", "notsan", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index b311644cb22..d38bdb592a3 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -69,7 +69,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, g) + distribution.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index d9339f8f75a..efa99d1fc52 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -205,7 +205,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) def configure(self, diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 19b59513d81..e3d919dd0d4 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -54,8 +54,6 @@ class CollectiveAllReduceStrategyTestBase( self._run_options = config_pb2.RunOptions() self._run_options.experimental.collective_graph_key = 6 - self._sess_config = config_pb2.ConfigProto() - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -66,9 +64,10 @@ class CollectiveAllReduceStrategyTestBase( def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( num_gpus_per_worker=num_gpus) + session_config = config_pb2.ConfigProto() if task_type and task_id is not None: distribution.configure( - session_config=self._sess_config, + session_config=session_config, cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) @@ -82,14 +81,16 @@ class CollectiveAllReduceStrategyTestBase( distribution._collective_keys = collective_keys distribution._cross_tower_ops._collective_keys = collective_keys if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + return distribution, 'grpc://' + self._cluster_spec[task_type][ + task_id], session_config else: - return distribution, '' + return distribution, '', session_config def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ d.scope(): l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) @@ -117,7 +118,7 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -154,7 +155,8 @@ class CollectiveAllReduceStrategyTestBase( return error_after < error_before def _test_complex_model(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) def model_fn(): """Mnist model with synthetic input.""" @@ -193,7 +195,7 @@ class CollectiveAllReduceStrategyTestBase( return train_op with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess: with d.scope(): train_op = d.call_for_each_replica(model_fn) @@ -204,10 +206,10 @@ class CollectiveAllReduceStrategyTestBase( return True def _test_variable_initialization(self, task_type, task_id, num_gpus): - distribution, master_target = self._get_test_object(task_type, task_id, - num_gpus) + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ distribution.scope(): diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 63a163e76cd..a5137165403 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -335,17 +335,13 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) -# Note that we disable prefetching for testing since prefetching makes -# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index bae0f474d27..b5b349aa64e 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -62,26 +62,26 @@ def validate_destinations(destinations): raise ValueError("destinations can not be empty") -def _make_tensor_into_per_device(input_tensor): - """Converts a single tensor into a PerDevice object.""" +def _make_tensor_into_per_replica(input_tensor): + """Converts a single tensor into a PerReplica object.""" if isinstance(input_tensor, (tuple, list)): - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, " + raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, " "got %r but expected a object that is not a tuple or list." % (input_tensor,)) - if isinstance(input_tensor, value_lib.PerDevice): + if isinstance(input_tensor, value_lib.PerReplica): return input_tensor try: device = input_tensor.device except AttributeError: - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object " + raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " "because it doesn't have device set.") - return value_lib.PerDevice({device: input_tensor}) + return value_lib.PerReplica({device: input_tensor}) def _normalize_value_destination_pairs(value_destination_pairs): - """Converts each tensor into a PerDevice object in the input list.""" + """Converts each tensor into a PerReplica object in the input list.""" result = [] if not isinstance(value_destination_pairs, (list, tuple)): raise ValueError("`value_destination_pairs` should be a list or tuple") @@ -93,8 +93,8 @@ def _normalize_value_destination_pairs(value_destination_pairs): raise ValueError("Each element of `value_destination_pairs` should be a " "tuple of size 2.") - per_device = _make_tensor_into_per_device(pair[0]) - result.append((per_device, pair[1])) + per_replica = _make_tensor_into_per_replica(pair[0]) + result.append((per_replica, pair[1])) return result @@ -105,7 +105,7 @@ def _validate_value_destination_pairs(value_destination_pairs): if not isinstance(value_destination_pairs, (list, tuple)): return False if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): return False - if not all([isinstance(v[0], value_lib.PerDevice) + if not all([isinstance(v[0], value_lib.PerReplica) for v in value_destination_pairs]): return False return True @@ -149,26 +149,16 @@ def _simple_broadcast(value, destinations): return value_lib.Mirrored(index) -def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, +def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 - for v in per_device_value._index.values(): # pylint: disable=protected-access - if isinstance(v, value_lib.MapOutput): - v_list = v.get() - if not v_list: - continue - count += len(v_list) - # Sum within each device before aggregating across devices. - # TODO(yuefengz): Check whether it helps to use accumulation_fn here. - v = cross_tower_utils.aggregate_tensors_or_indexed_slices( - v_list, math_ops.add_n) - else: - count += 1 + for v in per_replica_value._index.values(): # pylint: disable=protected-access + count += 1 all_values.append(v) if not all_values: - raise ValueError("`per_device_value` must be non-empty") + raise ValueError("`per_replica_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): @@ -189,8 +179,8 @@ class CrossDeviceOps(object): def __init__(self): pass - def reduce(self, aggregation, per_device_value, destinations): - """Reduce `per_device_value` to `destinations`. + def reduce(self, aggregation, per_replica_value, destinations): + """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `aggregation` and put the result on `destinations`. @@ -198,23 +188,23 @@ class CrossDeviceOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - per_device_value: a PerDevice object or a tensor with device set. + per_replica_value: a PerReplica object or a tensor with device set. destinations: the reduction destinations. Returns: a Mirrored object. Raises: - ValueError: if per_device_value is not a PerDevice object. + ValueError: if per_replica_value is not a PerReplica object. """ - if not isinstance(per_device_value, value_lib.PerDevice): - per_device_value = _make_tensor_into_per_device(per_device_value) + if not isinstance(per_replica_value, value_lib.PerReplica): + per_replica_value = _make_tensor_into_per_replica(per_replica_value) validate_destinations(destinations) - return self._reduce(aggregation, per_device_value, destinations) + return self._reduce(aggregation, per_replica_value, destinations) def batch_reduce(self, aggregation, value_destination_pairs): - """Reduce PerDevice objects in a batch. + """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. @@ -222,7 +212,7 @@ class CrossDeviceOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - value_destination_pairs: a list or a tuple of tuples of PerDevice objects + value_destination_pairs: a list or a tuple of tuples of PerReplica objects (or tensors with device set if there is one device) and destinations. Returns: @@ -230,11 +220,11 @@ class CrossDeviceOps(object): Raises: ValueError: if `value_destination_pairs` is not a list or a tuple of - tuples of PerDevice objects and destinations + tuples of PerReplica objects and destinations """ if not _validate_value_destination_pairs(value_destination_pairs): # If the first element of each pair is a tensor, we try to turn it into a - # PerDevice object. + # PerReplica object. value_destination_pairs = _normalize_value_destination_pairs( value_destination_pairs) @@ -256,7 +246,7 @@ class CrossDeviceOps(object): validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") @@ -286,13 +276,13 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): if check_destinations(destinations): devices = get_devices_from(destinations) else: - devices = get_devices_from(per_device_value) + devices = get_devices_from(per_replica_value) reduce_to_device = self.reduce_to_device or devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, + reduced = _simple_reduce(per_replica_value, reduce_to_device, self.accumulation_fn, aggregation) return self.broadcast(reduced, devices) @@ -303,7 +293,7 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): ] -def _group_value_by_device(per_device_values): +def _group_value_by_device(per_replica_values): """Group values into sublists by their devices. This grouping is needed to call the all-reduce library because it expects a @@ -315,18 +305,18 @@ def _group_value_by_device(per_device_values): ] Args: - per_device_values: a list of PerDevice obejcts. + per_replica_values: a list of PerReplica obejcts. Returns: a list of lists, each sublist has components for its corresponding device of - PerDevice objects, paired with a None. + PerReplica objects, paired with a None. """ - destinations = per_device_values[0].devices + destinations = per_replica_values[0].devices grouped = [[] for _ in range(len(destinations))] - for per_device_value in per_device_values: + for per_replica_value in per_replica_values: # pylint: disable=protected-access - for i, v in enumerate(per_device_value._index.values()): - assert per_device_value.devices == destinations + for i, v in enumerate(per_replica_value._index.values()): + assert per_replica_value.devices == destinations grouped[i].append((v, None)) return grouped @@ -354,8 +344,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, a list of Mirrored objects. """ index = [{} for _ in range(len(grouped_reduced[0]))] - for d, per_device_reduced in enumerate(grouped_reduced): - for i, (v, _) in enumerate(per_device_reduced): + for d, per_replica_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_replica_reduced): if aggregation == vs.VariableAggregation.MEAN: index[i][destinations[d]] = v / ( len(destinations) * num_between_graph_workers) @@ -567,13 +557,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( - per_device_value) - if (_devices_match(per_device_value, destinations) + per_replica_value) + if (_devices_match(per_replica_value, destinations) and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, [per_device_value])[0] + return self._batch_all_reduce(aggregation, [per_replica_value])[0] else: if contains_indexed_slices: logging.log_first_n( @@ -583,9 +573,9 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): if check_destinations(destinations): devices = get_devices_from(destinations) else: - devices = get_devices_from(per_device_value) + devices = get_devices_from(per_replica_value) reduce_to_device = devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, + reduced = _simple_reduce(per_replica_value, reduce_to_device, math_ops.add_n, aggregation) return self.broadcast(reduced, devices) @@ -609,16 +599,16 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " "agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_alg, self._num_packs, + (len(per_replica_values), self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - destinations = per_device_values[0].devices - grouped = _group_value_by_device(per_device_values) + destinations = per_replica_values[0].devices + grouped = _group_value_by_device(per_replica_values) device_grad_packs, tensor_packer = _pack_tensors( grouped, self._num_packs, self._agg_small_grads_max_bytes, @@ -639,7 +629,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): destinations, device_grad_packs)) reduced = _unpack_tensors(reduced, tensor_packer) - return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices, aggregation) @@ -723,18 +713,18 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): validate_and_complete_spec(spec) for spec in all_reduce_spec ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " "and agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_spec, self._num_packs, + (len(per_replica_values), self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - destinations = sorted(per_device_values[0].devices) - device_grads = _group_value_by_device(per_device_values) + destinations = sorted(per_replica_values[0].devices) + device_grads = _group_value_by_device(per_replica_values) # The all reduce library requires fully defined shapes. # TODO(yuefengz): when tensor sharding is not needed, static shapes are not @@ -805,16 +795,16 @@ class CollectiveAllReduce(CrossDeviceOps): super(CollectiveAllReduce, self).__init__() # TODO(yuefengz, tucker): is indexed slices supported by collective ops? - def _reduce(self, aggregation, per_device_value, destinations): - if cross_tower_utils.contains_indexed_slices(per_device_value): + def _reduce(self, aggregation, per_replica_value, destinations): + if cross_tower_utils.contains_indexed_slices(per_replica_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") - all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] - if _devices_match(per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0] + if _devices_match(per_replica_value, destinations): return all_reduced else: index = {} @@ -852,7 +842,7 @@ class CollectiveAllReduce(CrossDeviceOps): for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All-reduce across all workers in a batch.""" if context.executing_eagerly(): raise ValueError( @@ -860,9 +850,9 @@ class CollectiveAllReduce(CrossDeviceOps): logging.log_first_n( logging.INFO, "Collective All-reduce invoked with batches size = %d, " - "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + "num_workers = %d" % (len(per_replica_values), self._num_workers), 10) - grouped_by_device = _group_value_by_device(per_device_values) + grouped_by_device = _group_value_by_device(per_replica_values) grouped_by_var = list(zip(*grouped_by_device)) # grouped_by_var is grouped by variables and takes the following format: @@ -892,7 +882,7 @@ class CollectiveAllReduce(CrossDeviceOps): new_device_grads = [list(x) for x in zip(*reduced_gv_list)] return _ungroup_and_make_mirrored( new_device_grads, - per_device_values[0].devices, + per_replica_values[0].devices, aggregation, num_between_graph_workers=self._num_workers) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 6a9e8e00c02..3e274ba67ca 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -40,12 +40,12 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util -def _make_per_device(values, devices, regroup=False): +def _make_per_replica(values, devices, regroup=False): devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) - # We simulate the result of regroup called on PerDevice which strips the - # PerDevice wrapper if it has only one value. + # We simulate the result of regroup called on PerReplica which strips the + # PerReplica wrapper if it has only one value. if len(values) == 1 and regroup: with ops.device(devices[0]): placed_v = array_ops.identity(values[0]) @@ -56,7 +56,7 @@ def _make_per_device(values, devices, regroup=False): with ops.device(d): placed_v = array_ops.identity(v) index[d] = placed_v - return value_lib.PerDevice(index) + return value_lib.PerReplica(index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -122,11 +122,11 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): devices = distribution.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] - per_device = _make_per_device(values, devices) + per_replica = _make_per_replica(values, devices) mean = (len(devices) - 1.) / 2. values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = mean + 1. destination_mirrored = _fake_mirrored(1., devices) @@ -144,39 +144,41 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.MEAN, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.MEAN, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_device, + vs.VariableAggregation.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.SUM, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), (per_device_2, d2)]), + cross_tower_ops.batch_reduce( + vs.VariableAggregation.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ]) self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), (per_device_2, d2)]), + cross_tower_ops.batch_reduce( + vs.VariableAggregation.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), _fake_mirrored(mean_2 * len(devices), d2) @@ -277,9 +279,9 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) result = cross_tower_ops_lib._simple_reduce( - per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -311,13 +313,14 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(aggregation, - [(per_device, devices)]) + result = cross_tower_ops_instance.batch_reduce( + aggregation, [(per_replica, devices)]) else: - result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) + result = cross_tower_ops_instance.reduce( + aggregation, per_replica, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] @@ -478,11 +481,11 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_device = _make_per_device(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices, regroup=True) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = np.array([mean[0] + 1.]) destination_mirrored = _fake_mirrored(1., devices) @@ -500,26 +503,26 @@ class MultiWorkerCollectiveAllReduceTest( self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.MEAN, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.MEAN, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.SUM, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.SUM, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), sess) @@ -528,16 +531,16 @@ class MultiWorkerCollectiveAllReduceTest( for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), - (per_device_2, d2)]), + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), - (per_device_2, d2)]), + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 35324d15d44..50b3cf31e59 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -667,7 +667,5 @@ def contains_indexed_slices(value): return any(contains_indexed_slices(v) for v in value) elif isinstance(value, value_lib.DistributedValues): return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access - elif isinstance(value, value_lib.MapOutput): - return contains_indexed_slices(value.get()) else: return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py index d25964fa41a..e46240abbfa 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -98,24 +98,13 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDevice(self): + def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDeviceMapOutput(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({ - "/gpu:0": value_lib.MapOutput([t0]), - "/cpu:0": value_lib.MapOutput([t1])}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( mode=["graph", "eager"], diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 018512ae5a2..8f82b4c92aa 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -300,10 +300,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - try: - train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) - except TypeError: - train_distribute = train_distribute_cls(num_gpus_per_worker=2) + train_distribute = train_distribute_cls( + num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls( diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index c7036daa3e3..0fd3acd0451 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -61,7 +61,6 @@ def get_input_datasets(use_bfloat16=False): # train dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() - train_ds = train_ds.shuffle(100) train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) train_ds = train_ds.batch(64, drop_remainder=True) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py index f07ec8234df..ac1ccd64b32 100644 --- a/tensorflow/contrib/distribute/python/input_ops.py +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import ops @@ -27,9 +28,8 @@ from tensorflow.python.platform import tf_logging # TODO(priyag): Any other reader datasets to consider here? _READER_DATASET_OPS = [ - "TextLineDataset", - "TFRecordDataset", - "FixedLengthRecordDataset" + "TextLineDataset", "TFRecordDataset", "FixedLengthRecordDataset", + "FixedLengthRecordDatasetV2" ] @@ -75,6 +75,8 @@ def auto_shard_dataset(dataset, num_shards, index): # instead of updating in-place. return dataset._clone( filenames=dataset._filenames.shard(num_shards, index)) + elif isinstance(dataset, dataset_ops.RangeDataset): + return dataset.shard(num_shards, index) elif hasattr(dataset, "_map_func"): # TODO(priyag): Make this check more robust by enforcing some common # property on all map/flatmap/interleave datasets. diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index f4c222f26c3..46a1cf41c55 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -157,7 +157,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): (var, m, v, op, counter) = dist.call_for_each_replica( - create_fn, dist.worker_device_index, run_concurrently=False) + create_fn, args=[dist.worker_device_index]) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 4cd8ac14100..0db5844e4c4 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -47,7 +47,6 @@ _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 -_TOLERANCE = 1e-5 # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as @@ -213,10 +212,76 @@ def multi_input_output_model(): return model +def get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + if with_distribution: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': 1, + 'shuffle': False, + } + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + # TODO(b/119318587): We should not require batch_size when distribution + # is enabled. + 'batch_size': (len(x_predict) // with_distribution.num_replicas_in_sync + if with_distribution else None), + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': 1, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) + if with_distribution: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'batch_size': None, + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + strategies = [combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 combinations.tpu_strategy_one_step] @@ -245,6 +310,13 @@ def strategy_and_optimizer_combinations(): mode=['graph']) +def strategy_and_inputs(): + return combinations.combine( + distribution=strategies, + use_numpy=[True, False], + mode=['graph']) + + class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): @@ -413,8 +485,8 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' 'of replicas'): - # The batch size(32) * num_replicas(3) is 96 which is greater than the - # number of input samples(64). + # The batch size(32) * num_replicas_in_sync(3) is 96 which is greater + # than the number of input samples(64). distributed_training_utils.get_input_batch_params(inputs, 32, strategy) @@ -598,36 +670,33 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - loss = 'mse' - user_controlled_model = get_model() - user_controlled_optimizer = gradient_descent.GradientDescentOptimizer( - 0.001) - user_controlled_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - user_controlled_model.compile(user_controlled_optimizer, loss, - metrics=user_controlled_metrics, - distribute=distribution) + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) interleaved_model = get_model() - interleaved_optimizer = gradient_descent.GradientDescentOptimizer(0.001) - interleaved_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - interleaved_model.compile(interleaved_optimizer, loss, - metrics=interleaved_metrics, - distribute=distribution) + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) dataset = get_dataset(distribution) # Call fit with validation interleaved - interleaved_output = interleaved_model.fit(dataset, epochs=2, - steps_per_epoch=2, verbose=0, - validation_data=dataset, - validation_steps=2) + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) # Manually control the validation running after each epoch. user_controlled_output = [] for _ in range(2): user_controlled_model.fit( - dataset, epochs=1, steps_per_epoch=2, verbose=0) + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) user_controlled_output.append( user_controlled_model.evaluate(dataset, steps=2)) @@ -1019,26 +1088,36 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_replicas + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) self.assertEqual(history.history['binary_accuracy'], [1.0]) - @combinations.generate(strategy_combinations()) - def test_correctness(self, distribution): + @combinations.generate(strategy_and_inputs()) + def test_correctness(self, distribution, use_numpy): with self.cached_session(): + tolerance = 1e-5 + + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + # TODO(b/119257215): use the default one once the flakyness is fixed. + tolerance = 1e-4 + keras.backend.set_image_data_format('channels_last') - num_samples = 10000 np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) - # Train and predict datasets are created with the same input numpy arrays. + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 x_train = np.random.rand(num_samples, 1) y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and @@ -1052,49 +1131,38 @@ class TestDistributionStrategyCorrectness(test.TestCase, initial_weights = model.get_weights() def fit_and_predict(with_distribution=None): + # We have initialized the model to the same weight for the distribution + # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), distribute=with_distribution) - batch_size = 64 - if with_distribution: - batch_size //= with_distribution.num_replicas - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, - y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - # We have initialized the model to the same weight for the distribution - # and non-distribution run. If you want to initialize the model to - # random weights for each run, you need to run the model through the - # entire dataset at least once to ensure that the weights converge to - # the same value. - model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict)) + model.fit(**training_inputs) + eval_result = model.evaluate(**eval_inputs) weights = model.get_weights() - x_predict = [[1.], [2.], [3.], [4.]] - predict_batch_size = 4 - if with_distribution: - predict_batch_size //= with_distribution.num_replicas - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, distribution) - predict_result = model.predict(predict_dataset, steps=1) + predict_result = model.predict(**predict_inputs) - return weights, predict_result + return weights, eval_result, predict_result - wts_with_ds, predict_with_ds = fit_and_predict( + wts_with_ds, eval_with_ds, predict_with_ds = fit_and_predict( with_distribution=distribution) - wts_without_ds, predict_without_ds = fit_and_predict( + wts_without_ds, eval_without_ds, predict_without_ds = fit_and_predict( with_distribution=None) - # Verify that the weights are the same within some limits of tolerance. + # Verify that the weights, eval results, predict outputs are the same + # within some limits of tolerance. self.assertAllClose( - wts_with_ds, wts_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) - # Verify that the predicted outputs are the same within some limits of - # tolerance. + wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) self.assertAllClose( - predict_with_ds, predict_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) + eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) + self.assertAllClose( + predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) # TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 9e1a7ad3932..c28ab416518 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, inputs) + metric_fn, args=[inputs]) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) @@ -111,14 +111,14 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): # In each run, we run multiple steps, and each steps consumes as many # batches as number of replicas. batches_per_update = ( - distribution.num_replicas * distribution.steps_per_run) + distribution.num_replicas_in_sync * distribution.steps_per_run) else: value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_replicas" with "1". - batches_per_update = distribution.num_replicas + # replace "distribution.num_replicas_in_sync" with "1". + batches_per_update = distribution.num_replicas_in_sync self.evaluate(iterator.initializer) self.evaluate(distribution.initialize()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 165732d578f..c6562463edb 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized import numpy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops @@ -67,8 +66,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -111,7 +109,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): return distribution.group( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): with self.cached_session() as sess: @@ -162,8 +160,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -221,7 +218,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, update_ops_in_cross_replica_mode): """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): - num_replicas = len(distribution.worker_devices) + num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_replicas, @@ -229,17 +226,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - # Make sure prefetching is disabled since that makes the - # specific input on each device to be non deterministic, and - # this test relies on specific input being on each device. - if isinstance(distribution, mirrored_strategy.MirroredStrategy): - self.assertFalse(distribution._prefetch_on_device) - def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=batchnorm.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) @@ -334,8 +324,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, x, y): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, x, y, run_concurrently=False)) + distribution.call_for_each_replica(model_fn, args=(x, y))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -369,10 +358,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 # with sum loss reduction, or 10.6 with mean. if loss_reduction == losses_impl.Reduction.SUM: - # Note that the "distribution.num_replicas" factor will go away once - # we split the input across replicas, instead of pulling a complete + # Note that the "distribution.num_replicas_in_sync" factor will go away + # once we split the input across replicas, instead of pulling a complete # batch of input per replica. - self.assertNear(weight, 2 + 21.2 * distribution.num_replicas, 0.0001) + self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync, + 0.0001) else: # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) @@ -420,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(output_context, *inputs): (train_op, loss) = distribution.call_for_each_replica( - model_fn, output_context, *inputs, run_concurrently=False) + model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( name="cross_replica_loss_agg", output=loss, @@ -491,7 +481,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def _verify_loss_output(self, initial_loss, loss_output, aggregated, distribution): if not aggregated: - self.assertEqual(distribution.num_replicas, + self.assertEqual(distribution.num_replicas_in_sync, len(distribution.unwrap(loss_output))) loss_output = distribution.reduce( aggregation=variables_lib.VariableAggregation.MEAN, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index c23de069498..2d75024e7a0 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -73,17 +73,14 @@ class _RequestedStop(Exception): # TODO(yuefengz): maybe create a common class for those who need to call this # _call_for_each_replica. -def _call_for_each_replica(distribution, fn, *args, **kwargs): +def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. + args: positional arguments for `fn` + kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. @@ -92,16 +89,12 @@ def _call_for_each_replica(distribution, fn, *args, **kwargs): RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ - run_concurrently = kwargs.pop("run_concurrently", True) + # TODO(josh11b): Add this option once we add synchronization to variable + # creation. Until then, this is pretty unsafe to use. + run_concurrently = False if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) @@ -192,7 +185,7 @@ def _reduce_non_distributed_value(distribution, aggregation, value, raise ValueError("You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") - # If the same value is present on all replicas then the PerDevice value will + # If the same value is present on all replicas then the PerReplica value will # be a single value. We also handle the case when `value` is a single value # and equal to 0. if value == 0: @@ -348,8 +341,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): specified. cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not set, the `configure` method will try to find the best one. - prefetch_on_device: optional boolean to specify whether to prefetch input - data to devices. auto_shard_dataset: whether to auto-shard the dataset when there are multiple workers. cross_tower_ops: Deprecated alias for `cross_device_ops`. @@ -360,14 +351,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_device_ops=None, - prefetch_on_device=None, auto_shard_dataset=False, cross_tower_ops=None): super(MirroredStrategy, self).__init__() assert not (cross_device_ops and cross_tower_ops) self._cross_tower_ops = cross_device_ops or cross_tower_ops - self._prefetch_on_device = prefetch_on_device self._auto_shard_dataset = auto_shard_dataset # Remember num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: @@ -402,7 +391,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # TODO(josh11b): Require at least 2 devices? self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)}) + self._device_index = values.PerReplica( + {d: i for i, d in enumerate(devices)}) def _initialize_multi_worker(self, num_gpus, cluster_spec): """Initializes the object for multi-worker training.""" @@ -417,19 +407,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if num_gpus is None: raise ValueError("`num_gpus` is required if `cluster_spec` is given.") if num_gpus > 0: - self._worker_device_map = { - worker: [ + self._worker_devices = [ + (worker, [ device_util.canonicalize(worker + "/device:GPU:%d" % gpu) for gpu in range(num_gpus) - ] for worker in self._workers - } + ]) for worker in self._workers + ] else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, "/device:CPU:0")] + self._worker_devices = [ + (worker, [device_util.canonicalize(worker, "/device:CPU:0")]) for worker in self._workers - } + ] - devices = nest.flatten(self._worker_device_map) + devices = nest.flatten([l for _, l in self._worker_devices]) # Setting `_default_device` will add a device scope in the # distribution.scope. We set the default device to the first worker. When @@ -446,7 +436,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # TODO(josh11b): Require at least 2 devices? self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice( + self._device_index = values.PerReplica( {d: i for i, d in enumerate(devices)}) def _create_variable(self, next_creator, *args, **kwargs): @@ -490,12 +480,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def distribute_dataset(self, dataset_fn): if self._cluster_spec: return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device, self._auto_shard_dataset) + partial(self._call_dataset_fn, dataset_fn), self._worker_devices, + auto_shard=self._auto_shard_dataset) else: - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), self._devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, @@ -546,10 +535,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerDevice container. + # container, else in a PerReplica container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + {d: t for d, t in zip(self._devices, output)}, values.PerReplica) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] @@ -562,23 +551,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._get_cross_tower_ops().broadcast(tensor, destinations or self._devices) - def _call_for_each_replica(self, fn, *args, **kwargs): - return _call_for_each_replica(self, fn, *args, **kwargs) - - def map(self, map_over, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - index = {} - for i, m in enumerate(map_over): - d = self._devices[i % len(self._devices)] - with ops.device(d): - l = index.get(d, []) - l.append(fn(m, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - index[d] = l - # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput - # in addition to PerDevice data. - return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + def _call_for_each_replica(self, fn, args, kwargs): + return _call_for_each_replica(self, fn, args, kwargs) def configure(self, session_config=None, @@ -617,9 +591,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerDevice or Mirrored - # values. For example, the same value could be present on all replicas in - # which case `value` would be a single value or value could be 0. + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: @@ -818,7 +793,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): `MirroredStrategy.call_for_each_replica()`). """ - def _merge_call(self, fn, *args, **kwargs): + def _merge_call(self, fn, args, kwargs): """Delegate to the main thread to actually perform merge_call().""" t = threading.current_thread() # a _MirroredReplicaThread t.merge_fn = fn @@ -837,5 +812,9 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): @property def device(self): + raise RuntimeError("Use .devices instead") + + @property + def devices(self): distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] + return [self._distribution_strategy.worker_devices[self._replica_id]] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index b8e7edaaf82..1fd18e09c01 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -78,11 +78,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self._test_minimize_loss_graph( self._get_distribution_strategy(), soft_placement=soft_placement) - def testMapReduce(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): if not GPU_TEST: self.skipTest("Not GPU test") @@ -120,7 +115,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_replica(run_fn, dist.worker_device_index) + dist.call_for_each_replica(run_fn, args=(dist.worker_device_index,)) @test_util.run_in_graph_and_eager_modes def testReduceToCpu(self): @@ -132,7 +127,8 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + run_fn, args=(dist.worker_device_index,)) reduced = dist.reduce( variable_scope.VariableAggregation.SUM, result, @@ -152,7 +148,8 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + run_fn, args=(dist.worker_device_index,)) reduced = dist.reduce( variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, result, @@ -207,7 +204,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) self.assertEquals("foo:0", result.name) @@ -225,7 +222,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # Default name of "Variable" will be used. self.assertEquals("Variable:0", result.name) @@ -246,7 +243,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) for i, v in enumerate(result): self.assertIsInstance(v, values.MirroredVariable) self.assertEquals("foo" + str(i) + ":0", v.name) @@ -269,7 +266,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) self.assertEquals(4, len(result)) @@ -293,7 +290,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with dist.scope(): result = dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False) + model_fn, args=(dist.worker_device_index,)) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. self.assertEquals("foo_0:0", result.name) @@ -329,8 +326,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): features = iterator.get_next() with dist.scope(): - result = dist.call_for_each_replica( - model_fn, features, run_concurrently=False) + result = dist.call_for_each_replica(model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) @@ -368,7 +364,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v = variable_scope.variable(1.0, name="var-main0") self.assertEquals("var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) @@ -411,7 +407,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v = variable_scope.get_variable("var-main0", [1]) self.assertEquals("main/var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) @@ -448,7 +444,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): devices = ["/device:GPU:0", "/device:CPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): - v0, v1 = dist.call_for_each_replica(create_fn, run_concurrently=False) + v0, v1 = dist.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) @@ -465,7 +461,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return update0, update1 update0a, update1a = dist.call_for_each_replica( - update_member_fn, dist.worker_device_index, run_concurrently=False) + update_member_fn, args=(dist.worker_device_index,)) # Update "sync on read" variable. self.evaluate(dist.group(update0a)) @@ -491,7 +487,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return update0, update1 update0b, update1b = dist.call_for_each_replica( - update_state_ops_fn, dist.worker_device_index, run_concurrently=False) + update_state_ops_fn, args=(dist.worker_device_index,)) self.evaluate(dist.group(update0b)) # Update "sync on read" variable. @@ -588,7 +584,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) self.assertEquals("foo:0", result.name) @@ -611,7 +607,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_replica(model_fn, names, run_concurrently=False) + _ = dist.call_for_each_replica(model_fn, args=(names,)) @test_util.run_in_graph_and_eager_modes(config=config) def testReplicaLocalVariable(self): @@ -652,7 +648,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False)) + model_fn, args=(dist.worker_device_index,))) # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) @@ -709,7 +705,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) @@ -730,7 +726,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) @@ -760,7 +756,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) @@ -793,7 +789,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) @@ -824,7 +820,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) # Two variables are created by the RNN layer. self.assertEquals(2, len(result)) for v in result: @@ -851,7 +847,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return var.assign(value) with dist.scope(): - ret_v_sum = dist.call_for_each_replica(model_fn, run_concurrently=False) + ret_v_sum = dist.call_for_each_replica(model_fn) update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) # Initialize variables. @@ -894,7 +890,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -908,7 +904,7 @@ class MirroredVariableUpdateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(config=config) def testAssignMirroredVarReplicaContextWithSum(self): - # Test that we don't reduce a non-per-device value with the "sum" + # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. self._skip_eager_if_gpus_less_than(1) def var_fn(): @@ -920,7 +916,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -942,7 +938,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -960,7 +956,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -971,8 +967,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(0.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -986,7 +981,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -994,8 +989,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(5.0, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1008,7 +1002,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1036,7 +1030,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1047,8 +1041,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(1.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1062,7 +1055,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1070,8 +1063,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(6.0, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1084,7 +1076,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1104,7 +1096,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1115,8 +1107,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(4.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1130,7 +1121,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1138,8 +1129,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(4.0, self.evaluate(mirrored_var)) @@ -1211,8 +1201,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + replica_local_var = dist.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) @@ -1243,8 +1232,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + replica_local_var = dist.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) @@ -1307,8 +1295,7 @@ class MirroredStrategyDefunTest(test.TestCase): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_replica(model_fn, mock_model, *inputs, - run_concurrently=False) + result = dist.call_for_each_replica(model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1320,11 +1307,10 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_device_graph_functions = dist.call_for_each_replica( - defun.get_concrete_function, - mock_model, *inputs, run_concurrently=False) + per_replica_graph_functions = dist.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs) for device in devices: - graph_function = per_device_graph_functions.get(device=device) + graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) @@ -1398,16 +1384,16 @@ class MirroredStrategyDefunTest(test.TestCase): two_variables=True) @test_util.run_in_graph_and_eager_modes() - def testPassPerDevice(self): + def testPassPerReplica(self): self._skip_eager_if_gpus_less_than(1) @function.defun def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) + factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) self._call_and_check(fn1, [factors], expected_result, [fn1]) @test_util.run_in_graph_and_eager_modes() @@ -1429,8 +1415,7 @@ class MirroredStrategyDefunTest(test.TestCase): gradients_fn = backprop.implicit_grad(loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = dist.call_for_each_replica( - gradients_fn, None, run_concurrently=False) + grads_and_vars = dist.call_for_each_replica(gradients_fn, args=(None,)) optimizer = gradient_descent.GradientDescentOptimizer(0.25) update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 2bfe0f3e7a6..bea684e77ca 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -40,9 +40,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): self._test_device_index(self._get_distribution_strategy()) @@ -83,7 +80,8 @@ class VariableCreatorStackTest(test.TestCase): with context.graph_mode(), \ dist.scope(), \ variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_replica(model_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + model_fn, args=(dist.worker_device_index,)) result = dist.unwrap(result) expected = ["main_thread:thread_0", "main_thread:thread_1"] self.assertEquals(expected, result) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 815644421e3..7ecc852d205 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -93,7 +93,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([10.0, 11.0]) val = constant_op.constant([1.0, 2.0]) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average( var, val, decay, zero_debias=False) @@ -121,7 +122,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([0.0, 0.0]) val = array_ops.placeholder(dtypes.float32) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average(var, val, decay) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 8bdf0012087..a0d8f938874 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -40,10 +38,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device, prefetch_on_device=None): + def __init__(self, device): super(OneDeviceStrategy, self).__init__() self._device = device - self._prefetch_on_device = prefetch_on_device self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): @@ -62,9 +59,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return next_creator(*args, **kwargs) def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), [self._device], - self._prefetch_on_device) + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), [self._device]) def _broadcast(self, tensor, destinations): del destinations @@ -117,29 +113,13 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): - # We don't run `fn` in multiple threads in OneDeviceStrategy. - kwargs.pop("run_concurrently", None) + def _call_for_each_replica(self, fn, args, kwargs): with ops.device(self._device), _OneDeviceReplicaContext(self): return fn(*args, **kwargs) - def map(self, map_over, fn, *args, **kwargs): - with ops.device(self._device): - return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, aggregation, value, destinations): - del destinations - if not isinstance(value, values.MapOutput): - return value - l = value.get() - assert l - with ops.device(self._device): - if aggregation == vs.VariableAggregation.SUM: - return math_ops.add_n(l) - elif aggregation == vs.VariableAggregation.MEAN: - return math_ops.add_n(l) / len(l) - else: - assert False + del aggregation, destinations + return value def _update(self, var, options, fn, *args, **kwargs): # The implementations of _update() and _update_non_slot() are identical @@ -171,6 +151,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def num_replicas(self): return 1 + @property + def num_replicas_in_sync(self): + return 1 + @property def worker_devices(self): return [self._device] @@ -188,6 +172,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for OneDeviceStrategy.""" def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( @@ -195,4 +180,8 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @property def device(self): - return self._distribution_strategy.worker_devices[0] + raise RuntimeError("Use .devices instead") + + @property + def devices(self): + return [self._distribution_strategy.worker_devices[0]] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 3fb92273924..95f4cdb7868 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -35,9 +35,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): self._test_device_index(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 0554f4a83bd..fa4705af7cb 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def run_step(): return control_flow_ops.group(distribution.unwrap( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built))) + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2aa7f1ae5d6..790b37f8601 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -64,7 +64,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): Operations that occur only on the first replica (such as incrementing the global step), will occur on the first replica *of every worker*. - It is expected to call `call_for_each_replica(fn, *args, **kwargs)` for any + It is expected to call `call_for_each_replica(fn, ...)` for any operations which potentially can be replicated across replicas (i.e. multiple GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra caution needs to be taken: @@ -223,7 +223,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) def _broadcast(self, tensor, destinations): @@ -231,10 +231,13 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): destinations = self._compute_devices return self._cross_tower_ops.broadcast(tensor, destinations) + def _allow_variable_partition(self): + return not context.executing_eagerly() + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_replicas > 1: + if self.num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, @@ -288,9 +291,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): with ops.device(self._variable_device): return var_creator(*args, **kwargs) - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica(self, fn, *args, **kwargs) + return mirrored_strategy._call_for_each_replica(self, fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: @@ -336,9 +339,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "You cannot update variable with a Mirrored object with multiple " "components %r when using ParameterServerStrategy. You must " "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerDevice): + elif isinstance(x, values.PerReplica): raise ValueError( - "You cannot update variable with a PerDevice object %r when using " + "You cannot update variable with a PerReplica object %r when using " "ParameterServerStrategy. You must specify a single value or a " "Mirrored with a single value" % x) else: diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index a9f643c6ecc..81a23c89030 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -37,6 +37,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -85,8 +86,7 @@ class ParameterServerStrategyTestBase( config=sess_config) as sess, \ d.scope(): - # Define a variable outside the call_for_each_replica scope. This is not - # recommended. + # Define a variable outside the call_for_each_replica scope. n = variable_scope.get_variable('n', initializer=10.0) self.assertEqual(n.device, '/job:ps/task:0') @@ -178,6 +178,75 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) + def _test_device_assignment_distributed_enable_partitioner( + self, task_type, task_id, num_gpus): + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + num_shards = len(d.parameter_devices) + partitioner = partitioned_variables.fixed_size_partitioner(num_shards) + with ops.Graph().as_default(), \ + self.cached_session(target=self._default_target, + config=sess_config) as sess, \ + d.scope(): + + n = variable_scope.get_variable( + 'n', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + + for part_id, var in enumerate(n): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + + def model_fn(): + a = constant_op.constant([3.0, 5.0]) + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable( + 'x', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + x_add = x.assign_add(a, name='x_add') + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + for part_id, var in enumerate(x): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + self.assertEqual(var.device, x_add[part_id].device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x_add[0]): + y = variable_scope.get_variable( + 'y', + initializer=constant_op.constant([20.0, 10.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + y_add = y.assign_add( + [array_ops.identity(x_add[0]), + array_ops.identity(x_add[1])]) + + for part_id, var in enumerate(y): + self.assertEqual(var.device, '/job:ps/task:0') + self.assertEqual(y_add[part_id].device, var.device) + self.assertEqual(var.device, x_add[0].device) + + return x_add, y_add + + x, y = d.call_for_each_replica(model_fn) + + if context.num_gpus() >= 1: + variables.global_variables_initializer().run() + x_val, y_val = sess.run([x, y]) + if num_gpus < 1: + self.assertEqual(x_val, [13.0, 25.0]) + self.assertEqual(y_val, [33.0, 35.0]) + else: + x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] + y_expect = [ + 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus + ] + self.assertEqual(x_val, x_expect) + self.assertEqual(y_val, y_expect) + def _test_device_assignment_local(self, d, compute_device='CPU', @@ -345,11 +414,11 @@ class ParameterServerStrategyTestBase( self._finish_condition.release() x_val, y_val, z_val = sess.run([x, y, z]) - self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas) - self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) self.assertEqual(z_val, 30.0 + 1.0 * num_workers) - return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas and - y_val == 20.0 + 1.0 * num_workers * d.num_replicas and + return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas_in_sync and + y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): @@ -394,7 +463,7 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -479,6 +548,12 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testDeviceAssignmentDistributed(self, num_gpus): self._test_device_assignment_distributed('worker', 1, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + self._test_device_assignment_distributed_enable_partitioner( + 'worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, self._cluster_spec, context.num_gpus()) diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index a5adaac47ce..3dc815f0371 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -90,7 +90,6 @@ class StandardSingleLossStep(StandardInputStep): super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer - self._is_run_concurrently = False self._iterations_per_step = iterations_per_step def __call__(self): @@ -101,14 +100,11 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, - ctx, *inputs, - run_concurrently=self._is_run_concurrently) + gradients_fn, args=(ctx,) + inputs) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be # created. - self._is_run_concurrently = True return self._optimizer._distributed_apply( # pylint: disable=protected-access self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 60ef0a2106a..3c0c10430eb 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -104,7 +104,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one, run_concurrently=l.built) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -160,7 +160,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -189,15 +189,6 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_map_reduce(self, d, in_graph=None): - with d.scope(): - map_in = [constant_op.constant(i) for i in range(10)] - map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out, - "/device:CPU:0") - expected = 90 # 2 * (0 + 1 + ... + 9) - self.assertEqual(expected, observed.numpy()) - def _test_device_index(self, d): with d.scope(): expected_devices = [False] * len(d.worker_devices) @@ -207,7 +198,7 @@ class DistributionTestBase(test.TestCase): self.assertFalse(expected_devices[device_id]) expected_devices[device_id] = True - d.call_for_each_replica(mark_devices_fn, d.worker_device_index) + d.call_for_each_replica(mark_devices_fn, args=(d.worker_device_index,)) self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) def _test_replica_id(self, d): diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 65ef21df09b..f5b4531ba8c 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -141,7 +141,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # parallelism. device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name} - self._device_index = values.PerDevice(device_map) + self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. @@ -215,12 +215,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return enqueue_op_per_host def distribute_dataset(self, dataset_fn): - worker_map = { - self.get_host(hid): [self.get_host_cpu_device(hid)] + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) - } + ] return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_map) + functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have @@ -308,7 +308,8 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. + # TODO(josh11b): If aggregation is NONE, we should return a PerReplica + # value. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] @@ -316,10 +317,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - kwargs.pop("run_concurrently", None) with _TPUReplicaContext(self): return fn(*args, **kwargs) @@ -445,7 +445,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should - # be represented using a PerDevice wrapper instead of a list with + # be represented using a PerReplica wrapper instead of a list with # one entry per device. return val return [val] @@ -544,5 +544,9 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): @property def device(self): + raise RuntimeError("Use .devices instead") + + @property + def devices(self): distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] + return [self._distribution_strategy.worker_devices[self._replica_id]] diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 42fb92014a0..a1629735353 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -51,7 +51,7 @@ from tensorflow.python.util import nest # TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): - """Holds a map from device to values. Either PerDevice or Mirrored.""" + """Holds a map from device to values. Either PerReplica or Mirrored.""" def __init__(self, index): self._index = {device_util.canonicalize(key): value @@ -62,7 +62,8 @@ class DistributedValues(object): if device is None: replica_context = distribution_strategy_context.get_replica_context() if replica_context: - device = replica_context.device + # TODO(josh11b): support model parallelism better here + device = replica_context.devices[0] else: device = distribute_lib.get_update_device() if device is None: @@ -75,10 +76,6 @@ class DistributedValues(object): ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current())), e) - def on_device(self, device): - device = device_util.canonicalize(device) - return device in self._index - @property def devices(self): return list(self._index.keys()) @@ -167,12 +164,12 @@ class DistributedDelegate(DistributedValues): # TODO(josh11b): Even more operator overloads. -class PerDevice(DistributedValues): +class PerReplica(DistributedValues): """Holds a map from device to unsynchronized values.""" pass -# Note that unlike PerDevice, Mirrored values inherit from +# Note that unlike PerReplica, Mirrored values inherit from # DistributedDelegate and so can be used directly in cross-replica mode. class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" @@ -482,7 +479,8 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): if device is None: replica_context = distribution_strategy_context.get_replica_context() if replica_context: - device = replica_context.device + # TODO(josh11b): support model parallelism better here + device = replica_context.devices[0] else: device = distribute_lib.get_update_device() if device is None: @@ -583,7 +581,8 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): - if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": + strategy = distribution_strategy_context.get_distribution_strategy() + if strategy.__class__.__name__ != "TPUStrategy": raise ValueError("You may only assign to a TPUMirroredVariable within a " "TPUStrategy.") f = kwargs.pop("f") @@ -776,6 +775,18 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): def op(self): return self._primary_var.op + # pylint: disable=protected-access + @property + def _save_slice_info(self): + return self._primary_var._save_slice_info + + def _get_save_slice_info(self): + return self._primary_var._get_save_slice_info() + + def _set_save_slice_info(self, save_slice_info): + return self._primary_var._set_save_slice_info(save_slice_info) + # pylint: enable=protected-access + @property def _in_graph_mode(self): return self._primary_var._in_graph_mode # pylint: disable=protected-access @@ -861,7 +872,7 @@ def _assert_replica_context(): "Replica-local variables may only be assigned in a replica context.") -class ReplicaLocalVariable(DistributedVariable, PerDevice, +class ReplicaLocalVariable(DistributedVariable, PerReplica, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -942,9 +953,9 @@ def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) -def regroup(per_device, wrap_class=PerDevice): - """Makes device->nest map into a nest of PerDevice/Mirrored values.""" - items = list(per_device.items()) +def regroup(per_replica, wrap_class=PerReplica): + """Makes device->nest map into a nest of PerReplica/Mirrored values.""" + items = list(per_replica.items()) assert items v0 = items[0][1] # First value @@ -1005,7 +1016,7 @@ def regroup(per_device, wrap_class=PerDevice): # want to return the containing MirroredVariable, after a bunch of # sanity checking. In particular, each component should have the # same container, and the devices of the variables should match the - # keys of the per-device dictionary. + # keys of the per-replica dictionary. if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( @@ -1021,11 +1032,11 @@ def regroup(per_device, wrap_class=PerDevice): return distributed_container # pylint: enable=protected-access - return wrap_class(per_device) + return wrap_class(per_replica) def select_device(device, structured): - """Specialize a nest of regular & per-device values for one device.""" + """Specialize a nest of regular & per-replica values for one device.""" def _get(x): return x.get(device) if isinstance(x, DistributedValues) else x @@ -1070,8 +1081,8 @@ def update_regroup(strategy, updates, should_group): return nest.pack_sequence_as(regrouped, grouped_flat) -class PerDeviceDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" +class PerReplicaDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`.""" def __init__(self, iterator, devices, prefetch_on_device=None): self._iterator = iterator @@ -1114,8 +1125,8 @@ class PerDeviceDataIterator(object): return self._iterator.output_types -class PerDeviceDataset(object): - """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" +class PerReplicaDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerReplica` data.""" def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices @@ -1136,20 +1147,20 @@ class PerDeviceDataset(object): self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): - """Get a one time use iterator for the distributed PerDeviceDataset.""" + """Get a one time use iterator for the distributed PerReplicaDataset.""" # Graph mode with one shot iterator is disabled. if not context.executing_eagerly(): raise ValueError("Cannot create a one shot iterator. Please use " "`make_initializable_iterator()` instead.") # Eager mode prefetching would error out in constructor. Only remaining # case is non-prefetching in eager mode. We delegate to - # PerDeviceDataIterator to handle that case. + # PerReplicaDataIterator to handle that case. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( + return PerReplicaDataIterator( dataset_iterator, self._devices, prefetch_on_device=False) def make_initializable_iterator(self): - """Get an initializable iterator for the distributed PerDeviceDataset.""" + """Get an initializable iterator for the distributed PerReplicaDataset.""" # Eager mode generates already initialized iterators. Hence we cannot create # an initializable iterator. if context.executing_eagerly(): @@ -1160,7 +1171,7 @@ class PerDeviceDataset(object): self._dataset, self._devices) else: dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( + return PerReplicaDataIterator( dataset_iterator, self._devices, prefetch_on_device=self._prefetch_on_device) @@ -1169,43 +1180,47 @@ class PerDeviceDataset(object): class MultiWorkerDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" - def __init__(self, iterators, worker_device_map): + def __init__(self, iterators, worker_device_pairs): """Initialize the MultiWorkerDataIterator object. Args: - iterators: a dict mapping from each worker to an iterator for - that worker. - worker_device_map: a dict mapping from each worker's devices to a list of - devices that belong to this worker. + iterators: a list of worker, iterator pairs. + worker_device_pairs: a list of (worker's devices, a list of + devices that belong to this worker) pairs. Raises: - ValueError: if iterators and worker_device_map are not compatible. + ValueError: if iterators and worker_device_pairs are not compatible. """ - self._iterators = iterators - self._worker_device_map = worker_device_map - if set(self._iterators) != set(self._worker_device_map): - raise ValueError("iterators and worker_device_map are not compatible.") + if [d for d, _ in iterators] != [d for d, _ in worker_device_pairs]: + raise ValueError("iterators and worker_device_pairs are not compatible.") + self._workers = [d for d, _ in iterators] + self._iterators = [i for _, i in iterators] + self._worker_devices = [l for _, l in worker_device_pairs] @property def initializer(self): return control_flow_ops.group( - [iterator.initializer for iterator in self._iterators.values()]) + [iterator.initializer for iterator in self._iterators]) def get_iterator(self, worker): - return self._iterators.get(worker) + for i, w in enumerate(self._workers): + if worker == w: + return self._iterators[i] + return None @property def output_shapes(self): - return self._iterators.values()[0].output_shapes + return self._iterators[0].output_shapes @property def output_types(self): - return self._iterators.values()[0].output_types + return self._iterators[0].output_types def get_next(self, name=None): """Scatter the input across hosts and devices.""" index = {} - for worker, iterator in six.iteritems(self._iterators): + worker_info = zip(self._workers, self._iterators, self._worker_devices) + for worker, iterator, worker_devices in worker_info: if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) @@ -1214,13 +1229,12 @@ class MultiWorkerDataIterator(object): with ops.device(worker): data_per_worker = iterator.get_next(name=new_name) - worker_devices = self._worker_device_map[worker] - # Ungroup these per-device value so as to get a flat map from devices to + # Ungroup these per-replica value so as to get a flat map from devices to # values. for d in worker_devices: v = select_device(d, data_per_worker) if d in index: - raise ValueError("Duplicated devices in worker_device_map: %r" % v) + raise ValueError("Duplicated devices in worker_device_pairs: %r" % v) index[d] = v return regroup(index) @@ -1229,153 +1243,48 @@ class MultiWorkerDataIterator(object): class MultiWorkerDataset(object): """Like a `tf.data.Dataset` that distributes data to different workers. - Each worker gets one shard of the input dataset. It is currently not working - in - eager mode. + Each worker gets one shard of the input dataset. This currently does not work + in eager mode. """ - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, + def __init__(self, dataset_fn, worker_device_pairs, prefetch_on_device=None, auto_shard=False): """Initialize the MultiWorkerDataset object. Args: dataset_fn: a function that returns a `tf.data.Dataset`. - worker_device_map: a dict mapping from each worker to a list of devices - that belong to this worker. + worker_device_pairs: a list of (worker, list of devices on that worker) + pairs. prefetch_on_device: whether to prefetch to devices. auto_shard: whether to auto-shard the dataset. """ - self._worker_device_map = worker_device_map - self._datasets = {} + self._worker_device_pairs = worker_device_pairs + self._datasets = [] # TODO(yuefengz, priyag): support different set of jobs for input # processing. - for i, (worker, worker_devices) in enumerate( - six.iteritems(worker_device_map)): + for i, (worker, worker_devices) in enumerate(worker_device_pairs): with ops.device(worker): worker_input = dataset_fn() if auto_shard: worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) - self._datasets[worker] = PerDeviceDataset( + worker_input, len(worker_device_pairs), i) + dataset = PerReplicaDataset( worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + self._datasets.append((worker, dataset)) def make_one_shot_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): + iterators = [] + for worker, dataset in self._datasets: with ops.device(worker): - iterators[worker] = dataset.make_one_shot_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) + iterators.append((worker, dataset.make_one_shot_iterator())) + return MultiWorkerDataIterator(iterators, self._worker_device_pairs) def make_initializable_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): + iterators = [] + for worker, dataset in self._datasets: with ops.device(worker): - iterators[worker] = dataset.make_initializable_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) - - -class _PerKey(object): - """Holds data associated by keys.""" - - def __init__(self, *index): - # pylint: disable=protected-access - self._index = list(index) - - def get(self, iteration): - return array_ops.gather(self._index, iteration) - - def get_shape(self): - return self._index[-1][-1].get_shape() - - def get_dtype(self): - return self._index[-1][-1].dtype - - def __str__(self): - return "%s:%s" % (self.__class__.__name__, self._index) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._index) - - -class PerIteration(_PerKey): - """Holds input for multiple iterations at once.""" - - def __init__(self, *index): - # pylint: disable=protected-access - super(PerIteration, self).__init__(*[batch._index for batch in index]) - - -class Batches(_PerKey): - pass - - -class MultiIterator(object): - """Iterator that returns results of multiple get_next()s.""" - - def __init__(self, dataset_iterator, iterations, batches_per_iteration): - self._dataset_iterator = dataset_iterator - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def get_next(self, name=None): - """Return PerIteration with `iterations x batches_per_iteration` inputs.""" - data = [] - for _ in range(self._batches_per_iteration): - batch = [] - for _ in range(self._iterations): - batch.append(self._dataset_iterator.get_next(name=name)) - data.append(batch) - - # Here is an example. Suppose each get_next returns a tuple of two tensors. - # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: - # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] - # - # After the first `map_structure` it gets transformed to: - # [(Batches(a, A), Batches(z, Z)), - # (Batches(b, B), Batches(y, Y)), - # (Batches(c, C), Batches(x, X))] - # - # After the second `map_structure` it gets transformed to a tuple of: - # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), - # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) - - data = nest.map_structure(Batches, *data) - data = nest.map_structure(PerIteration, *data) - - return data - - @property - def initializer(self): - return self._dataset_iterator.initializer - - -class PerIterationDataset(object): - """A dataset that returns MultiIterators.""" - - def __init__(self, dataset, iterations, batches_per_iteration): - self._dataset = dataset - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def make_one_shot_iterator(self): - iterator = self._dataset.make_one_shot_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - def make_initializable_iterator(self): - iterator = self._dataset.make_initializable_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - -class MapOutput(object): - """Map can result in multiple outputs per device.""" - - def __init__(self, l): - self._l = l - - def get(self): - return self._l + iterators.append((worker, dataset.make_initializable_iterator())) + return MultiWorkerDataIterator(iterators, self._worker_device_pairs) class MultiStepContext(object): @@ -1430,13 +1339,13 @@ class MultiStepContext(object): output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple - replicas. Required if `set_last_step_output` is called in a replica context. - Optional in cross_replica_context. + replicas. Required if `set_last_step_output` is called in a replica + context. Optional in cross_replica_context. When present, the outputs from all the replicas are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output - must be a `PerDevice` value. + must be a `PerReplica` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. @@ -1482,7 +1391,7 @@ class MultiStepContext(object): def value_container(val): - """Returns the container that this per-device `value` belongs to. + """Returns the container that this per-replica `value` belongs to. Args: val: A value returned by `call_for_each_replica()` or a variable @@ -1528,8 +1437,8 @@ class AggregatingVariable(checkpointable.CheckpointableBase): # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) - # We are calling an assign function in cross replica context, wrap it in an - # update call. + # We are calling an assign function in cross replica context, wrap it in + # an update call. return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index d514e6f4c15..268393ee801 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os from tensorflow.contrib.distribute.python import mirrored_strategy @@ -190,10 +189,10 @@ def _make_mirrored(): class RegroupAndSelectDeviceTest(test.TestCase): - def _is_per_device(self, result, expected, klass=values.PerDevice): + def _is_per_replica(self, result, expected, klass=values.PerReplica): self.assertIsInstance(result, klass) # We canonicalize the devices to match the device strings returned - # by PerDevice, which also does device string canonicalization. + # by PerReplica, which also does device string canonicalization. devices = [device_util.canonicalize(_device_str(i)) for i in range(len(expected))] self.assertEqual(set(devices), set(result.devices)) @@ -206,18 +205,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"]) - self._is_per_device(result[2], ["h1", "h2"]) + self._is_per_replica(result[0], ["a1", "a2"]) + self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"]) - self._is_per_device(result[1][2], ["g1", "g2"]) + self._is_per_replica(result[1][0], ["b1", "b2"]) + self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"]) - self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -238,18 +237,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) - self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) - self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) - self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -275,7 +274,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): ("b", foo)}) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) - self._is_per_device(result[0], ["a", "b"]) + self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_device(), should undo the merge done by regroup(). @@ -341,53 +340,30 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -class PerDeviceDatasetTest(test.TestCase): +class PerReplicaDatasetTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True - def _test_iterator_no_prefetch(self, devices, dataset, expected_values): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) + def _test_iterator(self, devices, dataset, expected_values): + per_replica_dataset = values.PerReplicaDataset(dataset, devices) if context.executing_eagerly(): - iterator = per_device_dataset.make_one_shot_iterator() + iterator = per_replica_dataset.make_one_shot_iterator() else: - iterator = per_device_dataset.make_initializable_iterator() + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() - actual = self.evaluate([ - values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate([ values.select_device(d, next_element) for d in devices]) - def _test_iterator_with_prefetch(self, devices, dataset, expected_values): - if not context.executing_eagerly(): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - def _test_iterator(self, devices, dataset, expected_values): - self._test_iterator_no_prefetch(devices, dataset, expected_values) - self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] @@ -442,9 +418,8 @@ class PerDeviceDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10,))) - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_initializable_iterator() + per_replica_dataset = values.PerReplicaDataset(dataset, devices) + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next() @@ -463,7 +438,7 @@ class PerDeviceDatasetTest(test.TestCase): class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - def _test_iterator(self, iterator, devices, expected_values): + def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() for device in devices: v = values.select_device(device, next_element) @@ -472,73 +447,79 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): self.assertTrue(element.device in device) for expected_value in expected_values: - actual = self.evaluate( + actual = sess.run( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): - self.evaluate([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_device(d, next_element) for d in devices]) - def _test_dataset(self, dataset_fn, worker_device_map, devices, - expected_values): + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values, auto_shard=True): multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) - multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() - self._test_iterator(multi_worker_iterator, devices, expected_values) + dataset_fn, worker_devices, auto_shard=auto_shard) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) def _cpu_devices(self): - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])]) + worker_devices = [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def _cpu_and_one_gpu_devices(self): - # The worker_device_map doesn't have to be a OrderDict object, this is just - # to simplify the testing so that we can pass expected values as a list - # instead of a dict. - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", [ + worker_devices = [ + ("/job:worker/replica:0/task:0", [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0" - ]), ("/job:worker/replica:0/task:1", [ + ]), + ("/job:worker/replica:0/task:1", [ "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" - ])]) + ]) + ] devices = [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def testDataDistributionOneDevicePerWorker(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) + def testDataDistributionNoAutoShard(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_dataset(dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], + auto_shard=False) + def testDataDistributionTwoDevicePerWorker(self): - self.skipTest("Temporarily disabled.") if context.num_gpus() < 1: self.skipTest("A GPU is not available for this test.") - worker_device_map, devices = self._cpu_and_one_gpu_devices() + worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 2, 1, 3], [4, 6, 5, 7]]) def testTupleDataset(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): @@ -550,41 +531,38 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): expected_values = [ [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) ] - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, expected_values) def testInitializableIterator(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() - with context.graph_mode(): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) # After re-initializing the iterator, should be able to iterate again. - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) def testValueErrorForIterator(self): - self.skipTest("Temporarily disabled.") # Incompatiable arguments. with self.assertRaises(ValueError): values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) # Test duplicated devices under same worker. - worker_device_map, _ = self._cpu_devices() - worker_device_map["/job:worker/replica:0/task:0"].append( - "/job:worker/replica:0/task:0/device:CPU:0") + worker_devices, _ = self._cpu_devices() + worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.assertRaises(ValueError): multi_worker_iterator.get_next() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 6e775afb69a..67ffb939663 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -544,4 +544,19 @@ class SequenceNumericColumn( return fc.SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=seq_length) + # TODO(b/119409767): Implement parents, _{get,from}_config. + @property + def parents(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + def _get_config(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + @classmethod + def _from_config(cls, config, custom_objects=None, columns_by_name=None): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + # pylint: enable=protected-access diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index cfdc7df7d8f..00e04b83ac4 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -44,17 +44,18 @@ class ExperimentalTest(test.TestCase): # Assert function docs are properly updated. self.assertEqual("_fn", _fn.__name__) - self.assertEqual("fn doc. (experimental)" - "\n" - "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " - "be removed at any time, and without warning." - "\n" - "\nArgs:" - "\n arg0: Arg 0." - "\n arg1: Arg 1." - "\n" - "\nReturns:" - "\n Sum of args.", _fn.__doc__) + self.assertEqual( + "fn doc. (experimental)" + "\n" + "\nWarning: THIS FUNCTION IS EXPERIMENTAL. It may change " + "or be removed at any time, and without warning." + "\n" + "\nArgs:" + "\n arg0: Arg 0." + "\n arg1: Arg 1." + "\n" + "\nReturns:" + "\n Sum of args.", _fn.__doc__) # Assert calling new fn issues log warning. self.assertEqual(3, _fn(1, 2)) diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index d8584e4e6b7..b3f48ec1dd9 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -52,9 +52,10 @@ Status GdrServer::Init() { [this](const WorkerEnv* env) { return new GdrRendezvousMgr(env, remote_memory_manager_.get()); }; - WorkerCreationFunction worker_func = [this](WorkerEnv* env) { + WorkerCreationFunction worker_func = [this](WorkerEnv* env, + const ConfigProto& config) { return std::unique_ptr( - new GdrWorker(env, remote_memory_manager_.get())); + new GdrWorker(env, config, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index ce1d8d2d730..867cb83f420 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -39,9 +39,9 @@ limitations under the License. namespace tensorflow { -GdrWorker::GdrWorker(WorkerEnv* worker_env, +GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) - : GrpcWorker(worker_env), + : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), recv_tensor_recent_request_ids_(100000) {} diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 65105ed9973..39f11e6bde5 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -25,7 +25,8 @@ namespace tensorflow { class GdrWorker : public GrpcWorker { public: - GdrWorker(WorkerEnv* env, RemoteMemoryManager* remote_memory_manager); + GdrWorker(WorkerEnv* env, const ConfigProto& config, + RemoteMemoryManager* remote_memory_manager); // Serve the RecvTensorRequest but omit the tensor content and transmit it // out-of-band using GPU Direct RDMA whenever possible. diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD index 9393b702d11..2698b83a56a 100644 --- a/tensorflow/contrib/ignite/BUILD +++ b/tensorflow/contrib/ignite/BUILD @@ -22,48 +22,92 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + ":igfs_ops", ], ) tf_custom_op_library( - name = "_dataset_ops.so", - srcs = ["ops/dataset_ops.cc"], - deps = [":dataset_kernels"], + name = "_ignite_ops.so", + srcs = [ + "kernels/igfs/igfs.h", + "ops/dataset_ops.cc", + "ops/igfs_ops.cc", + ], + deps = [ + ":dataset_kernels", + ":igfs_kernels", + ], ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) +tf_gen_op_libs( + op_lib_names = ["igfs_ops"], + deps = [":igfs_kernels"], +) + cc_library( - name = "dataset_kernels", + name = "ignite_client", srcs = [ - "kernels/ignite_dataset_ops.cc", - "kernels/ignite_client.h", - "kernels/ignite_byte_swapper.h", - "kernels/ignite_plain_client.h", - "kernels/ignite_ssl_wrapper.h", - "kernels/ignite_ssl_wrapper.cc", - "kernels/ignite_binary_object_parser.h", - "kernels/ignite_binary_object_parser.cc", - "kernels/ignite_dataset.h", - "kernels/ignite_dataset.cc", - "kernels/ignite_dataset_iterator.h", - "kernels/ignite_dataset_iterator.cc", + "kernels/client/ignite_client.h", + "kernels/client/ignite_byte_swapper.h", + "kernels/client/ignite_plain_client.h", + "kernels/client/ignite_ssl_wrapper.h", + "kernels/client/ignite_ssl_wrapper.cc", ] + if_not_windows([ - "kernels/ignite_plain_client_unix.cc", + "kernels/client/ignite_plain_client_unix.cc", ]) + if_windows([ - "kernels/ignite_plain_client_windows.cc", + "kernels/client/ignite_plain_client_windows.cc", ]), copts = if_windows([ "-DWIN32_LEAN_AND_MEAN", ]), deps = [ "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", "@boringssl//:ssl", "@protobuf_archive//:protobuf_headers", ], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/dataset/ignite_binary_object_parser.cc", + "kernels/dataset/ignite_binary_object_parser.h", + "kernels/dataset/ignite_dataset.cc", + "kernels/dataset/ignite_dataset.h", + "kernels/dataset/ignite_dataset_iterator.cc", + "kernels/dataset/ignite_dataset_iterator.h", + "kernels/dataset/ignite_dataset_ops.cc", + ], + deps = [ + ":ignite_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "igfs_kernels", + srcs = [ + "kernels/igfs/igfs.cc", + "kernels/igfs/igfs.h", + "kernels/igfs/igfs_client.cc", + "kernels/igfs/igfs_client.h", + "kernels/igfs/igfs_extended_tcp_client.cc", + "kernels/igfs/igfs_extended_tcp_client.h", + "kernels/igfs/igfs_messages.cc", + "kernels/igfs/igfs_messages.h", + "kernels/igfs/igfs_random_access_file.cc", + "kernels/igfs/igfs_random_access_file.h", + "kernels/igfs/igfs_writable_file.cc", + "kernels/igfs/igfs_writable_file.h", + ], + deps = [":ignite_client"], alwayslink = 1, ) @@ -82,10 +126,29 @@ py_library( ], ) +py_library( + name = "igfs_ops", + srcs = [ + "python/ops/igfs_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":igfs_op_loader", + "//tensorflow/python:util", + "//tensorflow/python/data/util:nest", + ], +) + tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "python/ops/gen_dataset_ops.py", - deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"], + deps = [":dataset_ops_op_lib"], +) + +tf_gen_op_wrapper_py( + name = "gen_igfs_ops", + out = "python/ops/gen_igfs_ops.py", + deps = [":igfs_ops_op_lib"], ) tf_kernel_library( @@ -97,13 +160,22 @@ tf_kernel_library( alwayslink = 1, ) +tf_kernel_library( + name = "igfs_ops_kernels", + deps = [ + ":igfs_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "ignite_op_loader", srcs = ["python/ops/ignite_op_loader.py"], - dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"], + dso = [":_ignite_ops.so"], kernels = [ ":dataset_ops_kernels", - "//tensorflow/contrib/ignite:dataset_ops_op_lib", + ":dataset_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ @@ -113,6 +185,22 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "igfs_op_loader", + srcs = ["python/ops/igfs_op_loader.py"], + dso = [":_ignite_ops.so"], + kernels = [ + ":igfs_ops_kernels", + ":igfs_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_igfs_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + # The Apache Ignite servers have to setup before the test and tear down # after the test manually. The docker engine has to be installed. # @@ -122,8 +210,11 @@ tf_custom_op_py_library( # To tear down Apache Ignite servers: # $ bash ./python/tests/stop_ignite.sh tf_py_test( - name = "ignite_dataset_test", - srcs = ["python/tests/ignite_dataset_test.py"], + name = "ignite_test", + srcs = [ + "python/tests/igfs_test.py", + "python/tests/ignite_dataset_test.py", + ], additional_deps = [ ":ignite", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 55c89d27996..c7db0b77e25 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -1,19 +1,32 @@ -# Ignite Dataset +# Apache Ignite Integration -- [Overview](#overview) -- [Features](#features) - * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) - * [Structured Objects](#structured-objects) - * [Distributed Training](#distributed-training) - * [SSL Connection](#ssl-connection) - * [Windows Support](#windows-support) -- [Try it out](#try-it-out) -- [Limitations](#limitations) +- [Overview](#overview) +- [Features](#features) + * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) + * [Structured Objects](#structured-objects) + * [Distributed Training](#distributed-training) + * [Distributed File System](#distributed-file-system) + * [SSL Connection](#ssl-connection) + * [Windows Support](#windows-support) +- [Try it out](#try-it-out) + * [Ignite Dataset](#ignite-dataset) + * [IGFS](#igfs) +- [Limitations](#limitations) ## Overview -[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for -transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow. +[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed +database, caching, and processing platform for transactional, analytical, and +streaming workloads, delivering in-memory speeds at petabyte scale. This contrib +package contains an integration between Apache Ignite and TensorFlow. The +integration is based on +[tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow +side and +[Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) +from Apache Ignite side. It allows to use Apache Ignite as a data source for +neural network training, inference and all other computations supported by +TensorFlow. Another part of this module is an integration with distributed file +system based on Apache Ignite. ## Features @@ -134,6 +147,23 @@ Ignite Dataset allows using these two aspects of distributed neural network trai High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +### Distributed File System + +In addition to database functionality Apache Ignite provides a distributed file +system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS +delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in +addition to its own APIs, IGFS implements Hadoop FileSystem API and can be +transparently plugged into Hadoop or Spark deployments. This contrib package +contains an integration between IGFS and TensorFlow. The integration is based +on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) +from TensorFlow side and +[IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache +Ignite side. It has numerous uses, for example: * Checkpoints of state can be +saved to IGFS for reliability and fault-tolerance. * Training processes +communicate with TensorBoard by writing event files to a directory, which +TensorBoard watches. IGFS allows this communication to work even when +TensorBoard runs in a different process or machine. + ### SSL Connection Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. @@ -141,9 +171,12 @@ Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikip ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite") ->>> ... +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES", + certfile="client.pem", + cert_password="password", + username="ignite", + password="ignite") ``` ### Windows Support @@ -152,7 +185,16 @@ Ignite Dataset is fully compatible with Windows. You can use it as part of Tenso ## Try it out -The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: +Following examples will help you to easily start working with this module. + +### Ignite Dataset + +The simplest way to try Ignite Dataset is to run a +[Docker](https://www.docker.com/) container with Apache Ignite and loaded +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +it using Ignite Dataset. Such container is available on Docker Hub: +[dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). +You need to start this container on your machine: ``` docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist @@ -162,6 +204,35 @@ After that you will be able to work with it following way: ![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +### IGFS + +The simplest way to try IGFS with TensorFlow is to run +[Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS +and then interruct with it using TensorFlow +[tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container +is available on Docker Hub: +[dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). +You need to start this container on your machine: + +``` +docker run -it -p 10500:10500 dmitrievanthony/ignite-with-igfs +``` + +After that you will be able to work with it following way: + +```python +>>> import tensorflow as tf +>>> import tensorflow.contrib.ignite.python.ops.igfs_ops +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='w') as w: +>>> w.write("Hello, world!") +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='r') as r: +>>> print(r.read()) + +Hello, world! +``` + ## Limitations Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure. diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h similarity index 67% rename from tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h index 46df3e39dc4..aac950fcc2a 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ #include #include "tensorflow/core/platform/byte_order.h" @@ -25,76 +25,75 @@ class ByteSwapper { public: ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; } - inline void SwapIfRequiredInt16(int16_t *x) const { + void SwapIfRequiredInt16(int16_t *x) const { if (swap_) { Swap16(x); } } - inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const { + void SwapIfRequiredUnsignedInt16(uint16_t *x) const { if (swap_) { Swap16(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt32(int32_t *x) const { + void SwapIfRequiredInt32(int32_t *x) const { if (swap_) { Swap32(x); } } - inline void SwapIfRequiredFloat(float *x) const { + void SwapIfRequiredFloat(float *x) const { if (swap_) { Swap32(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt64(int64_t *x) const { + void SwapIfRequiredInt64(int64_t *x) const { if (swap_) { Swap64(x); } } - inline void SwapIfRequiredDouble(double *x) const { + void SwapIfRequiredDouble(double *x) const { if (swap_) { Swap64(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { + void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(&x[i]); } } - inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, - int32_t length) const { + void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { + void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(&x[i]); } } - inline void SwapIfRequiredFloatArr(float *x, int32_t length) const { + void SwapIfRequiredFloatArr(float *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { + void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(&x[i]); } } - inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const { + void SwapIfRequiredDoubleArr(double *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(reinterpret_cast(&x[i])); @@ -102,16 +101,16 @@ class ByteSwapper { } private: - inline void Swap16(int16_t *x) const { + void Swap16(int16_t *x) const { *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF); } - inline void Swap32(int32_t *x) const { + void Swap32(int32_t *x) const { *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) | (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF); } - inline void Swap64(int64_t *x) const { + void Swap64(int64_t *x) const { *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) | (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) | (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) | @@ -123,4 +122,4 @@ class ByteSwapper { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_client.h similarity index 74% rename from tensorflow/contrib/ignite/kernels/ignite_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_client.h index 459b50b48fd..0da80769260 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -32,44 +32,44 @@ class Client { virtual Status ReadData(uint8_t *buf, const int32_t length) = 0; virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0; - inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); } + Status ReadByte(uint8_t *data) { return ReadData(data, 1); } - inline Status ReadShort(int16_t *data) { + Status ReadShort(int16_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2)); byte_swapper_.SwapIfRequiredInt16(data); return Status::OK(); } - inline Status ReadInt(int32_t *data) { + Status ReadInt(int32_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4)); byte_swapper_.SwapIfRequiredInt32(data); return Status::OK(); } - inline Status ReadLong(int64_t *data) { + Status ReadLong(int64_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8)); byte_swapper_.SwapIfRequiredInt64(data); return Status::OK(); } - inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } + Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } - inline Status WriteShort(const int16_t data) { + Status WriteShort(const int16_t data) { int16_t tmp = data; byte_swapper_.SwapIfRequiredInt16(&tmp); return WriteData((uint8_t *)&tmp, 2); } - inline Status WriteInt(const int32_t data) { + Status WriteInt(const int32_t data) { int32_t tmp = data; byte_swapper_.SwapIfRequiredInt32(&tmp); return WriteData((uint8_t *)&tmp, 4); } - inline Status WriteLong(const int64_t data) { + Status WriteLong(const int64_t data) { int64_t tmp = data; byte_swapper_.SwapIfRequiredInt64(&tmp); return WriteData((uint8_t *)&tmp, 8); @@ -81,4 +81,4 @@ class Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h similarity index 80% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h index 75424c19ee4..54658324604 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" namespace tensorflow { @@ -40,4 +40,4 @@ class PlainClient : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc index cf672942c61..54efb5b6176 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc index dad5aace5fa..a99a3ada558 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #define WIN32_LEAN_AND_MEAN #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc index ceb479b0846..8f09c24a3be 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h similarity index 82% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h index 0406644bbaa..543e03d1efc 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" #include @@ -48,4 +48,4 @@ class SslWrapper : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc similarity index 99% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc index 2c8a7d44b07..4218ec05f2c 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h index eb1f856643a..3e8a1a19623 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ #include -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -78,4 +78,4 @@ enum ObjectType { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc index c4a7d3c513a..ace96e7b09f 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h similarity index 91% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h index 66bfdf2e2a1..db3bafb11f2 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ #include "tensorflow/core/framework/dataset.h" @@ -60,4 +60,4 @@ class IgniteDataset : public DatasetBase { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc index 5da9127aa6a..ce8972f1e7f 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h index c499e2c9ccf..5868c2cb67f 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -96,4 +96,4 @@ constexpr int32_t kMinResLength = 12; } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc index f75b1c5ff55..f2108775e29 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/strings/numbers.h" diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc new file mode 100644 index 00000000000..ae2dbcc2cf5 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc @@ -0,0 +1,331 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/file_system_helper.h" + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" + +namespace tensorflow { + +static string GetEnvOrElse(const string &env, string default_value) { + const char *env_c_str = env.c_str(); + return getenv(env_c_str) != nullptr ? getenv(env_c_str) : default_value; +} + +static string MakeRelative(const string &a, const string &b) { + string max = a; + string min = b; + bool first = b.size() > a.size(); + + if (first) { + max = b; + min = a; + } + + auto r = mismatch(min.begin(), min.end(), max.begin()); + return string((first ? r.first : r.second), first ? min.end() : max.end()); +} + +string IGFS::TranslateName(const string &name) const { + StringPiece scheme, namenode, path; + io::ParseURI(name, &scheme, &namenode, &path); + return string(path.data(), path.length()); +} + +IGFS::IGFS() + : host_(GetEnvOrElse("IGFS_HOST", "localhost")), + port_([] { + int port; + if (strings::safe_strto32(GetEnvOrElse("IGFS_PORT", "10500").c_str(), + &port)) { + return port; + } else { + LOG(WARNING) + << "IGFS_PORT environment variable had an invalid value: " + << getenv("IGFS_PORT") << "\nUsing default port 10500."; + return 10500; + } + }()), + fs_name_(GetEnvOrElse("IGFS_FS_NAME", "default_fs")) { + LOG(INFO) << "IGFS created [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +IGFS::~IGFS() { + LOG(INFO) << "IGFS destroyed [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +Status IGFS::NewRandomAccessFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse open_read_response(true); + TF_RETURN_IF_ERROR(client->OpenRead(&open_read_response, path)); + + int64 resource_id = open_read_response.res.stream_id; + result->reset(new IGFSRandomAccessFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New random access file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewWritableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + } + + CtrlResponse open_create_resp(false); + TF_RETURN_IF_ERROR(client->OpenCreate(&open_create_resp, path)); + + int64 resource_id = open_create_resp.res.stream_id; + result->reset(new IGFSWritableFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New writable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewAppendableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, file_name)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, file_name, false)); + } + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client->OpenAppend(&open_append_resp, file_name)); + + result->reset(new IGFSWritableFile(TranslateName(file_name), + open_append_resp.res.stream_id, + std::move(client))); + + LOG(INFO) << "New appendable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewReadOnlyMemoryRegionFromFile( + const string &file_name, std::unique_ptr *result) { + return errors::Unimplemented("IGFS does not support ReadOnlyMemoryRegion"); +} + +Status IGFS::FileExists(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (!exists_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "File exists completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetChildren(const string &file_name, std::vector *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + path = path + "/"; + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_paths_response(false); + TF_RETURN_IF_ERROR(client->ListPaths(&list_paths_response, path)); + + *result = std::vector(); + std::vector entries = list_paths_response.res.entries; + + for (IGFSPath &value : entries) + result->push_back(MakeRelative(value.path, path)); + + LOG(INFO) << "Get children completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetMatchingPaths(const string &pattern, + std::vector *results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + +Status IGFS::DeleteFile(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + + if (!del_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "Delete file completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::CreateDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse mkdir_response(false); + TF_RETURN_IF_ERROR(client->MkDir(&mkdir_response, path)); + + if (!mkdir_response.res.successful) + return errors::Unknown("Can't create directory ", path); + + LOG(INFO) << "Create dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::DeleteDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_files_response(false); + TF_RETURN_IF_ERROR(client->ListFiles(&list_files_response, path)); + + if (!list_files_response.res.entries.empty()) { + return errors::FailedPrecondition("Can't delete a non-empty directory"); + } else { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, true)); + } + + LOG(INFO) << "Delete dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetFileSize(const string &file_name, uint64 *size) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + *size = info_response.res.file_info.length; + + LOG(INFO) << "Get file size completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::RenameFile(const string &src, const string &dst) { + std::unique_ptr client = CreateClient(); + string src_path = TranslateName(src); + string dst_path = TranslateName(dst); + + if (FileExists(dst).ok()) DeleteFile(dst); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse rename_response(false); + TF_RETURN_IF_ERROR(client->Rename(&rename_response, src_path, dst_path)); + + if (!rename_response.res.successful) + return errors::NotFound("File ", src_path, " not found"); + + LOG(INFO) << "Rename file completed successful [src=" << src + << ", dst=" << dst << "]"; + + return Status::OK(); +} + +Status IGFS::Stat(const string &file_name, FileStatistics *stats) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + IGFSFile info = info_response.res.file_info; + + *stats = FileStatistics(info.length, info.modification_time * 1000000, + (info.flags & 0x1) != 0); + + LOG(INFO) << "Stat completed successful [file_name=" << file_name << "]"; + + return Status::OK(); +} + +std::unique_ptr IGFS::CreateClient() const { + return std::unique_ptr( + new IGFSClient(host_, port_, fs_name_, "")); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.h b/tensorflow/contrib/ignite/kernels/igfs/igfs.h new file mode 100644 index 00000000000..4c347e937f7 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFS : public FileSystem { + public: + IGFS(); + ~IGFS(); + Status NewRandomAccessFile( + const string& file_name, + std::unique_ptr* result) override; + Status NewWritableFile(const string& fname, + std::unique_ptr* result) override; + Status NewAppendableFile(const string& fname, + std::unique_ptr* result) override; + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, + std::unique_ptr* result) override; + Status FileExists(const string& fname) override; + Status GetChildren(const string& dir, std::vector* result) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + Status DeleteFile(const string& fname) override; + Status CreateDir(const string& name) override; + Status DeleteDir(const string& name) override; + Status GetFileSize(const string& fname, uint64* size) override; + Status RenameFile(const string& src, const string& target) override; + Status Stat(const string& fname, FileStatistics* stat) override; + string TranslateName(const string& name) const override; + + private: + std::unique_ptr CreateClient() const; + + const string host_; + const int port_; + const string fs_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc new file mode 100644 index 00000000000..3f97c34fdd8 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" + +namespace tensorflow { + +IGFSClient::IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name) + : fs_name_(fs_name), + user_name_(user_name), + client_(ExtendedTCPClient(host, port, true)) { + client_.Connect(); +} + +IGFSClient::~IGFSClient() { client_.Disconnect(); } + +Status IGFSClient::SendRequestGetResponse(const Request &request, + Response *response) { + TF_RETURN_IF_ERROR(request.Write(&client_)); + client_.reset(); + + if (response != nullptr) { + TF_RETURN_IF_ERROR(response->Read(&client_)); + client_.reset(); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h new file mode 100644 index 00000000000..bbec7b00077 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +class IGFSClient { + public: + IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name); + ~IGFSClient(); + + Status Handshake(CtrlResponse *res) { + return SendRequestGetResponse(HandshakeRequest(fs_name_, {}), res); + } + + Status ListFiles(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListFilesRequest(user_name_, path), res); + } + + Status ListPaths(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListPathsRequest(user_name_, path), res); + } + + Status Info(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(InfoRequest(user_name_, path), res); + } + + Status OpenCreate(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenCreateRequest(user_name_, path), res); + } + + Status OpenAppend(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenAppendRequest(user_name_, path), res); + } + + Status OpenRead(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenReadRequest(user_name_, path), res); + } + + Status Exists(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ExistsRequest(user_name_, path), res); + } + + Status MkDir(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(MakeDirectoriesRequest(user_name_, path), + res); + } + + Status Delete(CtrlResponse *res, const string &path, + bool recursive) { + return SendRequestGetResponse(DeleteRequest(user_name_, path, recursive), + res); + } + + Status WriteBlock(int64_t stream_id, const uint8_t *data, int32_t len) { + return SendRequestGetResponse(WriteBlockRequest(stream_id, data, len), + nullptr); + } + + Status ReadBlock(ReadBlockCtrlResponse *res, int64_t stream_id, int64_t pos, + int32_t length) { + return SendRequestGetResponse(ReadBlockRequest(stream_id, pos, length), + res); + } + + Status Close(CtrlResponse *res, int64_t stream_id) { + return SendRequestGetResponse(CloseRequest(stream_id), res); + } + + Status Rename(CtrlResponse *res, const string &source, + const string &dest) { + return SendRequestGetResponse(RenameRequest(user_name_, source, dest), res); + } + + private: + Status SendRequestGetResponse(const Request &request, Response *response); + + const string fs_name_; + const string user_name_; + ExtendedTCPClient client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc new file mode 100644 index 00000000000..ea63436546d --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +ExtendedTCPClient::ExtendedTCPClient(const string &host, int port, + bool big_endian) + : PlainClient(host, port, big_endian), pos_(0) {} + +Status ExtendedTCPClient::ReadData(uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::ReadData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteData(const uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::WriteData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::Ignore(int n) { + uint8_t buf[n]; + return ReadData(buf, n); +} + +Status ExtendedTCPClient::SkipToPos(int target_pos) { + return Ignore(std::max(0, target_pos - pos_)); +} + +Status ExtendedTCPClient::ReadBool(bool *res) { + uint8_t buf = 0; + TF_RETURN_IF_ERROR(ReadData(&buf, 1)); + *res = buf != 0; + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadNullableString(string *res) { + bool is_empty = false; + TF_RETURN_IF_ERROR(ReadBool(&is_empty)); + + if (!is_empty) { + TF_RETURN_IF_ERROR(ReadString(res)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadString(string *res) { + int16_t length; + TF_RETURN_IF_ERROR(ReadShort(&length)); + + uint8_t *buf = new uint8_t[length]; + Status status = ReadData(buf, length); + + if (status.ok()) res->assign(reinterpret_cast(buf), length); + + delete[] buf; + return status; +} + +Status ExtendedTCPClient::ReadStringMap(std::map *res) { + int size; + TF_RETURN_IF_ERROR(ReadInt(&size)); + + for (int i = 0; i < size; i++) { + string key; + string val; + TF_RETURN_IF_ERROR(ReadString(&key)); + TF_RETURN_IF_ERROR(ReadString(&val)); + + res->insert(std::pair(std::move(key), std::move(val))); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteSize(std::map::size_type s) { + return WriteInt(s); +} + +Status ExtendedTCPClient::FillWithZerosUntil(int n) { + int to_skip = std::max(0, n - pos_); + + for (int i = 0; i < to_skip; i++) { + TF_RETURN_IF_ERROR(WriteByte(0)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteBool(bool val) { + return WriteByte((char)(val ? 1 : 0)); +} + +Status ExtendedTCPClient::WriteString(string str) { + if (!str.empty()) { + TF_RETURN_IF_ERROR(WriteBool(false)); + size_t l = str.length(); + if (l > std::numeric_limits::max()) + return errors::InvalidArgument("String is too long"); + + TF_RETURN_IF_ERROR(WriteShort(l)); + TF_RETURN_IF_ERROR(WriteData(reinterpret_cast(str.c_str()), + str.length())); + } else { + TF_RETURN_IF_ERROR(WriteBool(true)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteStringMap(std::map map) { + std::map::size_type size = map.size(); + TF_RETURN_IF_ERROR(WriteSize(size)); + + for (auto &x : map) { + TF_RETURN_IF_ERROR(WriteString(x.first)); + TF_RETURN_IF_ERROR(WriteString(x.second)); + } + + return Status::OK(); +} + +void ExtendedTCPClient::reset() { pos_ = 0; } + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h new file mode 100644 index 00000000000..c5de342fd0c --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" + +namespace tensorflow { + +class ExtendedTCPClient : public PlainClient { + public: + ExtendedTCPClient(const string &host, int port, bool big_endian); + Status ReadData(uint8_t *buf, const int32_t length) override; + Status WriteData(const uint8_t *buf, const int32_t length) override; + Status Ignore(int n); + Status SkipToPos(int target_pos); + Status ReadBool(bool *res); + Status ReadNullableString(string *res); + Status ReadString(string *res); + Status ReadStringMap(std::map *res); + Status WriteSize(std::map::size_type s); + Status FillWithZerosUntil(int n); + Status WriteBool(bool val); + Status WriteString(string str); + Status WriteStringMap(std::map map); + void reset(); + + private: + int pos_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc new file mode 100644 index 00000000000..9c63f40f35f --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc @@ -0,0 +1,344 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +Status IGFSPath::Read(ExtendedTCPClient *client) { + return client->ReadNullableString(&path); +} + +Status IGFSFile::Read(ExtendedTCPClient *client) { + int32_t block_size; + int64_t group_block_size; + std::map properties = {}; + int64_t access_time; + + bool has_path; + TF_RETURN_IF_ERROR(client->ReadBool(&has_path)); + if (has_path) { + IGFSPath path = {}; + TF_RETURN_IF_ERROR(path.Read(client)); + } + + TF_RETURN_IF_ERROR(client->ReadInt(&block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&group_block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + TF_RETURN_IF_ERROR(client->ReadStringMap(&properties)); + TF_RETURN_IF_ERROR(client->ReadLong(&access_time)); + TF_RETURN_IF_ERROR(client->ReadLong(&modification_time)); + TF_RETURN_IF_ERROR(client->ReadByte(&flags)); + + return Status::OK(); +} + +Request::Request(int32_t command_id) : command_id_(command_id) {} + +Status Request::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(24)); + + return Status::OK(); +} + +Status Response::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->Ignore(1)); + TF_RETURN_IF_ERROR(client->SkipToPos(8)); + TF_RETURN_IF_ERROR(client->ReadInt(&req_id)); + TF_RETURN_IF_ERROR(client->SkipToPos(24)); + TF_RETURN_IF_ERROR(client->ReadInt(&res_type)); + + bool has_error; + TF_RETURN_IF_ERROR(client->ReadBool(&has_error)); + + if (has_error) { + int32_t error_code; + string error_msg; + TF_RETURN_IF_ERROR(client->ReadString(&error_msg)); + TF_RETURN_IF_ERROR(client->ReadInt(&error_code)); + + return errors::Unknown("Error [code=", error_code, ", message=\"", + error_msg, "\"]"); + } + + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + 5)); + TF_RETURN_IF_ERROR(client->ReadInt(&length)); + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + response_header_size_)); + + return Status::OK(); +} + +PathCtrlRequest::PathCtrlRequest(int32_t command_id_, const string &user_name, + const string &path, + const string &destination_path, bool flag, + bool collocate, + const std::map &properties) + : Request(command_id_), + user_name_(user_name), + path_(path), + destination_path_(destination_path), + flag_(flag), + collocate_(collocate), + props_(properties) {} + +Status PathCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(user_name_)); + TF_RETURN_IF_ERROR(WritePath(client, path_)); + TF_RETURN_IF_ERROR(WritePath(client, destination_path_)); + TF_RETURN_IF_ERROR(client->WriteBool(flag_)); + TF_RETURN_IF_ERROR(client->WriteBool(collocate_)); + TF_RETURN_IF_ERROR(client->WriteStringMap(props_)); + + return Status::OK(); +} + +Status PathCtrlRequest::WritePath(ExtendedTCPClient *client, + const string &path) const { + TF_RETURN_IF_ERROR(client->WriteBool(!path.empty())); + if (!path.empty()) TF_RETURN_IF_ERROR(client->WriteString(path)); + + return Status::OK(); +} + +Status StreamCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->WriteLong(stream_id_)); + TF_RETURN_IF_ERROR(client->WriteInt(length_)); + + return Status::OK(); +} + +StreamCtrlRequest::StreamCtrlRequest(int32_t command_id_, int64_t stream_id, + int32_t length) + : Request(command_id_), stream_id_(stream_id), length_(length) {} + +DeleteRequest::DeleteRequest(const string &user_name, const string &path, + bool flag) + : PathCtrlRequest(DELETE_ID, user_name, path, {}, flag, true, {}) {} + +Status DeleteResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +ExistsRequest::ExistsRequest(const string &user_name, const string &path) + : PathCtrlRequest(EXISTS_ID, user_name, path, {}, false, true, {}) {} + +Status ExistsResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +HandshakeRequest::HandshakeRequest(const string &fs_name, const string &log_dir) + : Request(HANDSHAKE_ID), fs_name_(fs_name), log_dir_(log_dir) {} + +Status HandshakeRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(fs_name_)); + TF_RETURN_IF_ERROR(client->WriteString(log_dir_)); + + return Status::OK(); +} + +Status HandshakeResponse::Read(ExtendedTCPClient *client) { + int64_t block_size; + bool sampling; + + TF_RETURN_IF_ERROR(client->ReadNullableString(&fs_name)); + TF_RETURN_IF_ERROR(client->ReadLong(&block_size)); + + bool has_sampling_; + TF_RETURN_IF_ERROR(client->ReadBool(&has_sampling_)); + + if (has_sampling_) { + TF_RETURN_IF_ERROR(client->ReadBool(&sampling)); + } + + return Status::OK(); +} + +ListRequest::ListRequest(int32_t command_id_, const string &user_name, + const string &path) + : PathCtrlRequest(command_id_, user_name, path, {}, false, true, {}) {} + +ListFilesRequest::ListFilesRequest(const string &user_name, const string &path) + : ListRequest(LIST_FILES_ID, user_name, path) {} + +ListPathsRequest::ListPathsRequest(const string &user_name, const string &path) + : ListRequest(LIST_PATHS_ID, user_name, path) {} + +OpenCreateRequest::OpenCreateRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_CREATE_ID, user_name, path, {}, false, true, {}) {} + +Status OpenCreateRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteInt(replication_)); + TF_RETURN_IF_ERROR(client->WriteLong(blockSize_)); + + return Status::OK(); +} + +Status OpenCreateResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenAppendRequest::OpenAppendRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_APPEND_ID, user_name, path, {}, false, true, {}) {} + +Status OpenAppendRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + return Status::OK(); +} + +Status OpenAppendResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path, + bool flag, + int32_t sequential_reads_before_prefetch) + : PathCtrlRequest(OPEN_READ_ID, user_name, path, {}, flag, true, {}), + sequential_reads_before_prefetch_(sequential_reads_before_prefetch) {} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path) + : OpenReadRequest(user_name, path, false, 0) {} + +Status OpenReadRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + if (flag_) { + TF_RETURN_IF_ERROR(client->WriteInt(sequential_reads_before_prefetch_)); + } + + return Status::OK(); +} + +Status OpenReadResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + + return Status::OK(); +} + +InfoRequest::InfoRequest(const string &user_name, const string &path) + : PathCtrlRequest(INFO_ID, user_name, path, {}, false, true, {}) {} + +Status InfoResponse::Read(ExtendedTCPClient *client) { + file_info = IGFSFile(); + TF_RETURN_IF_ERROR(file_info.Read(client)); + + return Status::OK(); +} + +MakeDirectoriesRequest::MakeDirectoriesRequest(const string &user_name, + const string &path) + : PathCtrlRequest(MKDIR_ID, user_name, path, {}, false, true, {}) {} + +Status MakeDirectoriesResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +CloseRequest::CloseRequest(int64_t streamId) + : StreamCtrlRequest(CLOSE_ID, streamId, 0) {} + +Status CloseResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +ReadBlockRequest::ReadBlockRequest(int64_t stream_id, int64_t pos, + int32_t length) + : StreamCtrlRequest(READ_BLOCK_ID, stream_id, length), pos(pos) {} + +Status ReadBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteLong(pos)); + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client, int32_t length, + uint8_t *dst) { + TF_RETURN_IF_ERROR(client->ReadData(dst, length)); + successfully_read = length; + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client) { + return Status::OK(); +} + +std::streamsize ReadBlockResponse::GetSuccessfullyRead() { + return successfully_read; +} + +ReadBlockCtrlResponse::ReadBlockCtrlResponse(uint8_t *dst) + : CtrlResponse(false), dst(dst) {} + +Status ReadBlockCtrlResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(Response::Read(client)); + + res = ReadBlockResponse(); + TF_RETURN_IF_ERROR(res.Read(client, length, dst)); + + return Status::OK(); +} + +WriteBlockRequest::WriteBlockRequest(int64_t stream_id, const uint8_t *data, + int32_t length) + : StreamCtrlRequest(WRITE_BLOCK_ID, stream_id, length), data(data) {} + +Status WriteBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + TF_RETURN_IF_ERROR(client->WriteData((uint8_t *)data, length_)); + + return Status::OK(); +} + +RenameRequest::RenameRequest(const string &user_name, const string &path, + const string &destination_path) + : PathCtrlRequest(RENAME_ID, user_name, path, destination_path, false, true, + {}) {} + +Status RenameResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h new file mode 100644 index 00000000000..44a2928a2b2 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h @@ -0,0 +1,356 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +enum CommandId { + HANDSHAKE_ID = 0, + EXISTS_ID = 2, + INFO_ID = 3, + RENAME_ID = 6, + DELETE_ID = 7, + MKDIR_ID = 8, + LIST_PATHS_ID = 9, + LIST_FILES_ID = 10, + OPEN_READ_ID = 13, + OPEN_APPEND_ID = 14, + OPEN_CREATE_ID = 15, + CLOSE_ID = 16, + READ_BLOCK_ID = 17, + WRITE_BLOCK_ID = 18, +}; + +class IGFSPath { + public: + Status Read(ExtendedTCPClient *client); + + string path; +}; + +class IGFSFile { + public: + Status Read(ExtendedTCPClient *client); + + int64_t length; + int64_t modification_time; + uint8_t flags; +}; + +class Request { + public: + Request(int32_t command_id); + virtual Status Write(ExtendedTCPClient *client) const; + + protected: + const int32_t command_id_; +}; + +class Response { + public: + virtual Status Read(ExtendedTCPClient *client); + + int32_t res_type; + int32_t req_id; + int32_t length; + + protected: + static const int32_t header_size_ = 24; + static const int32_t response_header_size_ = 9; +}; + +class PathCtrlRequest : public Request { + public: + PathCtrlRequest(int32_t command_id, const string &user_name, + const string &path, const string &destination_path, bool flag, + bool collocate, const std::map &properties); + Status Write(ExtendedTCPClient *client) const override; + + protected: + Status WritePath(ExtendedTCPClient *client, const string &path) const; + + const string user_name_; + const string path_; + const string destination_path_; + const bool flag_; + const bool collocate_; + const std::map props_; +}; + +class StreamCtrlRequest : public Request { + public: + StreamCtrlRequest(int32_t command_id, int64_t stream_id, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + protected: + int64_t stream_id_; + int32_t length_; +}; + +template +class CtrlResponse : public Response { + public: + CtrlResponse(bool optional) : optional_(optional) {} + Status Read(ExtendedTCPClient *client) override { + TF_RETURN_IF_ERROR(Response::Read(client)); + + if (optional_) { + TF_RETURN_IF_ERROR(client->ReadBool(&has_content)); + + if (!has_content) return Status::OK(); + } + + res = R(); + has_content = true; + TF_RETURN_IF_ERROR(res.Read(client)); + + return Status::OK(); + } + + R res; + bool has_content; + + private: + bool optional_; +}; + +template +class ListResponse { + public: + Status Read(ExtendedTCPClient *client) { + int32_t len; + TF_RETURN_IF_ERROR(client->ReadInt(&len)); + + entries.clear(); + + for (int32_t i = 0; i < len; i++) { + T f = {}; + TF_RETURN_IF_ERROR(f.Read(client)); + entries.push_back(f); + } + + return Status::OK(); + } + + std::vector entries; +}; + +class DeleteRequest : public PathCtrlRequest { + public: + DeleteRequest(const string &user_name, const string &path, bool flag); +}; + +class DeleteResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class ExistsRequest : public PathCtrlRequest { + public: + explicit ExistsRequest(const string &user_name, const string &path); +}; + +class ExistsResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class HandshakeRequest : public Request { + public: + HandshakeRequest(const string &fs_name, const string &log_dir); + Status Write(ExtendedTCPClient *client) const override; + + private: + string fs_name_; + string log_dir_; +}; + +class HandshakeResponse { + public: + Status Read(ExtendedTCPClient *client); + + string fs_name; +}; + +class ListRequest : public PathCtrlRequest { + public: + explicit ListRequest(int32_t command_id, const string &user_name, + const string &path); +}; + +class ListFilesRequest : public ListRequest { + public: + ListFilesRequest(const string &user_name, const string &path); +}; + +class ListFilesResponse : public ListResponse {}; + +class ListPathsRequest : public ListRequest { + public: + ListPathsRequest(const string &user_name, const string &path); +}; + +class ListPathsResponse : public ListResponse {}; + +class OpenCreateRequest : public PathCtrlRequest { + public: + OpenCreateRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + private: + int32_t replication_; + int64_t blockSize_; +}; + +class OpenCreateResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenAppendRequest : public PathCtrlRequest { + public: + explicit OpenAppendRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; +}; + +class OpenAppendResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenReadRequest : public PathCtrlRequest { + public: + OpenReadRequest(const string &user_name, const string &path, bool flag, + int32_t seqReadsBeforePrefetch); + OpenReadRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + protected: + /** Sequential reads before prefetch. */ + int32_t sequential_reads_before_prefetch_; +}; + +class OpenReadResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; + int64_t length; +}; + +class InfoRequest : public PathCtrlRequest { + public: + InfoRequest(const string &user_name, const string &path); +}; + +class InfoResponse { + public: + Status Read(ExtendedTCPClient *client); + + IGFSFile file_info; +}; + +class MakeDirectoriesRequest : public PathCtrlRequest { + public: + MakeDirectoriesRequest(const string &userName, const string &path); +}; + +class MakeDirectoriesResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +/** Stream control requests. **/ + +class CloseRequest : public StreamCtrlRequest { + public: + explicit CloseRequest(int64_t stream_id); +}; + +class CloseResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +class ReadBlockRequest : public StreamCtrlRequest { + public: + ReadBlockRequest(int64_t stream_id, int64_t pos, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + int64_t pos; +}; + +class ReadBlockResponse { + public: + Status Read(ExtendedTCPClient *client, int32_t length, uint8_t *dst); + Status Read(ExtendedTCPClient *client); + std::streamsize GetSuccessfullyRead(); + + private: + int32_t length; + std::streamsize successfully_read; +}; + +class ReadBlockCtrlResponse : public CtrlResponse { + public: + ReadBlockCtrlResponse(uint8_t *dst); + Status Read(ExtendedTCPClient *client) override; + + private: + uint8_t *dst; +}; + +class WriteBlockRequest : public StreamCtrlRequest { + public: + WriteBlockRequest(int64_t stream_id, const uint8_t *data, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + const uint8_t *data; +}; + +class RenameRequest : public PathCtrlRequest { + public: + RenameRequest(const string &user_name, const string &path, + const string &destination_path); +}; + +class RenameResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc new file mode 100644 index 00000000000..a4c898f14e6 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSRandomAccessFile::IGFSRandomAccessFile(const string &file_name, + int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSRandomAccessFile::~IGFSRandomAccessFile() { + CtrlResponse close_response = {false}; + Status status = client_->Close(&close_response, resource_id_); + + if (!status.ok()) LOG(ERROR) << status.ToString(); +} + +Status IGFSRandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const { + ReadBlockCtrlResponse response = ReadBlockCtrlResponse((uint8_t *)scratch); + TF_RETURN_IF_ERROR(client_->ReadBlock(&response, resource_id_, offset, n)); + + std::streamsize sz = response.res.GetSuccessfullyRead(); + if (sz == 0) return errors::OutOfRange("End of file"); + + *result = StringPiece(scratch, sz); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h new file mode 100644 index 00000000000..b21369ff8a3 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSRandomAccessFile : public RandomAccessFile { + public: + IGFSRandomAccessFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSRandomAccessFile() override; + Status Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const override; + + private: + const string file_name_; + const int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc new file mode 100644 index 00000000000..c15ecb7deeb --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSWritableFile::IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSWritableFile::~IGFSWritableFile() { + if (resource_id_ >= 0) { + CtrlResponse close_response = {false}; + + Status status = client_->Close(&close_response, resource_id_); + if (!status.ok()) LOG(ERROR) << status.ToString(); + } +} + +Status IGFSWritableFile::Append(StringPiece data) { + return client_->WriteBlock(resource_id_, (uint8_t *)data.data(), data.size()); +} + +Status IGFSWritableFile::Close() { + int64_t resource_to_be_closed = resource_id_; + resource_id_ = -1; + + CtrlResponse close_response = {false}; + return client_->Close(&close_response, resource_to_be_closed); +} + +Status IGFSWritableFile::Flush() { return Sync(); } + +Status IGFSWritableFile::Sync() { + CtrlResponse close_response = {false}; + TF_RETURN_IF_ERROR(client_->Close(&close_response, resource_id_)); + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client_->OpenAppend(&open_append_resp, file_name_)); + + resource_id_ = open_append_resp.res.stream_id; + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h new file mode 100644 index 00000000000..b406db17e0e --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSWritableFile : public WritableFile { + public: + IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSWritableFile() override; + Status Append(StringPiece data) override; + Status Close() override; + Status Flush() override; + Status Sync() override; + + private: + const string file_name_; + int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ diff --git a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc b/tensorflow/contrib/ignite/ops/igfs_ops.cc similarity index 62% rename from tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc rename to tensorflow/contrib/ignite/ops/igfs_ops.cc index f3b24b2341e..473bddff08b 100644 --- a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc +++ b/tensorflow/contrib/ignite/ops/igfs_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 Google Inc. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" +#include "tensorflow/core/platform/env.h" + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" namespace tensorflow { -namespace fuzzing { -class FuzzDecodeJpeg : public FuzzStringInputOp { - SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeJpeg); -}; +REGISTER_FILE_SYSTEM("igfs", IGFS); -STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeJpeg); - -} // end namespace fuzzing -} // end namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py new file mode 100644 index 00000000000..8e1d6707d64 --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading IGFS ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/ignite/python/ops/igfs_ops.py b/tensorflow/contrib/ignite/python/ops/igfs_ops.py new file mode 100644 index 00000000000..12b973b7077 --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_ops.py @@ -0,0 +1,40 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignite File System for checkpointing and communication with TensorBoard. + +Apache Ignite is a memory-centric distributed database, caching, and +processing platform for transactional, analytical, and streaming workloads, +delivering in-memory speeds at petabyte scale. In addition to database +functionality Apache Ignite provides a distributed file system called +IGFS (https://ignite.apache.org/features/igfs.html). IGFS delivers a similar +functionality to Hadoop HDFS, but only in-memory. In fact, in addition to +its own APIs, IGFS implements Hadoop FileSystem API and can be transparently +plugged into Hadoop or Spark deployments. This contrib package contains an +integration between IGFS and TensorFlow. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +file_system_library = os.path.join(resource_loader.get_data_files_path(), + "../../_ignite_ops.so") +load_library.load_file_system_library(file_system_library) diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py index c9af7386cf0..e450e2d84ba 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py @@ -21,4 +21,4 @@ from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader _dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/signal/python/__init__.py b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh old mode 100644 new mode 100755 similarity index 73% rename from tensorflow/contrib/signal/python/__init__.py rename to tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh index e672d1146c5..5e39e16c052 --- a/tensorflow/contrib/signal/python/__init__.py +++ b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Signal ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-igfs.xml & +sleep 5 # Wait Apache Ignite to be started + +tail -f nohup.out diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml new file mode 100644 index 00000000000..5d81bf33226 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 127.0.0.1 + + + + + + + + + diff --git a/tensorflow/contrib/ignite/python/tests/igfs_test.py b/tensorflow/contrib/ignite/python/tests/igfs_test.py new file mode 100644 index 00000000000..cacfc568942 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/igfs_test.py @@ -0,0 +1,215 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for IGFS.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.contrib.ignite.python.ops.igfs_ops # pylint: disable=unused-import +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class IGFSTest(test.TestCase): + """The Apache Ignite servers have to setup before the test and tear down + + after the test manually. The docker engine has to be installed. + + To setup Apache Ignite servers: + $ bash start_ignite.sh + + To tear down Apache Ignite servers: + $ bash stop_ignite.sh + """ + + def test_create_file(self): + """Test create file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_create_file/1" + self.assertFalse(gfile.Exists(file_name)) + # Create file. + with gfile.Open(file_name, mode="w") as w: + w.write("") + # Check that file was created. + self.assertTrue(gfile.Exists(file_name)) + + def test_write_read_file(self): + """Test write/read file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_write_read_file/1" + rows = 10000 + self.assertFalse(gfile.Exists(file_name)) + # Write data. + with gfile.Open(file_name, mode="w") as w: + for i in range(rows): + w.write("This is row\n") + # Read data. + with gfile.Open(file_name, mode="r") as r: + lines = r.readlines() + # Check that data is equal. + self.assertEqual(rows, len(lines)) + for i in range(rows): + self.assertEqual("This is row\n", lines[i]) + + def test_delete_recursively(self): + """Test delete recursively. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_delete_recursively/" + file_name = "igfs:///test_delete_recursively/1" + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + gfile.MkDir(dir_name) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(dir_name)) + self.assertTrue(gfile.Exists(file_name)) + # Delete directory recursively. + gfile.DeleteRecursively(dir_name) + # Check that directory was deleted. + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + + def test_copy(self): + """Test copy. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_copy/1" + dst_file_name = "igfs:///test_copy/2" + self.assertFalse(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + # Copy file. + gfile.Copy(src_file_name, dst_file_name) + # Check that files are identical. + self.assertTrue(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_is_directory(self): + """Test is directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_is_directory/1" + file_name = "igfs:///test_is_directory/2" + with gfile.Open(file_name, mode="w") as w: + w.write("") + gfile.MkDir(dir_name) + # Check that directory is a directory. + self.assertTrue(gfile.IsDirectory(dir_name)) + # Check that file is not a directory. + self.assertFalse(gfile.IsDirectory(file_name)) + + def test_list_directory(self): + """Test list directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_list_directory/" + file_names = [ + "igfs:///test_list_directory/1", "igfs:///test_list_directory/2/3" + ] + ch_dir_names = [ + "igfs:///test_list_directory/4", + ] + for file_name in file_names: + with gfile.Open(file_name, mode="w") as w: + w.write("") + for ch_dir_name in ch_dir_names: + gfile.MkDir(ch_dir_name) + ls_expected_result = file_names + ch_dir_names + # Get list of files in directory. + ls_result = gfile.ListDirectory(dir_name) + # Check that list of files is correct. + self.assertEqual(len(ls_expected_result), len(ls_result)) + for e in ["1", "2", "4"]: + self.assertTrue(e in ls_result) + + def test_make_dirs(self): + """Test make dirs. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_make_dirs/" + self.assertFalse(gfile.Exists(dir_name)) + # Make directory. + gfile.MkDir(dir_name) + # Check that directory was created. + self.assertTrue(gfile.Exists(dir_name)) + + def test_remove(self): + """Test remove. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_remove/1" + self.assertFalse(gfile.Exists(file_name)) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(file_name)) + # Remove file. + gfile.Remove(file_name) + # Check that file was removed. + self.assertFalse(gfile.Exists(file_name)) + + def test_rename_file(self): + """Test rename file. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_rename_file/1" + dst_file_name = "igfs:///test_rename_file/2" + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + # Rename file. + gfile.Rename(src_file_name, dst_file_name) + # Check that only new name of file is available. + self.assertFalse(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_rename_dir(self): + """Test rename dir. + + """ + # Setup and check preconditions. + src_dir_name = "igfs:///test_rename_dir/1" + dst_dir_name = "igfs:///test_rename_dir/2" + gfile.MkDir(src_dir_name) + # Rename directory. + gfile.Rename(src_dir_name, dst_dir_name) + # Check that only new name of directory is available. + self.assertFalse(gfile.Exists(src_dir_name)) + self.assertTrue(gfile.Exists(dst_dir_name)) + self.assertTrue(gfile.IsDirectory(dst_dir_name)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh index a67bd44f2fb..112e0dea844 100755 --- a/tensorflow/contrib/ignite/python/tests/start_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh @@ -20,3 +20,7 @@ SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )" # Start Apache Ignite with plain client listener. docker run -itd --name ignite-plain -p 42300:10800 \ -v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh + +# Start Apache Ignite with IGFS. +docker run -itd --name ignite-igfs -p 10500:10500 \ +-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-igfs.sh \ No newline at end of file diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh index 8f03dbd1ede..35b0f32d1b3 100755 --- a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh @@ -15,5 +15,4 @@ # ============================================================================== docker rm -f ignite-plain -docker rm -f ignite-ssl -docker rm -f ignite-ssl-auth +docker rm -f ignite-igfs \ No newline at end of file diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index c8812d4b23f..588f15b867c 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,7 +70,10 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "noasan", # TODO(b/119323169) + ], deps = [ ":_typecheck", ":core", diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index eab93f2cc5e..e779eff6890 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -42,6 +42,7 @@ tensorflow/core/kernels/conv_grad_filter_ops.cc tensorflow/core/kernels/conv_grad_input_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/conv_ops.cc +tensorflow/core/kernels/conv_ops_3d.cc tensorflow/core/kernels/conv_ops_fused.cc tensorflow/core/kernels/conv_ops_using_gemm.cc tensorflow/core/kernels/crop_and_resize_op.cc @@ -163,6 +164,7 @@ tensorflow/core/kernels/pack_op.cc tensorflow/core/kernels/pad_op.cc tensorflow/core/kernels/padding_fifo_queue.cc tensorflow/core/kernels/padding_fifo_queue_op.cc +tensorflow/core/kernels/pooling_ops_3d.cc tensorflow/core/kernels/pooling_ops_common.cc tensorflow/core/kernels/population_count_op.cc tensorflow/core/kernels/quantization_utils.cc @@ -248,6 +250,7 @@ tensorflow/core/kernels/spectrogram_op.cc tensorflow/core/kernels/split_lib_cpu.cc tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc +tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index b313024e285..45a60d79482 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -51,7 +51,7 @@ The pruning library allows for specification of the following hyper parameters: | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | -| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | +| threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index d2b81164176..f6b4373edd0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -204,7 +204,7 @@ def get_pruning_hparams(): begin_pruning_step=0, end_pruning_step=-1, weight_sparsity_map=[''], - threshold_decay=0.9, + threshold_decay=0.0, pruning_frequency=10, nbins=256, block_height=1, @@ -456,13 +456,14 @@ class Pruning(object): pool_window = [self._block_dim[0], self._block_dim[1]] pool_fn = pruning_utils.factorized_pool - + squeeze_axis = None if not self._spec.use_tpu: pool_fn = nn_ops.pool abs_weights = array_ops.reshape( abs_weights, [1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1]) + squeeze_axis = [0, 3] pooled_weights = pool_fn( abs_weights, @@ -473,7 +474,7 @@ class Pruning(object): name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: - pooled_weights = array_ops.squeeze(pooled_weights) + pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 91b0bb7f600..14fc51229ab 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -188,7 +188,6 @@ def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: values = ops.convert_to_tensor(values, name='values') values = array_ops.reshape(values, [-1]) - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -250,7 +249,6 @@ def compute_cdf(values, value_range, **kwargs): name = kwargs.get('name', None) with ops.name_scope(name, 'cdf', [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -336,7 +334,7 @@ def factorized_pool(input_tensor, padding=padding) return array_ops.squeeze( - array_ops.transpose(width_pooling, perm=[0, 1, 3, 2])) + array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]), axis=[0, 1]) def determine_partitioned_axis(partitioned_variable): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 0aca8434976..d6f2bfcb6c2 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -85,8 +85,28 @@ class PruningUtilsTest(test.TestCase): @parameterized.named_parameters( - ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]), - ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1])) + ("Input_32x32_block_1x1", [32, 32], [1, 1]), + # block size 6x6 + ("Input_3x3_block_6x6", [3, 3], [6, 6]), + ("Input_32x32_block_6x6", [32, 32], [6, 6]), + ("Input_2x32_block_6x6", [2, 32], [6, 6]), + ("Input_32x2_block_6x6", [32, 2], [6, 6]), + ("Input_30x30_block_6x6", [30, 30], [6, 6]), + # block size 4x4 + ("Input_32x32_block_4x4", [32, 32], [4, 4]), + ("Input_2x32_block_4x4", [2, 32], [4, 4]), + ("Input_32x2_block_4x4", [32, 2], [4, 4]), + ("Input_30x30_block_4x4", [30, 30], [4, 4]), + # block size 1x4 + ("Input_32x32_block_1x4", [32, 32], [1, 4]), + ("Input_2x32_block_1x4", [2, 32], [1, 4]), + ("Input_32x2_block_1x4", [32, 2], [1, 4]), + ("Input_30x30_block_1x4", [30, 30], [1, 4]), + # block size 4x1 + ("Input_32x32_block_4x1", [32, 32], [4, 1]), + ("Input_2x32_block_4x1", [2, 32], [4, 1]), + ("Input_32x2_block_4x1", [32, 2], [4, 1]), + ("Input_30x30_block_4x1", [30, 30], [4, 1])) class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): def _compare_pooling_methods(self, weights, pooling_kwargs): @@ -97,9 +117,11 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): array_ops.reshape( weights, [1, weights.get_shape()[0], - weights.get_shape()[1], 1]), **pooling_kwargs)) + weights.get_shape()[1], 1]), **pooling_kwargs), + axis=[0, 3]) pooled_weights_factorized_pool = pruning_utils.factorized_pool( weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), pooled_weights_factorized_pool.eval()) @@ -113,8 +135,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val) - def testFactorizedAvgPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedAvgPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "AVG", @@ -123,8 +145,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testFactorizedMaxPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedMaxPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "MAX", @@ -133,8 +155,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testExpandTensor(self, block_dim): - weights = random_ops.random_normal(shape=[1024, 512]) + def testExpandTensor(self, input_shape, block_dim): + weights = random_ops.random_normal(shape=input_shape) self._compare_expand_tensor_with_kronecker_product(weights, block_dim) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index 9ce50bfe105..b7fd2d2fb9d 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -106,6 +106,32 @@ class MovingAverageOptimizer(optimizer.Optimizer): self._swapped_variable_name_map[v_avg.op.name] = v.op.name return control_flow_ops.group(train_op, ma_op, name='train_with_avg') + def _find_swapped_variable(self, v_name_to_tensor, v_name, tensor): + """Returns name of swapped variable for given tensor. + + Args: + v_name_to_tensor: Mapping from variable names to tensors. + v_name: name of the variable for which swapped variable should be returned + tensor: Tensor which correspond to variable for which swapped variable + should be returned. + + Returns: + Tensor which correspond to swapped variable. + + Raises: + ValueError: If swapped variable could not be found in v_name_to_tensor. + """ + swapped_v_name = self._swapped_variable_name_map.get(v_name, None) + if swapped_v_name is None: + return tensor + else: + if swapped_v_name in v_name_to_tensor: + return v_name_to_tensor[swapped_v_name] + else: + raise ValueError( + ('Variable to swap %s is not part of variables to save. ' + 'This breaks MovingAverageOptimizer.') % swapped_v_name) + def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. @@ -141,33 +167,33 @@ class MovingAverageOptimizer(optimizer.Optimizer): if not isinstance(var_list, dict): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) - # OpListToDict converts variables to tensors. We make sure we can get - # the unique variable name for normal and resource vaiables. - def get_v_name(tensor): - if tensor.op.type == 'ReadVariableOp': - return tensor.op.inputs[0].op.name - else: - return tensor.op.name - v_name_to_tensor = {} - for tensor in six.itervalues(var_list): - v_name = get_v_name(tensor) - v_name_to_tensor[v_name] = tensor + for k, tensor_or_list in six.iteritems(var_list): + # For each partitioned variable OpListToDict returns list of constituent + # parts instead of single tensor. + if (isinstance(tensor_or_list, list) + or isinstance(tensor_or_list, variables.PartitionedVariable)): + for tensor in tensor_or_list: + v_name = tensor.op.name + v_name_to_tensor[v_name] = tensor + else: + v_name_to_tensor[k] = tensor_or_list # Now swap variables and moving averages swapped_var_list = {} - for k, tensor in six.iteritems(var_list): - v_name = get_v_name(tensor) - swapped_v_name = self._swapped_variable_name_map.get(v_name, None) - tensor_to_save = tensor - if swapped_v_name is not None: - if swapped_v_name in v_name_to_tensor: - tensor_to_save = v_name_to_tensor[swapped_v_name] - else: - raise ValueError( - ('Variable to swap %s is not part of variables to save. ' - 'This breaks MovingAverageOptimizer.') % swapped_v_name) - swapped_var_list[k] = tensor_to_save + for k, tensor_or_list in six.iteritems(var_list): + if isinstance(tensor_or_list, list): + tensor_list_to_save = [] + for tensor in tensor_or_list: + v_name = tensor.op.name + swapped_variable = self._find_swapped_variable(v_name_to_tensor, + v_name, + tensor) + tensor_list_to_save.append(swapped_variable) + swapped_var_list[k] = tensor_list_to_save + else: + swapped_var_list[k] = self._find_swapped_variable( + v_name_to_tensor, k, tensor_or_list) # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index f22e7245285..643403eea6f 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -43,97 +45,171 @@ class MovingAverageOptimizerTest(test.TestCase): # Test that MovingAverageOptimizer works with resource variables. self._helpTestRun(use_resource=True) - def _helpTestRun(self, use_resource=False): + def testRunUsePartitionedVars(self): + # Test that MovingAverageOptimizer works with partitioned variables. + self._helpTestRun(use_partitioned_vars=True) + + def testRunUseResourcePartitionedVars(self): + # Test that MovingAverageOptimizer works with resource and partitioned + # variables. + self._helpTestRun(use_partitioned_vars=True, use_resource=True) + + def _helpTestRun(self, use_resource=False, use_partitioned_vars=False): + # Partitioned variables are represented as a "collection" of partitions. + # To simplify the test and reuse as much code as possible we employ + # following test strategy for partitioned variables. + # + # In the case of non-partitioned variables test runs on variables with + # shape [2]. + # + # In the case of partitioned variables we use shape [4] with two partitions, + # thus each partition has shape [2]. + # For partitioned variables the test is run twice (for loop over + # variable_part_names), first time on the first partition of each variable, + # second time on the second partition of each variable. + variable_part_names = ['part_0', 'part_1'] if use_partitioned_vars else [''] for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.session(graph=ops.Graph()) as sess: - orig_val0 = [1.0, 2.0] - orig_val1 = [3.0, 4.0] - var0 = variable_scope.get_variable( - 'var0', - initializer=constant_op.constant(orig_val0, dtype=dtype), - use_resource=use_resource) - var1 = variable_scope.get_variable( - 'var1', - initializer=constant_op.constant(orig_val1, dtype=dtype), - use_resource=use_resource) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + for var_part_name in variable_part_names: + with self.session(graph=ops.Graph()) as sess: + orig_val0 = [1.0, 2.0] + orig_val1 = [3.0, 4.0] + grads0 = [0.1, 0.1] + grads1 = [0.01, 0.01] + if use_partitioned_vars: + # Use partitioned variables. + # Create partitioned and duplicate each value used as initial + # value of variables. + partitioner = partitioned_variables.fixed_size_partitioner( + num_shards=2) + orig_val0 = orig_val0 * 2 + orig_val1 = orig_val1 * 2 + grads0 = grads0 * 2 + grads1 = grads1 * 2 + else: + # Regular (non-partitioned) variables. + partitioner = None + var0 = variable_scope.get_variable( + 'var0', + initializer=constant_op.constant(orig_val0, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + var1 = variable_scope.get_variable( + 'var1', + initializer=constant_op.constant(orig_val1, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + # Make a fake loss, such that gradient(loss, var0) == grads0 + # and gradient(loss, var1) == grads1 + grads0 = constant_op.constant(grads0, dtype=dtype) + grads1 = constant_op.constant(grads1, dtype=dtype) + loss = (math_ops.reduce_sum(grads0 * var0) + + math_ops.reduce_sum(grads1 * var1)) - opt = moving_average_optimizer.MovingAverageOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=2.0), - average_decay=0.5, - sequential_update=sequential_update) - save_dir = tempfile.mkdtemp( - prefix=os.path.join(self.get_temp_dir(), 'run_1')) - save_path = os.path.join(save_dir, 'model') - update = opt.apply_gradients( - list(six.moves.zip([grads0, grads1], [var0, var1]))) - global_vars = variables.global_variables() - ema_var0 = [ - v for v in global_vars - if v.op.name == 'var0/ExponentialMovingAverage' - ][0] - ema_var1 = [ - v for v in global_vars - if v.op.name == 'var1/ExponentialMovingAverage' - ][0] - perturb = control_flow_ops.group([ - state_ops.assign_add(var0, [1.0, 1.0]), - state_ops.assign_add(var1, [2.0, 2.0]), - state_ops.assign_add(ema_var0, [3.0, 3.0]), - state_ops.assign_add(ema_var1, [4.0, 4.0]) - ]) + opt = moving_average_optimizer.MovingAverageOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=2.0), + average_decay=0.5, + sequential_update=sequential_update) + save_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'run_1')) + save_path = os.path.join(save_dir, 'model') - # Test that saver with missing ema variables will fail. - with self.assertRaisesRegexp(ValueError, r'Variable to swap'): - opt.swapping_saver(var_list=[var0]) + update = opt.minimize(loss) - train_saver = opt.swapping_saver() - train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) - inference_saver = saver.Saver() - variables.global_variables_initializer().run() - # Step 1. - update.run() - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the swapping saver save/restore operation is identity. - train_saver.save(sess, save_path) - train_saver.restore(sess, save_path) - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the subset saver saves the EMA variable as well. - if sequential_update: - subset_save_path = save_path + '_subset' - train_saver_subset.save(sess, subset_save_path) - perturb.run() - self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) - self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restoring should only restore var0 and ema_var0. - train_saver_subset.restore(sess, subset_save_path) + # Get variables and their EMAs. In case of partitioned variables + # get proper part of each variable. + def _get_variable(var_name, part_name, ema): + """Returns variable of it's moving average by name.""" + matches = [ + v for v in variables.global_variables() + if ((var_name in v.op.name) + and (part_name in v.op.name) + and (('ExponentialMovingAverage' in v.op.name) == ema)) + ] + self.assertEqual(len(matches), 1) + return matches[0] + var0 = _get_variable('var0', var_part_name, ema=False) + var1 = _get_variable('var1', var_part_name, ema=False) + ema_var0 = _get_variable('var0', var_part_name, ema=True) + ema_var1 = _get_variable('var1', var_part_name, ema=True) + + perturb = control_flow_ops.group([ + state_ops.assign_add(var0, [1.0, 1.0]), + state_ops.assign_add(var1, [2.0, 2.0]), + state_ops.assign_add(ema_var0, [3.0, 3.0]), + state_ops.assign_add(ema_var1, [4.0, 4.0]) + ]) + + # Test that saver with missing ema variables will fail. + with self.assertRaisesRegexp(ValueError, r'Variable to swap'): + opt.swapping_saver(var_list=[var0]) + + train_saver = opt.swapping_saver() + train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) + inference_saver = saver.Saver() + variables.global_variables_initializer().run() + # Step 1. + update.run() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restore back to previous state. + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the swapping saver save/restore operation is identity. + train_saver.save(sess, save_path) train_saver.restore(sess, save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the subset saver saves the EMA variable as well. + if sequential_update: + subset_save_path = save_path + '_subset' + train_saver_subset.save(sess, subset_save_path) + perturb.run() + self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) + self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restoring should only restore var0 and ema_var0. + train_saver_subset.restore(sess, subset_save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restore back to previous state. + train_saver.restore(sess, save_path) - # If updates are parallel, this is not always true after the 1st step. - if sequential_update: + # If updates are parallel, + # this is not always true after the 1st step. + if sequential_update: + # Test that the normal saver will have the averaged variables. + # We test that the average values are between the original value + # and the most recent variable values (since they are an average + # of the two). + val0 = var0.eval() + val1 = var1.eval() + train_saver.save(sess, save_path) + inference_saver.restore(sess, save_path) + avg_val0 = var0.eval() + avg_val1 = var1.eval() + for i in six.moves.range(len(val0)): + self.assertLess(val0[i], avg_val0[i]) + self.assertLess(avg_val0[i], orig_val0[i]) + self.assertLess(val1[i], avg_val1[i]) + self.assertLess(avg_val1[i], orig_val1[i]) + train_saver.restore(sess, save_path) + # Step 2. + update.run() # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value - # and the most recent variable values (since they are an average - # of the two). + # We test that the average values are between the original value and + # the most recent variable values (since they are an average of the + # two). val0 = var0.eval() val1 = var1.eval() + self.assertAllCloseAccordingToType([0.6, 1.6], val0) + self.assertAllCloseAccordingToType([2.96, 3.96], val1) train_saver.save(sess, save_path) inference_saver.restore(sess, save_path) avg_val0 = var0.eval() @@ -143,26 +219,6 @@ class MovingAverageOptimizerTest(test.TestCase): self.assertLess(avg_val0[i], orig_val0[i]) self.assertLess(val1[i], avg_val1[i]) self.assertLess(avg_val1[i], orig_val1[i]) - train_saver.restore(sess, save_path) - # Step 2. - update.run() - # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value and - # the most recent variable values (since they are an average of the - # two). - val0 = var0.eval() - val1 = var1.eval() - self.assertAllCloseAccordingToType([0.6, 1.6], val0) - self.assertAllCloseAccordingToType([2.96, 3.96], val1) - train_saver.save(sess, save_path) - inference_saver.restore(sess, save_path) - avg_val0 = var0.eval() - avg_val1 = var1.eval() - for i in six.moves.range(len(val0)): - self.assertLess(val0[i], avg_val0[i]) - self.assertLess(avg_val0[i], orig_val0[i]) - self.assertLess(val1[i], avg_val1[i]) - self.assertLess(avg_val1[i], orig_val1[i]) def testFailWhenSaverCreatedBeforeInitialized(self): with self.cached_session(): diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 44a8890cb10..155ff5b3f4f 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f789c83e005..467dd86d8fd 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -790,14 +790,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of replicas (callable-loss case). In this case, # we have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss_value *= 1. / num_replicas + loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: var_list = tape.watched_variables() @@ -808,14 +801,7 @@ class OptimizerV2(optimizer_v1.Optimizer): "be a function when eager execution is enabled.") # Scale loss for number of replicas (non-callable-loss case). - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss *= 1. / num_replicas + loss = self._scale_loss(loss, scale_loss_by_num_replicas) if gate_gradients not in [ optimizer_v1.Optimizer.GATE_NONE, optimizer_v1.Optimizer.GATE_OP, @@ -857,6 +843,20 @@ class OptimizerV2(optimizer_v1.Optimizer): ]) return grads_and_vars + @staticmethod + def _scale_loss(loss_value, scale_loss_by_num_replicas): + """Scale loss for the number of replicas.""" + if scale_loss_by_num_replicas is None: + scale_loss_by_num_replicas = ( + distribute_lib.get_loss_reduction() == variable_scope + .VariableAggregation.MEAN) + if scale_loss_by_num_replicas: + num_replicas = \ + distribute_ctx.get_distribution_strategy().num_replicas_in_sync + if num_replicas > 1: + loss_value *= 1. / num_replicas + return loss_value + def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 057e851aba6..15ae95f13cf 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, // const int gid = batch_id * cell_size * 4 + act_id; const int cid = batch_id * cell_size + act_id; - Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_logistic_op sigmoid_op; Eigen::internal::scalar_tanh_op tanh_op; Eigen::scalar_clip_op clip_op; diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 291ff83791c..f0947fe423f 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,7 +82,6 @@ py_library( name = "keras_saved_model", srcs = ["python/saved_model/keras_saved_model.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ "//tensorflow/python:array_ops", @@ -103,7 +102,7 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = ["no_windows"], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 4970ebc3199..a65b2ce4661 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -345,21 +345,22 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): inputs, outputs = load_model(sess, output_path, model_fn_lib.ModeKeys.EVAL) - sess.run(outputs['metrics/mae/update_op'], { - inputs[input_name]: input_arr, - inputs[target_name]: target_arr - }) + # First obtain the loss and predictions, and run the metric update op by + # feeding in the inputs and targets. + loss, predictions, _ = sess.run( + (outputs['loss'], outputs['predictions/' + output_name], + outputs['metrics/mae/update_op']), + {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) - eval_results = sess.run(outputs, {inputs[input_name]: input_arr, - inputs[target_name]: target_arr}) + # The metric value should be run after the update op, to ensure that it + # reflects the correct value. + metric_value = sess.run(outputs['metrics/mae/value']) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) - self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05) - self.assertAllClose( - ref_mae, eval_results['metrics/mae/value'], atol=1e-05) - self.assertAllClose( - ref_predict, eval_results['predictions/' + output_name], atol=1e-05) + self.assertAllClose(ref_loss, loss, atol=1e-05) + self.assertAllClose(ref_mae, metric_value, atol=1e-05) + self.assertAllClose(ref_predict, predictions, atol=1e-05) # Load train graph, and check for the train op, and prediction values with session.Session(graph=ops.Graph()) as sess: diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 6bd58c4d322..5e4f130b314 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -4,129 +4,11 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "py_test") # @unused - py_library( name = "signal_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:spectral_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_library( - name = "test_util", - srcs = ["python/kernel_tests/test_util.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python:training", - ], -) - -cuda_py_tests( - name = "mel_ops_test", - srcs = ["python/kernel_tests/mel_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - ], -) - -cuda_py_tests( - name = "mfcc_ops_test", - srcs = ["python/kernel_tests/mfcc_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:spectral_ops_test_util", - ], -) - -cuda_py_tests( - name = "reconstruction_ops_test", - srcs = ["python/kernel_tests/reconstruction_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "shape_ops_test", - srcs = ["python/kernel_tests/shape_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "spectral_ops_test", - size = "large", - srcs = ["python/kernel_tests/spectral_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops_test_util", - ], - tags = ["nomac"], -) - -cuda_py_tests( - name = "window_ops_test", - srcs = ["python/kernel_tests/window_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python/ops/signal", ], ) diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index d088e744346..d01f5ccf51c 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -14,6 +14,9 @@ # ============================================================================== """Signal processing operations. +`tf.contrib.signal` has been renamed to `tf.signal`. `tf.contrib.signal` will be +removed in TensorFlow 2.0. + See the [Contrib Signal](https://tensorflow.org/api_guides/python/contrib.signal) guide. @@ -39,18 +42,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix -from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms -from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add -from tensorflow.contrib.signal.python.ops.shape_ops import frame -# `frame` used to be named `frames`, which is a noun and not a verb. -# Keep an alias to `frames` for backwards compatibility. -from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft_window_fn -from tensorflow.contrib.signal.python.ops.spectral_ops import stft -from tensorflow.contrib.signal.python.ops.window_ops import hamming_window -from tensorflow.contrib.signal.python.ops.window_ops import hann_window +from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix +from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms +from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add +from tensorflow.python.ops.signal.shape_ops import frame +from tensorflow.python.ops.signal.spectral_ops import inverse_stft +from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn +from tensorflow.python.ops.signal.spectral_ops import stft +from tensorflow.python.ops.signal.window_ops import hamming_window +from tensorflow.python.ops.signal.window_ops import hann_window from tensorflow.python.util.all_util import remove_undocumented + +# `frame` used to be named `frames`, which is a noun and not a verb. +# Keep an alias to `frames` for backwards compatibility. +frames = frame + remove_undocumented(__name__) diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 596c59ead34..290c16fe396 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops # pylint: disable=unused-import @@ -28,10 +30,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_mod # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _model_ops = loader.load_op_library( @@ -88,6 +92,59 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) +class TreeVariable(tracking.TrackableResource): + """A tree model.""" + + def __init__(self, params, tree_config, stats_handle, name, container=None): + self._params = params + self._tree_config = tree_config + self._stats_handle = stats_handle + self._name = name + self._container = container + self._init_op = None + super(TreeVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "tree_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_model_ops.decision_tree_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_variable( + self.resource_handle, + self._tree_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "tree_variable": + functools.partial( + TreeVariableSavable, + params=self._params, + tree_handle=self.resource_handle, + stats_handle=self._stats_handle, + create_op=self._init_op) + } + + def tree_variable(params, tree_config, stats_handle, name, container=None): r"""Creates a tree model and returns a handle to it. @@ -102,18 +159,13 @@ def tree_variable(params, tree_config, stats_handle, name, container=None): A `Tensor` of type mutable `string`. The handle to the tree. """ with ops.name_scope(name, "TreeVariable") as name: - resource_handle = gen_model_ops.decision_tree_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_model_ops.create_tree_variable( - resource_handle, - tree_config, - params=params.serialized_params_proto) - is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle) + tree_var = TreeVariable(params, tree_config, stats_handle, name, container) + resource_handle = tree_var.resource_handle + create_op = tree_var.initializer + is_initialized_op = tree_var.is_initialized() # Adds the variable to the savable list. - saveable = TreeVariableSavable(params, resource_handle, stats_handle, - create_op, - resource_handle.name) + saveable = tree_var._gather_saveables_for_checkpoint()["tree_variable"]( # pylint: disable=protected-access + name=resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 44d486edecc..9184198cd4c 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_stats_ops # pylint: disable=unused-import from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import finalize_tree @@ -25,10 +27,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import process_in # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _stats_ops = loader.load_op_library( @@ -84,8 +88,58 @@ class FertileStatsVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) -def fertile_stats_variable(params, stats_config, name, - container=None): +class FertileStatsVariable(tracking.TrackableResource): + """A Fertile stats variable.""" + + def __init__(self, params, stats_config, name, container=None): + self._params = params + self._stats_config = stats_config + self._name = name + self._container = container + self._init_op = None + super(FertileStatsVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "fertile_stats_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_stats_ops.fertile_stats_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_stats_ops.create_fertile_stats_variable( + self.resource_handle, + self._stats_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_stats_ops.fertile_stats_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "fertile_stats_variable": + functools.partial( + FertileStatsVariableSavable, + params=self._params, + stats_handle=self.resource_handle, + create_op=self.initializer) + } + + +def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: @@ -98,17 +152,15 @@ def fertile_stats_variable(params, stats_config, name, A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: - resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_stats_ops.create_fertile_stats_variable( - resource_handle, stats_config, - params=params.serialized_params_proto) - is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( - resource_handle) + fertile_stats_var = FertileStatsVariable(params, stats_config, name, + container) + resource_handle = fertile_stats_var.resource_handle + create_op = fertile_stats_var.initializer + is_initialized_op = fertile_stats_var.is_initialized() # Adds the variable to the savable list. - saveable = FertileStatsVariableSavable(params, resource_handle, create_op, - resource_handle.name) + saveable = ( + fertile_stats_var._gather_saveables_for_checkpoint()[ # pylint: disable=protected-access + "fertile_stats_variable"](name=resource_handle.name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 1f5591fe2a6..26d54eb156c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -141,6 +141,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { std::vector input_edges; TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); std::vector> input_node_and_ports; + input_node_and_ports.reserve(input_edges.size()); for (const Edge* input_edge : input_edges) { input_node_and_ports.emplace_back(&input_edge->src()->def(), input_edge->src_output()); @@ -923,7 +924,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { converted_segments.push_back(std::move(curr_segment)); if (VLOG_IS_ON(8)) { - string fname = curr_engine.engine_name; + string fname = engine_segments.back().engine_name; StrAppend(&fname, ".pb"); std::fstream f; f.open(fname.c_str(), std::fstream::out | std::fstream::binary); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a6f954391d3..e2988f5f2a8 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -1552,6 +1552,7 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { const auto& node_def = params->node_def; // prepare input std::vector all_inputs; + all_inputs.reserve(inputs.size()); for (auto input : inputs) { all_inputs.emplace_back(const_cast(input.tensor())); } diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 67327d32000..a0a9cb3f31a 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -246,6 +246,7 @@ py_library( "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", "python/tpu/session_support.py", + "python/tpu/tensor_tracer.py", "python/tpu/topology.py", "python/tpu/tpu.py", "python/tpu/tpu_feed.py", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 38d1c3049ef..541fbf33a30 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -94,13 +94,6 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( - name = "tf_op_stats_proto", - srcs = ["tf_op_stats.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "tpu_profiler_analysis_proto", srcs = ["tpu_profiler_analysis.proto"], diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto deleted file mode 100644 index 1e66801efd4..00000000000 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ /dev/null @@ -1,261 +0,0 @@ -// This proto describes the format of tensorflow operation level stats for -// profiling (in tensorboard) purpose. - -syntax = "proto2"; - -package tensorflow.tpu; - -// Result proto for OpMetrics. -message OpMetricsResult { - // True if this OP is executed on the device; False if it is executed on the - // host. - optional bool on_device = 1; - reserved 2; // was uint32 id. - // Name of this OP. - optional string name = 3; - // Rank of this OP. - optional uint64 rank = 4; - // The starting time in cycles of the last instance of this OP executed. - optional double last_starttime_in_cycles = 5; - // The ending time in cycles of the last instance of this OP executed. - optional double last_endtime_in_cycles = 6; - // If this OP (say A), is an immediate child of another OP (say B), this field - // stores the sum of duration in microseconds of A inside B. If A appears more - // than once in B, the duration of all A's appearances will be added together. - // This sum will be reset after the self-time of B is calculated so that it - // can be reused for a new parent OP. - optional double sum_of_duration_in_us_as_children = 7; - // Number of instances that this OP occurred. - optional uint64 occurrences = 8; - // Total time in microseconds spent in this OP (accumulated - // over all of its occurrences). - optional double total_time_in_us = 9; - // Total self time in microseconds spent in this OP - // (accumulated over all of its occurrences). - optional double total_self_time_in_us = 10; - // The total self time as a fraction of sum of all OP's - // total self time on the host. - optional double host_total_self_time_as_fraction_of_all_op_time = 11; - // Cumulative total self time in fraction on the host. - optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = - 12; - // The total self time as a fraction of sum of all OP's - // total self time on the device. - optional double device_total_self_time_as_fraction_of_all_op_time = 13; - // Cumulative total self time in fraction on the device. - optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = - 14; - // Total number of FLOPs incurred by this OP. - optional double total_flops = 15; - // Total number of bytes accessed by this OP. - optional double total_bytes_accessed = 16; - // Total time in microseconds that special hw unit 1 is occupied by this OP. - optional double unit1_occupancy_in_us = 17; - // Total time in microseconds that special hw unit 2 is occupied by this OP. - optional double unit2_occupancy_in_us = 18; - // Total memory stall time in microseconds. - optional double total_memory_stall_in_us = 19; -} - -// Result proto for OpMetricsDb. -message OpMetricsDbResult { - // A bunch of OpMetricsResults. - repeated OpMetricsResult metrics_db = 1; - // The total host infeed-enqueue duration in picoseconds. - optional uint64 total_host_infeed_enq_duration_ps = 2; - // The total of the difference between the start times of two - // consecutive infeed-enqueues (per host) in picoseconds. - optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3; - // The total device time in microseconds. - optional double total_device_time_in_us = 4; - // The total host time in microseconds. - optional double total_host_time_in_us = 5; -} - -// Result proto for StepInfo. -message StepInfoResult { - // The (micro) step number. - optional uint32 step_num = 1; - // The step duration in picoseconds. - optional uint64 duration_ps = 2; - // The infeed duration in picoseconds. - optional uint64 infeed_duration_ps = 3; - // The outfeed duration in picoseconds. - optional uint64 host_outfeed_ps = 8; - // The start time of this step in picoseconds. - optional uint64 begin_ps = 4; - // The waiting time within this step in picoseconds. - optional uint64 wait_duration_ps = 5; - // The unit b outfeed duration in picoseconds. - optional uint64 unit_b_outfeed_ps = 9; - // The time spent on cross-replica-sum in picoseconds. - optional uint64 crs_duration_ps = 6; - // Percentage of unit b time spent on infeed. - optional double unit_b_infeed_percent = 7; -} - -// Result proto for a sequence of steps. -message StepSequenceResult { - // A sequence of StepInfoResults. - repeated StepInfoResult step_sequence = 1; -} - -// Result proto for a StepDatabase. -message StepDatabaseResult { - // A map from core_id to StepSequenceResult. - map step_sequence_per_core = 1; -} - -// Result proto for looping-related metrics. -message LoopingResult { - // The total iteration time in nanoseconds. - optional double iteration_time_ns = 1; - // The total number of iterations. - optional int32 num_iterations = 2; - // The total computation time in nanoseconds. - optional double computation_time_ns = 3; - // The total number of computations. - optional int32 num_computations = 4; -} - -// Result proto for HloExtraInfo. -message HloExtraInfoResult { - // Category of the HLO op given by the compiler. - optional string category = 1; - // The long name of the HLO that includes the dimensions. - optional string long_name = 2; - // The per-TPU-core batch size inferred from this HLO. - optional int64 per_core_batch_size = 3; -} - -// Result proto for HloExtraInfoMap. -message HloExtraInfoMapResult { - // A map from HLO name to HloExtraInfo. - map hlo_extrainfo_map = 1; -} - -// Result proto for host-independent job information. -message HostIndependentJobInfoResult { - // The change-list number of this build. - optional int64 change_list = 1; - // The time of this build. - optional int64 build_time = 2; - // The target of this build. - optional string build_target = 3; -} - -// Result proto for host-dependent job information. -message HostDependentJobInfoResult { - // This ID of the host where the job was run on. - optional string host_id = 1; - // The command line used to run the job. - optional string command_line = 2; - // The start time of the job on this host. - optional int64 start_time = 3; -} - -// Result proto for RunEnvironment (the run environment of a profiling session). -message RunEnvironmentResult { - // Number of hosts used. - optional int32 host_count = 1; - // The type of TPU used. - optional string tpu_type = 2; - // The number of TPU cores used. - optional int32 tpu_core_count = 3; - // The per-TPU-core batch size. - optional int32 per_core_batch_size = 4; - // Host-independent job information. - optional HostIndependentJobInfoResult host_independent_job_info = 5; - // Host-dependent job information. - repeated HostDependentJobInfoResult host_dependent_job_info = 6; - // The number of replicas, corresponds to input parallelism. - // If there is no model parallelism, replica_count = tpu_core_count - optional int32 replica_count = 7; - // The number of cores used for a single replica, e.g. model parallelism. - // If there is no model parallelism, then num_cores_per_replica = 1 - optional int32 num_cores_per_replica = 8; -} - -// The types of host operations that are tracked. -enum HostOp { - // Invalid host op. - kINVALIDHostOp = 0; - // Each of host op type has two parts: - // (1) the stage where the op happens and (2) the op name. - // stage = Input Data Producer, op = Get Next Batch. - kInputDataProducerGetNextBatch = 1; - // stage = Input Data Producer, op = Session Run. - kInputDataProducerSessionRun = 2; - // stage = Input Data Producer, op = Forward Batch. - kInputDataProducerForwardBatch = 3; - // stage = Infeed Thread, op = Get Next Batch. - kInfeedThreadGetNextBatch = 4; - // stage = Infeed Thread, op = Session Run. - kInfeedThreadSessionRun = 5; - // stage = Infeed Thread, op = Forward Batch. - kInfeedThreadForwardBatch = 6; - // stage = Outfeed Thread, op = Get Next Batch. - kOutfeedThreadGetNextBatch = 7; - // stage = Outfeed Thread, op = Session Run. - kOutfeedThreadSessionRun = 8; - // stage = Outfeed Thread, op = Forward Batch. - kOutfeedThreadForwardBatch = 9; -} - -// Result proto for the host ops per TPU step. -message HostOpsPerTpuStep { - // Whether the data in this message is valid. - optional bool valid = 1 [default = false]; - // The current TPU step number. - optional uint32 tpu_step_num = 2; - // The beginning time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_begin_ps = 3; - // The ending time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_end_ps = 4; - // For each possible host operation, maps to the difference between the TPU - // step number that the host op targets and the current TPU step number. - // The key is HostOp, value is the step difference. - map step_diffs = 5; -} - -message HostOpsDetailsPerCore { - // Map from core id to HostOpsPerTpuStep. - map core_map = 1; -} - -message HostOpsDetailsPerHost { - // Map from hostname to a map from core id to HostOpsPerTpuStep. - map host_map = 1; -} - -// Result proto for the host ops for all TPU steps. -message HostOpsResult { - reserved 1; // (was repeated HostOpsPerTpuStep host_op_sequence) - // A sequence of records with one for each TPU step. Each record - // is a map from hostname to a map from core id to HostOpsPerTpuStep. - repeated HostOpsDetailsPerHost hostops_details = 2; -} - -// Result proto for TfStatsHelper. -message TfOpStats { - // The result for the TF-metric database. - optional OpMetricsDbResult tf_metrics_db = 1; - // The result for the HLO-metric database. - optional OpMetricsDbResult hlo_metrics_db = 2; - // The result for the step database. - optional StepDatabaseResult step_db = 3; - // The result for the looping-related metrics. - optional LoopingResult looping = 4; - // The result for the HloExtraInfoMap. - optional HloExtraInfoMapResult hlo_extrainfo_map = 5; - // Overall matrix unit utilization in percentage. - optional double matrix_unit_utilization_percent = 6; - // The run environment of this profiling session. - optional RunEnvironmentResult run_environment = 7; - // The result for the host operations. - optional HostOpsResult host_ops = 8; - // A map from core ID to name. - map core_id_to_name_map = 9; - // The result for hw unit b stats. - optional bytes unit_b_stats = 10; -} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index c2e3be03db0..aae1ab1d37a 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -154,6 +154,14 @@ message OptimizationParameters { // updates; not present means no limits are applied. ClippingLimits gradient_clipping_limits = 7; + // Amount of weight decay to apply; see weight_decay_optimizers.py for + // details. Almost all optimizers are supported with this option (MDL Adagrad + // Light does not work, and SGD does not behave as expected if it is enabled). + // Although there is no check, users who want weight decay will probably also + // want to enable gradient accumulation as well so that the decay will happen + // once per minibatch. + float weight_decay_factor = 16; + // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). This feature is diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index c32bd5997c1..1cf7f9fcf67 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -164,14 +164,15 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) + + for l in self._listeners: + l.after_save(session, step) + end_time = time.time() logging.info("Checkpoint actual writing time: (%.3f sec)", end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - for l in self._listeners: - l.before_save(session, step) - if not asynchronous: _save_fn() return diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index ce2c322ff49..08f58a5f5b8 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1184,8 +1184,7 @@ class TPUFunction(object): # pipelined loop. return None, None - if (self.model.uses_learning_phase and - not isinstance(K.learning_phase(), int)): + if not isinstance(K.learning_phase(), int): # Remove the learning_phase flag at the end. We currently hard code the # learning_phase in TPUFunction. assert isinstance(inputs[-1], int), ( @@ -1651,7 +1650,7 @@ class KerasTPUModel(models.Model): self._make_train_function() sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + if not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] else: ins = inputs + targets + sample_weights diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py new file mode 100644 index 00000000000..70baea203cc --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -0,0 +1,553 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================== +"""A utility to trace tensor values on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path +import re + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging + +_TRACER_LOG_PREFIX = ' [>>>TT>>>]' +_DEVICE_TYPE_TPU = 'tpu' +_DEVICE_TYPE_CPU = 'cpu' +_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' +_TRACE_MODE_NAN_INF = 'nan-inf' +_TRACE_MODE_PART_TENSOR = 'part-tensor' +_TRACE_MODE_PART_TENSOR_SIZE = 3 +_TRACE_MODE_FULL_TENSOR = 'full-tensor' +_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' +_RECORD_FILTERED_OUT = 'not-traced-filtered-out' +_RECORD_SCALAR = 'not-traced-scalar' +_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_RECORD_GET_TRACED = 'get-traced' +_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' +_MARKER_SECTION_END = '!!!!!!! section-end:' +_SECTION_NAME_CONFIG = 'configuration' +_SECTION_NAME_REASON = 'reason' +_SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_GRAPH = 'graph' +_FIELD_NAME_VERSION = 'version:' +_FIELD_NAME_DEVICE = 'device:' +_FIELD_NAME_TRACE_MODE = 'trace-mode:' +_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' +_FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' +_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' +_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") +_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') +_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') +_FLAG_NAME_ENABLE = 'enable' +_FLAG_NAME_TRACE_MODE = 'trace_mode' +_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' +_FLAG_NAME_OP_RANGE = 'op_range' +_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') +_OUTPUT_STREAM_ESCAPE = 'file://' +_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' + + +class TensorTracer(object): + """A software construct for tracing tensor values in a TF graph on TPU. + + This utility is disabled by default. It can be enabled by setting + the TENSOR_TRACER_FLAGS env variable as: + export TENSOR_TRACER_FLAGS="--enable=1" + If it is enabled, it will trace the output tensor values of + selected Ops in the graph. It has two outputs: (1) the traces and (2) + a report. The traces are dumped to a specified local file on the TPU + host. The report is printed to the log.info of the TPU job. + By passing options via the env variable, users can change: + (1) the trace mode (e.g., detecting NaN/Inf, printing partial or + full tensor values) + (2) which Ops to be traced (via op.name or op.type) + (3) output trace file path. + """ + + @staticmethod + def _match_next_flag(flags, pos): + """Returns the match for the next TensorTracer flag.""" + + match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_NO_QUOTE_PAT.match(flags, pos) + return match + + @staticmethod + def print_flag_values(): + """Prints all TensorTracer flags passed via environment variables.""" + + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR + result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, + tensor_tracer_flags) + result += 'Individual flag value:\n' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + flag_value = match.group(2) + result += ' %s: %s\n'%(flag_name, flag_value) + pos = match.end() + result += '\n' + return result + + @staticmethod + def get_flag_value(wanted_flag_name): + """Returns the value of a TensorTracer flags.""" + + tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return '' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + return '' + flag_name = match.group(1) + flag_value = match.group(2) + if flag_name == wanted_flag_name: + return flag_value + pos = match.end() + return '' + + @staticmethod + def is_enabled(): + """Returns True if TensorTracer is enabled.""" + + flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def use_test_undeclared_outputs_dir(): + """Decides the output directory of the trace file. + + Args: + None. + + Returns: + True if the output trace file should be written to the + test-undeclared-outputs-directory defined via an + env variable. + """ + + flag_value = TensorTracer.get_flag_value( + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def check_device_type(device_type): + """Checks if the given device type is valid.""" + + if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: + raise ValueError('Invalid device_type "%s"'%device_type) + + @staticmethod + def check_trace_mode(trace_mode): + """Checks if the given trace mode is valid.""" + + valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, + _TRACE_MODE_FULL_TENSOR] + if trace_mode not in valid_trace_modes: + raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' + 'Valid trace modes are: %s'%(trace_mode, + valid_trace_modes)) + + @staticmethod + def should_trace(device_type, op): + """Returns True if the given Op should be traced.""" + + if device_type != _DEVICE_TYPE_TPU: + raise ValueError('Non TPU device type is not supported') + if control_flow_util.IsInCond(op): + return False + if op.type in ['Reshape', 'ArgMin', 'ArgMax']: + return False + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR in op.node_def.attr + # pylint: enable=protected-access + + @staticmethod + def reason(op_idx, details): + """Returns why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) + + @staticmethod + def topological_sort(g): + """Performs topological sort on the given graph. + + Args: + g: the graph. + + Returns: + A pair where the first element indicates if the topological + sort succeeded (True if there is no cycle found; False if a + cycle is found) and the second element is either the sorted + list of nodes or the cycle of nodes found. + """ + + def visit(op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops): + """Recursively visits all Ops in a graph. + + Args: + op: the current Op being visited. + cycle: a cycle of Ops found. + permanently_marked_ops: the set of Ops that were already visited. + temporarily_marked_ops: the set of Ops that we have visited during + the current descent. + sorted_ops: the list of Ops sorted in topological order. + """ + + if cycle: + return + if op in permanently_marked_ops: + return + if op in temporarily_marked_ops: + cycle = temporarily_marked_ops + return + temporarily_marked_ops.add(op) + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + for consumer_op in out_tensor.consumers(): + visit(consumer_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + # pylint: disable=protected-access + for ctrl_output_op in op._control_outputs: + # pylint: enable=protected-access + visit(ctrl_output_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + temporarily_marked_ops.remove(op) + permanently_marked_ops.add(op) + sorted_ops.insert(0, op) + + graph_cycle = set([]) + sorted_ops = [] + permanently_marked_ops = set([]) + temporarily_marked_ops = set([]) + unsorted_ops = g.get_operations() + for op in unsorted_ops: + visit(op, graph_cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + if graph_cycle: + return (False, graph_cycle) + else: + assert len(unsorted_ops) == len(sorted_ops) + return (True, sorted_ops) + + def __init__(self): + """Initializes a TensorTracer. + + Sets the various member fields from the flags (if given) or the defaults. + """ + self._version = 'use-outside-compilation' + self._device_type = None + self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) + if not self._trace_mode: + self._trace_mode = _TRACE_MODE_NAN_INF + TensorTracer.check_trace_mode(self._trace_mode) + self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE + self._instrument_records = {} + interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) + self._selected_ops = interesting_ops.split() + self._set_trace_file_path() + self._set_op_range() + self._num_replicas = None + self._replica_id = None + + def _add_replica_id_to_graph(self, num_replicas, result_tensor): + """Adds nodes for computing the replica ID to the graph.""" + + if not num_replicas: + self._replica_id = 'unknown' + return result_tensor + + self._num_replicas = num_replicas + + with ops.control_dependencies(None): + # Uses None as dependency to run outside of TPU graph rewrites. + self._replica_id = tpu_ops.tpu_replicated_input( + list(range(self._num_replicas)), + name='tt_replica_id') + use_replica_id = array_ops.identity(self._replica_id).op + with ops.control_dependencies([use_replica_id]): + # Adds a control dependency from the result_tensor to + # the replica_id to ensure that replica_id will be added to the graph. + return array_ops.identity(result_tensor) + + def _set_trace_file_path(self): + """Sets the path of the output trace file.""" + + self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) + if not self._trace_file_path: + raise ValueError('--%s is not set in the environment variable %s' + %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) + elif TensorTracer.use_test_undeclared_outputs_dir(): + if os.path.isabs(self._trace_file_path): + raise ValueError('If use_test_undeclared_outputs_dir is set,' + 'trace_file_path cannot be an absolute path (%s)' + %self._trace_file_path) + outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) + self._trace_file_path = os.path.join(outputs_dir, + self._trace_file_path) + + def _set_op_range(self): + """Sets the index range of the Ops that we will consider tracing.""" + + op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) + if not op_range: + self._op_range = (-1, -1) # this means including all ops. + return + match = _OP_RANGE_PAT.match(op_range) + if not match: + self._op_range = (-1, -1) # this means including all ops. + return + self._op_range = (int(match.group(1)), int(match.group(2))) + + def _inside_op_range(self, idx): + """Return True if the given index is inside the selected range.""" + + if idx < self._op_range[0]: + return False + return self._op_range[1] < 0 or idx <= self._op_range[1] + + def _write_report(self, content): + """Writes the given content to the report.""" + + logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + + def _is_selected_op(self, op_name): + """Returns True if the Op with op_name is selected to be traced.""" + + if not self._selected_ops: + return True + if op_name in self._selected_ops: + return True + return False + + def _write_config_section(self): + """Writes the config section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) + self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) + self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) + self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) + self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) + + def _write_reason_section(self): + """Writes the reason section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) + for key in sorted(self._instrument_records): + self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) + + def _write_op_list_section(self, op_list): + """Writes the Op-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) + for i in range(0, len(op_list)): + self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + + def _write_graph_section(self, succeed, sorted_or_cycle): + """Writes the graph section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) + self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, + succeed)) + l = list(sorted_or_cycle) + for i in range(0, len(l)): + self._write_report('%d "%s"\n'%(i, l[i].name)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) + + def _make_tensor_trace_fun(self, op_name, output_idx): + """Makes the tensor tracing function called by outside compilation. + + Args: + op_name: the name of the Op that outputs the tensor to be traced. + output_idx: which output of the Op it is (0 means the first output). + + Returns: + A function to be passed as the first argument to outside compilation. + + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): + """Prints a tensor value to a file. + + Args: + op_name: the name of the Op that outputs the tensor to be printed. + output_idx: which output of the Op it is (0 means the first output). + num_elements: number of elements to print. + tensor: the tensor needs to be returned. + output_tensor: the tensor needs to be printed. + + Returns: + The same tensor passed via the "tensor" argument. + """ + msg = '"%s:%d" '%(op_name, output_idx) + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), + ' @', self._replica_id, + '\n', output_tensor, + summarize=num_elements, + output_stream=output_stream) + with ops.control_dependencies([print_op]): + return array_ops.identity(tensor).op + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + # Since host can't handle bf16, always convert tensor to f32. + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_any( + gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), + gen_math_ops.is_inf(tensor))) + else: + output_tensor = constant_op.constant(0) + return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + + def _show_global_step(tensor): + """Trace function for printing the global step count.""" + + return _print_tensor(op_name, output_idx, 1, tensor, tensor) + + def _show_part_tensor(tensor): + """Trace function for printing part of the tensor.""" + + return _print_tensor(op_name, output_idx, self._part_tensor_size, + tensor, tensor) + + def _show_full_tensor(tensor): + """Trace function for printing the entire tensor.""" + + return _print_tensor(op_name, output_idx, -1, tensor, tensor) + + if op_name == _GLOBAL_STEP_OP_NAME: + return _show_global_step + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return _show_part_tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return _show_full_tensor + + raise RuntimeError('Tensor trace fun for %s is not yet implemented' + %self._trace_mode) + + def trace_tpu(self, graph, result_tensor, num_replicas=None): + """Traces the tensors generated by TPU Ops in a TF graph. + + Args: + graph: the graph of Ops. + result_tensor: a result tensor of evaluating the graph. + num_replicas: number of replicas used on the TPU. + + Returns: + A tuple (result_tensor_copy, tracing_ops), where: + result_tensor_copy: an exact copy of result_tensor + tracing_ops: a list of tracing ops. If this list + is non empty, the caller of this function + should pose control dependencies upon these + Ops so that they will be executed when the + graph is evaluated. + """ + + self._device_type = _DEVICE_TYPE_TPU + TensorTracer.check_device_type(self._device_type) + result_tensor_copy = self._add_replica_id_to_graph(num_replicas, + result_tensor) + self._write_config_section() + tracing_ops = [] + operations = graph.get_operations() + self._write_op_list_section(operations) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + for op_id, op in enumerate(operations): + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_OUTSIDE_OP_RANGE) + continue + if not TensorTracer.should_trace(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_SHOULD_NOT_TRACE) + continue + if not self._is_selected_op(op.name): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_FILTERED_OUT) + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_DYNAMIC_SHAPE) + continue # cannot trace tensors with dynamic shape. + rank = len(out_tensor.shape) + if rank < 1: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_SCALAR) + continue # cannot trace scalar. + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_GET_TRACED) + consumers = out_tensor.consumers() + trace_op = tpu.outside_compilation( + self._make_tensor_trace_fun(op.name, i), out_tensor) + if consumers: + for consumer_op in consumers: + # pylint: disable=protected-access + consumer_op._add_control_input(trace_op) + # pylint: enable=protected-access + else: + # if there is no consumer, we will add the control dependence later + # when we add the control dependency to the output operations. + tracing_ops.append(trace_op) + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + + return (result_tensor_copy, tracing_ops) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index b6bb5c6e56c..6ae718cc2c9 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -189,12 +189,13 @@ class Topology(object): def cpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the CPU device attached to a logical core.""" return _tpu_host_device_name( - job, self._topology_tasks[device_coordinates]) + job, self._topology_tasks[tuple(device_coordinates)]) def tpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, self._topology_tasks[device_coordinates], - self._topology_devices[device_coordinates]) + return _tpu_device_name(job, + self._topology_tasks[tuple(device_coordinates)], + self._topology_devices[tuple(device_coordinates)]) @property def num_tasks(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 555ad0f1fdb..7cb8c4aa7f1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support @@ -108,6 +109,15 @@ ops.register_proto_function( from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access +def _is_iterable(obj): + """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" + try: + iter(obj) + return True + except TypeError: + return False + + def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: @@ -1317,9 +1327,15 @@ class _ModelFnWrapper(object): captured_training_hooks.capture(estimator_spec.training_hooks) + tracing_ops = [] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, + self._ctx.num_replicas) + # We must run train_op to update the variables prior to running the # outfeed. - with ops.control_dependencies([train_op]): + with ops.control_dependencies([train_op]+tracing_ops): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): @@ -2250,8 +2266,7 @@ class TPUEstimator(estimator_lib.Estimator): # Only fetching `tpu_tensors_on_cpu` does not trigger # TPU computation and blocks, so we add the control dependency here. control_inputs = ( - tpu_tensors_on_cpu if isinstance(tpu_tensors_on_cpu, - (list, tuple)) else + tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else (tpu_tensors_on_cpu,)) with ops.control_dependencies(control_inputs): new_tensors.append(array_ops.identity(t)) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index afe4c46c8ef..a701b38d4b3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -383,6 +383,7 @@ cc_library( ":lib_platform", ":platform_base", "//tensorflow/core/platform/default/build_config:port", + "@com_google_absl//absl/base", "@snappy", ], ) @@ -1057,6 +1058,7 @@ tf_gen_op_libs( "logging_ops", "manip_ops", "math_ops", + "mkl_nn_ops", "nccl_ops", "nn_ops", "no_op", @@ -1229,7 +1231,7 @@ cc_library( ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + tf_additional_cloud_op_deps(), + ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1285,7 +1287,9 @@ cc_library( ":framework", ":lib", ":nn_ops_op_lib", - ], + ] + if_mkl([ + ":mkl_nn_ops_op_lib", + ]), alwayslink = 1, ) @@ -1668,6 +1672,7 @@ cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) @@ -2168,6 +2173,7 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", + ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2811,7 +2817,6 @@ tf_cuda_library( ":functional_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:required", - ":core_cpu_impl", ]), alwayslink = 1, ) @@ -2997,6 +3002,16 @@ cc_library( deps = [":lib_internal"], ) +tf_cuda_library( + name = "metrics", + srcs = ["common_runtime/metrics.cc"], + hdrs = ["common_runtime/metrics.h"], + deps = [ + ":lib", + "@com_google_absl//absl/time", + ], +) + tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3013,10 +3028,12 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", + ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", + "@com_google_absl//absl/time", ], alwayslink = 1, ) @@ -3048,7 +3065,9 @@ tf_cuda_library( ], copts = tf_copts(), cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()), - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow:internal", + ], deps = [ ":core_cpu_internal", ":lib", diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt index cdaeb5091c7..bfaf3d2ea59 100644 --- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -4,7 +4,7 @@ op { in_arg { name: "float_values" description: <