diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 760e3ebada6..64d2da499db 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -719,6 +719,7 @@ cc_library( ":compilation_cache", ":compiler", ":computation_layout", + ":computation_placer", ":device_memory_allocator", ":dump", ":dynamic_dimension_inference", @@ -933,6 +934,7 @@ cc_library( hdrs = ["compiler.h"], deps = [ ":buffer_value", + ":computation_placer", ":executable", ":hlo", ":hlo_module_config", diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 5209da93ee9..a4758c2b9db 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -75,6 +75,10 @@ CompileOnlyService::CompileAheadOfTime( *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = instance.result_layout->ToProto(); + if (options.has_static_device_assignment()) { + TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( + execution_options.mutable_device_assignment())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d4db95da8eb..9b483bd97e9 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -82,12 +83,24 @@ class AotCompilationOptions { const DebugOptions& debug_options() const { return debug_options_; } DebugOptions* mutable_debug_options() { return &debug_options_; } + bool has_static_device_assignment() const { + return static_device_assignment_.has_value(); + } + const DeviceAssignment& static_device_assignment() const { + CHECK(static_device_assignment_.has_value()); + return *static_device_assignment_; + } + void set_static_device_assignment(const DeviceAssignment& device_assignment) { + static_device_assignment_ = device_assignment; + } + protected: AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; DebugOptions debug_options_; + absl::optional static_device_assignment_; }; // Abstract superclass describing metadata produced during ahead-of-time diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 91efbdd2fea..49c346d87fc 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" @@ -319,6 +320,15 @@ StatusOr> Service::CreateModuleConfig( config->set_intra_op_parallelism_threads( execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); } + + if (execution_options != nullptr && + execution_options->has_device_assignment()) { + TF_ASSIGN_OR_RETURN( + auto device_assignment, + DeviceAssignment::Deserialize(execution_options->device_assignment())); + config->set_static_device_assignment(*device_assignment); + } + return std::move(config); } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 879929697a4..6d71ae866c3 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -288,6 +288,10 @@ message ExecutionOptions { // Number of replicas of the computation to run. If zero, uses the default // number of replicas for the XLA service. int32 num_replicas = 6; + + // This optional field specifies the device assignment if known at compile + // time. + DeviceAssignmentProto device_assignment = 7; } message GetDeviceHandlesRequest {