From bdf0ea41b9bb5498c3ba144e7dd03c17c2fb59af Mon Sep 17 00:00:00 2001 From: James Bradbury Date: Fri, 21 Feb 2020 17:39:08 -0800 Subject: [PATCH] [XLA] Wire through static device assignments from the Python client to the compiler PiperOrigin-RevId: 296544701 Change-Id: I49ca867e15a9da92e2bdba6410da5ccb16a57440 --- tensorflow/compiler/xla/client/BUILD | 1 + .../xla/client/executable_build_options.cc | 6 ++++++ .../compiler/xla/client/executable_build_options.h | 14 ++++++++++++++ tensorflow/compiler/xla/python/local_client.cc | 1 + tensorflow/compiler/xla/service/local_service.cc | 4 ++++ tensorflow/compiler/xla/service/service.h | 2 +- 6 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 7b53f8504ea..0a47920bd9a 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -97,6 +97,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index bb3d3317ec5..cd52e2f5e45 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( + const DeviceAssignment& device_assignment) { + device_assignment_ = device_assignment; + return *this; +} + string ExecutableBuildOptions::ToString() const { string result_layout = "nullopt"; if (result_layout_set_) { diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 461fd834115..360ad0260df 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -76,6 +77,18 @@ class ExecutableBuildOptions { int num_partitions() const { return num_partitions_; } ExecutableBuildOptions& set_num_partitions(int num_partitions); + // If set, this specifies a static device assignment for the computation. + // Otherwise, the computation will be compiled generically and can be run with + // any device assignment compatible with the computation's replica and + // partition counts. + bool has_device_assignment() const { return device_assignment_.has_value(); } + ExecutableBuildOptions& set_device_assignment( + const DeviceAssignment& device_assignment); + const DeviceAssignment& device_assignment() const { + CHECK(device_assignment_.has_value()); + return device_assignment_.value(); + } + // Whether input and output buffers are aliased if the associated parameter is // passed-through XLA modules without being changed. bool alias_passthrough_params() const { return alias_passthrough_params_; } @@ -91,6 +104,7 @@ class ExecutableBuildOptions { se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; int num_partitions_ = 1; + absl::optional device_assignment_; bool alias_passthrough_params_ = false; }; diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 39da7f086b5..a35b20f6aa1 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -929,6 +929,7 @@ PyLocalExecutable::Compile(const XlaComputation& computation, VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n" << device_assignment->ToString(); } + options.set_device_assignment(device_assignment.value()); if (!argument_layouts) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 91a00b5555a..ef8ddfc1a76 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -112,6 +112,10 @@ ExecutionOptions CreateExecutionOptions( } execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_partitions(build_options.num_partitions()); + if (build_options.has_device_assignment()) { + TF_CHECK_OK(build_options.device_assignment().Serialize( + execution_options.mutable_device_assignment())); + } execution_options.set_alias_passthrough_params( build_options.alias_passthrough_params()); return execution_options; diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 3a4e17d7f44..d58020655de 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -184,7 +184,7 @@ class Service : public ServiceInterface { Backend* mutable_backend() { return execute_backend_.get(); } // Create a Hlo module config for the given program shape and arguments. - // execution_options is optional; if not given a default is used. + // aot_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes,