[XLA] Wire through static device assignments from the Python client to the compiler

PiperOrigin-RevId: 296544701
Change-Id: I49ca867e15a9da92e2bdba6410da5ccb16a57440
This commit is contained in:
James Bradbury 2020-02-21 17:39:08 -08:00 committed by TensorFlower Gardener
parent 88e2109ac8
commit bdf0ea41b9
6 changed files with 27 additions and 1 deletions

View File

@ -97,6 +97,7 @@ cc_library(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",

View File

@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions(
return *this; return *this;
} }
ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
const DeviceAssignment& device_assignment) {
device_assignment_ = device_assignment;
return *this;
}
string ExecutableBuildOptions::ToString() const { string ExecutableBuildOptions::ToString() const {
string result_layout = "nullopt"; string result_layout = "nullopt";
if (result_layout_set_) { if (result_layout_set_) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla.pb.h"
@ -76,6 +77,18 @@ class ExecutableBuildOptions {
int num_partitions() const { return num_partitions_; } int num_partitions() const { return num_partitions_; }
ExecutableBuildOptions& set_num_partitions(int num_partitions); ExecutableBuildOptions& set_num_partitions(int num_partitions);
// If set, this specifies a static device assignment for the computation.
// Otherwise, the computation will be compiled generically and can be run with
// any device assignment compatible with the computation's replica and
// partition counts.
bool has_device_assignment() const { return device_assignment_.has_value(); }
ExecutableBuildOptions& set_device_assignment(
const DeviceAssignment& device_assignment);
const DeviceAssignment& device_assignment() const {
CHECK(device_assignment_.has_value());
return device_assignment_.value();
}
// Whether input and output buffers are aliased if the associated parameter is // Whether input and output buffers are aliased if the associated parameter is
// passed-through XLA modules without being changed. // passed-through XLA modules without being changed.
bool alias_passthrough_params() const { return alias_passthrough_params_; } bool alias_passthrough_params() const { return alias_passthrough_params_; }
@ -91,6 +104,7 @@ class ExecutableBuildOptions {
se::DeviceMemoryAllocator* device_allocator_ = nullptr; se::DeviceMemoryAllocator* device_allocator_ = nullptr;
int num_replicas_ = 1; int num_replicas_ = 1;
int num_partitions_ = 1; int num_partitions_ = 1;
absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false; bool alias_passthrough_params_ = false;
}; };

View File

@ -929,6 +929,7 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n" VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n"
<< device_assignment->ToString(); << device_assignment->ToString();
} }
options.set_device_assignment(device_assignment.value());
if (!argument_layouts) { if (!argument_layouts) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, TF_ASSIGN_OR_RETURN(ProgramShape program_shape,

View File

@ -112,6 +112,10 @@ ExecutionOptions CreateExecutionOptions(
} }
execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions()); execution_options.set_num_partitions(build_options.num_partitions());
if (build_options.has_device_assignment()) {
TF_CHECK_OK(build_options.device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
execution_options.set_alias_passthrough_params( execution_options.set_alias_passthrough_params(
build_options.alias_passthrough_params()); build_options.alias_passthrough_params());
return execution_options; return execution_options;

View File

@ -184,7 +184,7 @@ class Service : public ServiceInterface {
Backend* mutable_backend() { return execute_backend_.get(); } Backend* mutable_backend() { return execute_backend_.get(); }
// Create a Hlo module config for the given program shape and arguments. // Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used. // aot_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape, const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes, absl::Span<const Shape* const> argument_shapes,