Add device assignment as an optional field to HloConfig to enable assignment-specific optimizations.
PiperOrigin-RevId: 238349255
This commit is contained in:
parent
7b50beb0b5
commit
383eb9afbe
@ -719,6 +719,7 @@ cc_library(
|
|||||||
":compilation_cache",
|
":compilation_cache",
|
||||||
":compiler",
|
":compiler",
|
||||||
":computation_layout",
|
":computation_layout",
|
||||||
|
":computation_placer",
|
||||||
":device_memory_allocator",
|
":device_memory_allocator",
|
||||||
":dump",
|
":dump",
|
||||||
":dynamic_dimension_inference",
|
":dynamic_dimension_inference",
|
||||||
@ -933,6 +934,7 @@ cc_library(
|
|||||||
hdrs = ["compiler.h"],
|
hdrs = ["compiler.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_value",
|
":buffer_value",
|
||||||
|
":computation_placer",
|
||||||
":executable",
|
":executable",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_module_config",
|
":hlo_module_config",
|
||||||
|
@ -75,6 +75,10 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
*execution_options.mutable_debug_options() = debug_options;
|
*execution_options.mutable_debug_options() = debug_options;
|
||||||
*execution_options.mutable_shape_with_output_layout() =
|
*execution_options.mutable_shape_with_output_layout() =
|
||||||
instance.result_layout->ToProto();
|
instance.result_layout->ToProto();
|
||||||
|
if (options.has_static_device_assignment()) {
|
||||||
|
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
|
||||||
|
execution_options.mutable_device_assignment()));
|
||||||
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(
|
CreateModuleConfig(
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/executable.h"
|
#include "tensorflow/compiler/xla/service/executable.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
@ -82,12 +83,24 @@ class AotCompilationOptions {
|
|||||||
const DebugOptions& debug_options() const { return debug_options_; }
|
const DebugOptions& debug_options() const { return debug_options_; }
|
||||||
DebugOptions* mutable_debug_options() { return &debug_options_; }
|
DebugOptions* mutable_debug_options() { return &debug_options_; }
|
||||||
|
|
||||||
|
bool has_static_device_assignment() const {
|
||||||
|
return static_device_assignment_.has_value();
|
||||||
|
}
|
||||||
|
const DeviceAssignment& static_device_assignment() const {
|
||||||
|
CHECK(static_device_assignment_.has_value());
|
||||||
|
return *static_device_assignment_;
|
||||||
|
}
|
||||||
|
void set_static_device_assignment(const DeviceAssignment& device_assignment) {
|
||||||
|
static_device_assignment_ = device_assignment;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AotCompilationOptions();
|
AotCompilationOptions();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||||
DebugOptions debug_options_;
|
DebugOptions debug_options_;
|
||||||
|
absl::optional<DeviceAssignment> static_device_assignment_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Abstract superclass describing metadata produced during ahead-of-time
|
// Abstract superclass describing metadata produced during ahead-of-time
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||||
@ -319,6 +320,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
config->set_intra_op_parallelism_threads(
|
config->set_intra_op_parallelism_threads(
|
||||||
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (execution_options != nullptr &&
|
||||||
|
execution_options->has_device_assignment()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto device_assignment,
|
||||||
|
DeviceAssignment::Deserialize(execution_options->device_assignment()));
|
||||||
|
config->set_static_device_assignment(*device_assignment);
|
||||||
|
}
|
||||||
|
|
||||||
return std::move(config);
|
return std::move(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,6 +288,10 @@ message ExecutionOptions {
|
|||||||
// Number of replicas of the computation to run. If zero, uses the default
|
// Number of replicas of the computation to run. If zero, uses the default
|
||||||
// number of replicas for the XLA service.
|
// number of replicas for the XLA service.
|
||||||
int32 num_replicas = 6;
|
int32 num_replicas = 6;
|
||||||
|
|
||||||
|
// This optional field specifies the device assignment if known at compile
|
||||||
|
// time.
|
||||||
|
DeviceAssignmentProto device_assignment = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetDeviceHandlesRequest {
|
message GetDeviceHandlesRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user