Add device assignment as an optional field to HloConfig to enable assignment-specific optimizations.
PiperOrigin-RevId: 238349255
This commit is contained in:
parent
7b50beb0b5
commit
383eb9afbe
@ -719,6 +719,7 @@ cc_library(
|
||||
":compilation_cache",
|
||||
":compiler",
|
||||
":computation_layout",
|
||||
":computation_placer",
|
||||
":device_memory_allocator",
|
||||
":dump",
|
||||
":dynamic_dimension_inference",
|
||||
@ -933,6 +934,7 @@ cc_library(
|
||||
hdrs = ["compiler.h"],
|
||||
deps = [
|
||||
":buffer_value",
|
||||
":computation_placer",
|
||||
":executable",
|
||||
":hlo",
|
||||
":hlo_module_config",
|
||||
|
@ -75,6 +75,10 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
instance.result_layout->ToProto();
|
||||
if (options.has_static_device_assignment()) {
|
||||
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
|
||||
execution_options.mutable_device_assignment()));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
@ -82,12 +83,24 @@ class AotCompilationOptions {
|
||||
const DebugOptions& debug_options() const { return debug_options_; }
|
||||
DebugOptions* mutable_debug_options() { return &debug_options_; }
|
||||
|
||||
bool has_static_device_assignment() const {
|
||||
return static_device_assignment_.has_value();
|
||||
}
|
||||
const DeviceAssignment& static_device_assignment() const {
|
||||
CHECK(static_device_assignment_.has_value());
|
||||
return *static_device_assignment_;
|
||||
}
|
||||
void set_static_device_assignment(const DeviceAssignment& device_assignment) {
|
||||
static_device_assignment_ = device_assignment;
|
||||
}
|
||||
|
||||
protected:
|
||||
AotCompilationOptions();
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
DebugOptions debug_options_;
|
||||
absl::optional<DeviceAssignment> static_device_assignment_;
|
||||
};
|
||||
|
||||
// Abstract superclass describing metadata produced during ahead-of-time
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
@ -319,6 +320,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
config->set_intra_op_parallelism_threads(
|
||||
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
||||
}
|
||||
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_device_assignment()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto device_assignment,
|
||||
DeviceAssignment::Deserialize(execution_options->device_assignment()));
|
||||
config->set_static_device_assignment(*device_assignment);
|
||||
}
|
||||
|
||||
return std::move(config);
|
||||
}
|
||||
|
||||
|
@ -288,6 +288,10 @@ message ExecutionOptions {
|
||||
// Number of replicas of the computation to run. If zero, uses the default
|
||||
// number of replicas for the XLA service.
|
||||
int32 num_replicas = 6;
|
||||
|
||||
// This optional field specifies the device assignment if known at compile
|
||||
// time.
|
||||
DeviceAssignmentProto device_assignment = 7;
|
||||
}
|
||||
|
||||
message GetDeviceHandlesRequest {
|
||||
|
Loading…
Reference in New Issue
Block a user