diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 659ae055cdf..03a9a3ad3a4 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable( build_options.set_device_allocator(options.device_allocator); build_options.set_alias_passthrough_params(options.alias_passthrough_params); - auto compile_result = - client_->Compile(*result.computation, argument_layouts, build_options); - if (!compile_result.ok()) { - return compile_result.status(); - } - *executable = std::move(compile_result.ValueOrDie()); + TF_ASSIGN_OR_RETURN( + auto executables, + client_->Compile(*result.computation, argument_layouts, build_options)); + TF_RET_CHECK(executables.size() == 1); + *executable = std::move(executables[0]); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index c66112cc5fa..0392cc7d345 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -117,8 +117,10 @@ XlaJitCompiledCpuFunction::Compile( // Compile the executable. The static_cast to the CpuExecutable subclass is // necessary since the raw function and buffer assignments are only available // there. - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + TF_ASSIGN_OR_RETURN(auto executables, client->Compile(computation, arg_shapes, build_options)); + TF_RET_CHECK(executables.size() == 1); + std::unique_ptr executable = std::move(executables[0]); const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index c93ad9f98ce..7b29e9c4e90 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -337,7 +337,7 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( +StatusOr>> LocalClient::Compile( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& options) { @@ -347,12 +347,20 @@ StatusOr> LocalClient::Compile( VLOG(3) << "Set device ordinal to default value of: " << updated_options.device_ordinal(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - local_service_->CompileExecutable( + TF_ASSIGN_OR_RETURN(std::vector> executables, + local_service_->CompileExecutables( computation, argument_layouts, updated_options)); - return absl::make_unique(std::move(executable), - local_service_->mutable_backend(), - updated_options); + + std::vector> local_executables; + local_executables.reserve(executables.size()); + + for (auto& executable : executables) { + local_executables.push_back(absl::make_unique( + std::move(executable), local_service_->mutable_backend(), + updated_options)); + } + + return std::move(local_executables); } StatusOr LocalClient::LiteralToShapedBuffer( diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 6cfa7cf6cd7..3f9ed37b05f 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ #include +#include #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" @@ -110,12 +111,13 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given XlaComputation, argument layouts and options. + // Build and return LocalExecutable objects (one per partition, as specified + // by the build options). The executable is compiled using the given + // XlaComputation, argument layouts and options. // // The given ExecutableBuildOptions overrides any values from XLA_FLAGS // environment variable. - StatusOr> Compile( + StatusOr>> Compile( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& options); diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 6e150a4c24e..c825bdce14d 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -676,16 +676,27 @@ static std::shared_ptr LookupDevice(const PyLocalClient& client, } PyLocalExecutable::PyLocalExecutable( - std::shared_ptr executable, + std::vector> executables, DeviceAssignment device_assignment, std::shared_ptr client) : client_(std::move(client)), - executable_(std::move(executable)), device_assignment_( std::make_shared(device_assignment)) { + executables_.reserve(executables.size()); + for (auto& executable : executables) { + executables_.emplace_back(std::move(executable)); + } + + // This must go after `executables_` is initialized. VLOG(1) << "PyLocalExecutable " << name() << " device_assignment:\n" << device_assignment_->ToString(); + const int num_replicas = device_assignment_->replica_count(); const int num_partitions = device_assignment_->computation_count(); + + CHECK_EQ(num_partitions, executables_.size()) + << "Number of executables " << executables_.size() + << " did not match number of partitions " << num_partitions; + for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment_)(replica, partition); @@ -704,7 +715,7 @@ PyLocalExecutable::PyLocalExecutable( } const std::string& PyLocalExecutable::name() const { - Executable* executable = executable_->executable(); + Executable* executable = executables_[0]->executable(); if (executable->has_module()) { return executable->module().name(); } else { @@ -779,7 +790,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( device_state->compute_semaphore().ScopedAcquire(1)); StatusOr result_buffer_or_status = - executable_->RunAsync(argument_buffer_ptrs, options); + executables_[partition]->RunAsync(argument_buffer_ptrs, options); VLOG(1) << "Replica " << replica << " partition " << partition << " completed; ok=" << result_buffer_or_status.ok(); @@ -809,7 +820,8 @@ StatusOr> PyLocalExecutable::ExecuteHelper( device_state->ThenRelease( device_state->compute_stream(), - std::make_tuple(executable_, compute_reservation, device_assignment_)); + std::make_tuple(executables_[partition], compute_reservation, + device_assignment_)); return absl::make_unique( result_buffer.on_host_shape(), result_buffer.on_device_shape(), std::move(out_buffer), client_, device); @@ -1081,13 +1093,14 @@ PyLocalExecutable::Compile(const XlaComputation& computation, TF_RETURN_IF_ERROR(assign_layouts(&result_layout)); options.set_result_layout(result_layout); - TF_ASSIGN_OR_RETURN(std::unique_ptr local_executable, - client->client()->Compile( - computation, argument_layout_pointers, options)); + TF_ASSIGN_OR_RETURN( + std::vector> local_executables, + client->client()->Compile(computation, argument_layout_pointers, + options)); - return absl::make_unique( - std::shared_ptr(std::move(local_executable)), - std::move(*device_assignment), std::move(client)); + return absl::make_unique(std::move(local_executables), + std::move(*device_assignment), + std::move(client)); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index c3d7cd7b341..72afa3d0135 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -283,7 +283,8 @@ class PyLocalBuffer { }; // Represents a compiled computation that can be executed given handles to -// device-allocated literals. Wraps an XLA LocalExecutable. +// device-allocated literals. Wraps one or more XLA LocalExecutables (one per +// partition, as specified by the build options). class PyLocalExecutable { public: // Compiles a computation to an executable. @@ -304,20 +305,24 @@ class PyLocalExecutable { std::shared_ptr client, absl::optional device_assignment); - PyLocalExecutable(std::shared_ptr executable, + PyLocalExecutable(std::vector> executables, DeviceAssignment device_assignment, std::shared_ptr client); int num_replicas() const { - return executable_->build_options().num_replicas(); + return executables_[0]->build_options().num_replicas(); } int num_partitions() const { - return executable_->build_options().num_partitions(); + return executables_[0]->build_options().num_partitions(); } int64 SizeOfGeneratedCodeInBytes() const { - return executable_->executable()->SizeOfGeneratedCodeInBytes(); + int64 size = 0; + for (auto& executable : executables_) { + size += executable->executable()->SizeOfGeneratedCodeInBytes(); + } + return size; } const DeviceAssignment& device_assignment() const { @@ -346,7 +351,7 @@ class PyLocalExecutable { StatusOr>> ExecuteOnLocalDevices( absl::Span> argument_handles); - void Delete() { executable_ = nullptr; } + void Delete() { executables_.clear(); } const string& name() const; @@ -359,7 +364,8 @@ class PyLocalExecutable { // asynchronous execution, the process being executed can outlive the // executable itself. std::shared_ptr const client_; - std::shared_ptr executable_; + // One executable per partition. + std::vector> executables_; std::shared_ptr device_assignment_; // The replica and partition indices of device_assignment_ to be run by this diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index a7872241e8f..91a00b5555a 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -119,7 +119,8 @@ ExecutionOptions CreateExecutionOptions( } // namespace -StatusOr> LocalService::CompileExecutable( +StatusOr>> +LocalService::CompileExecutables( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { @@ -178,9 +179,29 @@ StatusOr> LocalService::CompileExecutable( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(proto, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); + // TODO(cjfj): Investigate why there are a couple of test failures when the + // single partition computations are built using `BuildExecutables`, fix it, + // and remove this special case (provided the performance if similar). + if (build_options.num_partitions() == 1) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + BuildExecutable(proto, std::move(module_config), execute_backend_.get(), + executor, build_options.device_allocator())); + std::vector> executables; + executables.push_back(std::move(executable)); + return executables; + } else { + std::vector> module_configs; + module_configs.push_back(std::move(module_config)); + // BuildExecutables uses the executors length to determine the number of + // cores per module, but otherwise only uses the first executor. + std::vector executors(build_options.num_partitions(), + executor); + + return BuildExecutables({&proto}, std::move(module_configs), + execute_backend_.get(), {executors}, + build_options.device_allocator()); + } } StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 170d226e336..3e684a32274 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_ #include +#include #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" @@ -41,12 +42,12 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and + // Builds Executables with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( + StatusOr>> CompileExecutables( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 77183779217..8f18e7fc241 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -452,7 +452,10 @@ xla_test( name = "while_test", srcs = ["while_test.cc"], deps = [ + ":client_library_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -464,9 +467,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -480,7 +480,9 @@ xla_test( "interpreter", ], deps = [ + ":client_library_test_base", ":test_macros_header", + ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -489,8 +491,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", @@ -950,7 +950,12 @@ xla_test( "optonly", ], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -959,11 +964,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -984,7 +984,12 @@ xla_test( "optonly", ], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -993,11 +998,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1048,7 +1048,12 @@ xla_test( "optonly", ], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -1057,11 +1062,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1383,7 +1383,10 @@ xla_test( timeout = "moderate", srcs = ["dynamic_ops_test.cc"], deps = [ + ":client_library_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", @@ -1395,9 +1398,6 @@ xla_test( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2205,7 +2205,11 @@ xla_test( name = "cpu_gpu_fusion_test", srcs = ["cpu_gpu_fusion_test.cc"], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2214,10 +2218,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2295,7 +2295,11 @@ xla_test( shard_count = 30, tags = ["optonly"], deps = [ + ":literal_test_util", + ":local_client_test_base", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -2304,16 +2308,12 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:local_client_test_base", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2526,13 +2526,13 @@ tf_cc_test( srcs = ["multiple_devices_on_host_test.cc"], args = ["--xla_force_host_platform_device_count=4"], deps = [ + ":xla_internal_test_main", # fixdeps: keep "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc index 83ed3c93df1..2a1eed7c7a7 100644 --- a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc @@ -882,13 +882,14 @@ void BM_ParallelFusion(int num_iters) { .ConsumeValueOrDie(); // Build executable. - std::unique_ptr executable = + auto executables = client ->Compile(computation, {&buffer0.on_host_shape(), &buffer1.on_host_shape(), &buffer2.on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); + auto executable = std::move(executables[0]); se::Stream stream(executors[device_ordinal]); stream.Init(); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6742e863b9b..6d64cb0a510 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1672,11 +1672,10 @@ void DOT_ReorderContracting(int num_iters) { client->LiteralToShapedBuffer(input_literal, device_ordinal) .ConsumeValueOrDie(); - std::unique_ptr executable = - client - ->Compile(computation, {&buffer0.on_host_shape()}, - ExecutableBuildOptions()) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, client->Compile(computation, {&buffer0.on_host_shape()}, + ExecutableBuildOptions())); + auto executable = std::move(executables[0]); se::Stream stream(executors[device_ordinal]); stream.Init(); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 9ea27585e61..555dfc48d9e 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -779,9 +779,10 @@ void BM_DynamicSlice(int num_iters) { DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); - std::unique_ptr executable = - client->Compile(computation, host_shapes, ExecutableBuildOptions()) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + client->Compile(computation, host_shapes, ExecutableBuildOptions())); + auto executable = std::move(executables[0]); // Run some warm-up executions. ExecutableRunOptions options; diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h index 1aa06a0aa63..67e6d6d630a 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h @@ -242,7 +242,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { [&](const Literal* input_literal) { return &input_literal->shape(); }); TF_ASSIGN_OR_RETURN( - auto executable, + auto executables, client_->Compile(computation, input_shapes, build_opts)); std::vector input_buffers; @@ -264,7 +264,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { run_opts.set_intra_op_thread_pool( client_->backend().eigen_intra_op_thread_pool_device()); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - executable->Run(input_buffer_pointers, run_opts)); + executables[0]->Run(input_buffer_pointers, run_opts)); TF_ASSIGN_OR_RETURN(Literal result_literal, client_->ShapedBufferToLiteral(result)); diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index 5511190caf9..1868159ef7b 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -46,12 +46,13 @@ TEST_F(HloMetadataTest, MetadataPropagation) { Shape argument_layout = ShapeUtil::MakeShape(F32, {}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, + auto executables, local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout, &argument_layout}, ExecutableBuildOptions())); - auto instruction = executable->executable() + auto instruction = executables[0] + ->executable() ->module() .entry_computation() ->root_instruction(); @@ -67,15 +68,14 @@ TEST_F(HloMetadataTest, MetadataClearing) { BuildAddComputation(&builder); Shape argument_layout = ShapeUtil::MakeShape(F32, {}); - auto executable_status = local_client_->Compile( - builder.Build().ValueOrDie(), {&argument_layout, &argument_layout}, - ExecutableBuildOptions()); - ASSERT_IS_OK(executable_status); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + local_client_->Compile(builder.Build().ValueOrDie(), + {&argument_layout, &argument_layout}, + ExecutableBuildOptions())); - std::unique_ptr executable = - executable_status.ConsumeValueOrDie(); - - auto instruction = executable->executable() + auto instruction = executables[0] + ->executable() ->module() .entry_computation() ->root_instruction(); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 67a1abacd18..6d156f12b36 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -759,17 +760,17 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { Shape argument_layout = ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); - auto executable_status = + TF_ASSERT_OK_AND_ASSIGN( + auto executables, local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout}, - ExecutableBuildOptions()); - ASSERT_IS_OK(executable_status); - std::unique_ptr executable = - executable_status.ConsumeValueOrDie(); + ExecutableBuildOptions())); + EXPECT_EQ(1, executables.size()); auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = - executable->Run({&x_array}, DefaultExecutableRunOptions()) + executables[0] + ->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); ASSERT_IS_OK(local_client_->mutable_backend() ->BorrowStream(0) @@ -780,6 +781,31 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } +XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) { + if (local_client_->device_count() < 2) { + GTEST_SKIP_("requires two devices"); + } + + XlaBuilder builder(TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + auto z = ConstantR1(&builder, {5.0f, 6.0f, 7.0f}); + auto r = Add(x, y); + builder.SetSharding(sharding_builder::AssignDevice(1)); + Add(r, z); + builder.ClearSharding(); + + Shape argument_layout = + ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); + ExecutableBuildOptions build_options; + build_options.set_num_partitions(2); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout}, + build_options)); + EXPECT_EQ(2, executables.size()); +} + XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { // Test copying Literals to the device as ShapedBuffers, then copying them // back again to Literals. @@ -928,11 +954,10 @@ void BM_LocalClientOverhead(int num_iters) { const int kWarmups = 2; - auto executable_status = client->Compile( - computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()); - ASSERT_IS_OK(executable_status); - std::unique_ptr executable = - executable_status.ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, client->Compile(computation, {&buffer.on_host_shape()}, + ExecutableBuildOptions())); + std::unique_ptr executable = std::move(executables[0]); ExecutableRunOptions run_options; run_options.set_allocator(&allocator).set_stream(stream.get()); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index fdb3489f450..4c5951476d8 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -194,9 +194,10 @@ StatusOr LocalClientTestBase::ExecuteLocally( argument_layouts[i] = &arguments[i]->on_host_shape(); } TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, + auto executables, local_client_->Compile(computation, argument_layouts, build_options)); - TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options)); + TF_RET_CHECK(executables.size() == 1); + TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options)); auto device_ordinal = build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal(); diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc index c530591c6e5..2b19aaded9c 100644 --- a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -65,8 +65,9 @@ void TestWithDeviceCount(const int device_count) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, + auto executables, client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{})); + std::unique_ptr executable = std::move(executables[0]); std::vector threads; absl::Mutex results_mutex; std::vector>> results; diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 2a0d98ad1f1..8a99976e60c 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -47,12 +47,12 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); - auto executable_status = local_client_->Compile( - computation_status.ValueOrDie(), {&pair_float, &single_float}, - ExecutableBuildOptions()); - TF_ASSERT_OK(executable_status.status()); - HloModule& module = const_cast( - executable_status.ValueOrDie()->executable()->module()); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, local_client_->Compile(computation_status.ValueOrDie(), + {&pair_float, &single_float}, + ExecutableBuildOptions())); + HloModule& module = + const_cast(executables[0]->executable()->module()); TF_ASSERT_OK(MakeFakeArguments(&module).status()); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 4d80a57ad40..5a482305513 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1314,9 +1314,10 @@ void BM_WhileLoop(int num_iters) { While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); - std::unique_ptr executable = - client->Compile(computation, {}, ExecutableBuildOptions()) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + client->Compile(computation, {}, ExecutableBuildOptions())); + auto executable = std::move(executables[0]); // Run some warm-up executions. ExecutableRunOptions options; diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 957e96d5a43..1b8203e02a9 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -158,11 +158,11 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, ExecutableBuildOptions build_options; build_options.mutable_debug_options()->set_xla_hlo_profile(true); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr local_executable, + auto local_executables, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, build_options)); - Executable* executable = local_executable->executable(); + Executable* executable = local_executables[0]->executable(); HloExecutionProfile hlo_execution_profile( &executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index df2d3d18b9f..90e2596dc10 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -85,10 +85,11 @@ void RealMain(absl::Span args) { ExecutableBuildOptions build_options; build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); - StatusOr> executable = - local_service->CompileExecutable(computation, layouts, build_options); - - const HloModule& module = executable.ValueOrDie()->module(); + auto executables = + local_service->CompileExecutables(computation, layouts, build_options) + .ConsumeValueOrDie(); + CHECK_EQ(executables.size(), 1); + const HloModule& module = executables[0]->module(); OperationDumper dumper(arg); for (auto* computation : module.computations()) { diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 35bb82ca22f..c4dc6d10670 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -62,10 +62,11 @@ void RealMain(absl::Span args, bool compile) { ExecutableBuildOptions build_options; build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); - StatusOr> executable = - local_service->CompileExecutable(computation, layouts, build_options); - - const HloModule& module = executable.ValueOrDie()->module(); + auto executables = + local_service->CompileExecutables(computation, layouts, build_options) + .ConsumeValueOrDie(); + CHECK_EQ(executables.size(), 1); + const HloModule& module = executables[0]->module(); fprintf(stdout, "HLO compiled for %s backend:\n%s\n", local_service->backend().platform()->Name().c_str(), diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 639f91b8b53..3b5023457b2 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -125,7 +125,11 @@ StatusOr> CompileExecutable( } ExecutableBuildOptions exec_build_options; *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags(); - return client->Compile(computation, argument_layout_ptrs, exec_build_options); + TF_ASSIGN_OR_RETURN( + auto executables, + client->Compile(computation, argument_layout_ptrs, exec_build_options)); + TF_RET_CHECK(executables.size() == 1); + return std::move(executables[0]); } absl::optional GetXfeedShape(bool is_infeed, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 99fb092335e..7304008cef1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -128,12 +128,11 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, } VLOG(1) << "Building executable"; - auto compile_result = - client->Compile(computation, argument_layout_ptrs, build_options); - if (!compile_result.ok()) { - return compile_result.status(); - } - *program = std::move(compile_result.ValueOrDie()); + TF_ASSIGN_OR_RETURN( + auto executables, + client->Compile(computation, argument_layout_ptrs, build_options)); + TF_RET_CHECK(executables.size() == 1); + *program = std::move(executables[0]); return Status::OK(); } diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 08a99756426..ec23f3d4a97 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -285,9 +285,12 @@ xla::ProgramShape XlaCompiledProgramShape( for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) { parameters_shapes.push_back(&input_program_shape.parameters(i)); } - auto local_executable = + std::vector> local_executables = client->Compile(computation, parameters_shapes, exec_options) - .ValueOrDie(); + .ConsumeValueOrDie(); + EXPECT_EQ(local_executables.size(), 1); + std::unique_ptr local_executable = + std::move(local_executables[0]); return local_executable->executable() ->module() .entry_computation()