[XLA] Wire through static device assignments from the Python client to the compiler
PiperOrigin-RevId: 296544701 Change-Id: I49ca867e15a9da92e2bdba6410da5ccb16a57440
This commit is contained in:
parent
88e2109ac8
commit
bdf0ea41b9
@ -97,6 +97,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
|
||||||
|
const DeviceAssignment& device_assignment) {
|
||||||
|
device_assignment_ = device_assignment;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
string ExecutableBuildOptions::ToString() const {
|
string ExecutableBuildOptions::ToString() const {
|
||||||
string result_layout = "nullopt";
|
string result_layout = "nullopt";
|
||||||
if (result_layout_set_) {
|
if (result_layout_set_) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||||
@ -76,6 +77,18 @@ class ExecutableBuildOptions {
|
|||||||
int num_partitions() const { return num_partitions_; }
|
int num_partitions() const { return num_partitions_; }
|
||||||
ExecutableBuildOptions& set_num_partitions(int num_partitions);
|
ExecutableBuildOptions& set_num_partitions(int num_partitions);
|
||||||
|
|
||||||
|
// If set, this specifies a static device assignment for the computation.
|
||||||
|
// Otherwise, the computation will be compiled generically and can be run with
|
||||||
|
// any device assignment compatible with the computation's replica and
|
||||||
|
// partition counts.
|
||||||
|
bool has_device_assignment() const { return device_assignment_.has_value(); }
|
||||||
|
ExecutableBuildOptions& set_device_assignment(
|
||||||
|
const DeviceAssignment& device_assignment);
|
||||||
|
const DeviceAssignment& device_assignment() const {
|
||||||
|
CHECK(device_assignment_.has_value());
|
||||||
|
return device_assignment_.value();
|
||||||
|
}
|
||||||
|
|
||||||
// Whether input and output buffers are aliased if the associated parameter is
|
// Whether input and output buffers are aliased if the associated parameter is
|
||||||
// passed-through XLA modules without being changed.
|
// passed-through XLA modules without being changed.
|
||||||
bool alias_passthrough_params() const { return alias_passthrough_params_; }
|
bool alias_passthrough_params() const { return alias_passthrough_params_; }
|
||||||
@ -91,6 +104,7 @@ class ExecutableBuildOptions {
|
|||||||
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||||
int num_replicas_ = 1;
|
int num_replicas_ = 1;
|
||||||
int num_partitions_ = 1;
|
int num_partitions_ = 1;
|
||||||
|
absl::optional<DeviceAssignment> device_assignment_;
|
||||||
bool alias_passthrough_params_ = false;
|
bool alias_passthrough_params_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -929,6 +929,7 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
|||||||
VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n"
|
VLOG(2) << "PyLocalExecutable::Compile using default device_assignment:\n"
|
||||||
<< device_assignment->ToString();
|
<< device_assignment->ToString();
|
||||||
}
|
}
|
||||||
|
options.set_device_assignment(device_assignment.value());
|
||||||
|
|
||||||
if (!argument_layouts) {
|
if (!argument_layouts) {
|
||||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||||
|
@ -112,6 +112,10 @@ ExecutionOptions CreateExecutionOptions(
|
|||||||
}
|
}
|
||||||
execution_options.set_num_replicas(build_options.num_replicas());
|
execution_options.set_num_replicas(build_options.num_replicas());
|
||||||
execution_options.set_num_partitions(build_options.num_partitions());
|
execution_options.set_num_partitions(build_options.num_partitions());
|
||||||
|
if (build_options.has_device_assignment()) {
|
||||||
|
TF_CHECK_OK(build_options.device_assignment().Serialize(
|
||||||
|
execution_options.mutable_device_assignment()));
|
||||||
|
}
|
||||||
execution_options.set_alias_passthrough_params(
|
execution_options.set_alias_passthrough_params(
|
||||||
build_options.alias_passthrough_params());
|
build_options.alias_passthrough_params());
|
||||||
return execution_options;
|
return execution_options;
|
||||||
|
@ -184,7 +184,7 @@ class Service : public ServiceInterface {
|
|||||||
Backend* mutable_backend() { return execute_backend_.get(); }
|
Backend* mutable_backend() { return execute_backend_.get(); }
|
||||||
|
|
||||||
// Create a Hlo module config for the given program shape and arguments.
|
// Create a Hlo module config for the given program shape and arguments.
|
||||||
// execution_options is optional; if not given a default is used.
|
// aot_options is optional; if not given a default is used.
|
||||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const Shape* const> argument_shapes,
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
|
Loading…
Reference in New Issue
Block a user