Add device assignment as an optional field to HloConfig to enable assignment-specific optimizations.

PiperOrigin-RevId: 238349255
This commit is contained in:
Yuanzhong Xu 2019-03-13 17:54:34 -07:00 committed by TensorFlower Gardener
parent 7b50beb0b5
commit 383eb9afbe
5 changed files with 33 additions and 0 deletions

View File

@ -719,6 +719,7 @@ cc_library(
":compilation_cache",
":compiler",
":computation_layout",
":computation_placer",
":device_memory_allocator",
":dump",
":dynamic_dimension_inference",
@ -933,6 +934,7 @@ cc_library(
hdrs = ["compiler.h"],
deps = [
":buffer_value",
":computation_placer",
":executable",
":hlo",
":hlo_module_config",

View File

@ -75,6 +75,10 @@ CompileOnlyService::CompileAheadOfTime(
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() =
instance.result_layout->ToProto();
if (options.has_static_device_assignment()) {
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -82,12 +83,24 @@ class AotCompilationOptions {
const DebugOptions& debug_options() const { return debug_options_; }
DebugOptions* mutable_debug_options() { return &debug_options_; }
bool has_static_device_assignment() const {
return static_device_assignment_.has_value();
}
const DeviceAssignment& static_device_assignment() const {
CHECK(static_device_assignment_.has_value());
return *static_device_assignment_;
}
void set_static_device_assignment(const DeviceAssignment& device_assignment) {
static_device_assignment_ = device_assignment;
}
protected:
AotCompilationOptions();
private:
DeviceMemoryAllocator* device_allocator_ = nullptr;
DebugOptions debug_options_;
absl::optional<DeviceAssignment> static_device_assignment_;
};
// Abstract superclass describing metadata produced during ahead-of-time

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
@ -319,6 +320,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
config->set_intra_op_parallelism_threads(
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
}
if (execution_options != nullptr &&
execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(
auto device_assignment,
DeviceAssignment::Deserialize(execution_options->device_assignment()));
config->set_static_device_assignment(*device_assignment);
}
return std::move(config);
}

View File

@ -288,6 +288,10 @@ message ExecutionOptions {
// Number of replicas of the computation to run. If zero, uses the default
// number of replicas for the XLA service.
int32 num_replicas = 6;
// This optional field specifies the device assignment if known at compile
// time.
DeviceAssignmentProto device_assignment = 7;
}
message GetDeviceHandlesRequest {