Add a launch id field in run options and hlo module config.

PiperOrigin-RevId: 307922589
Change-Id: Ie1ea0b389e5228f827d570086799227983035f81
This commit is contained in:
Yunxing Dai 2020-04-22 16:06:22 -07:00 committed by TensorFlower Gardener
parent 9e6096a772
commit 49e59a8cad
4 changed files with 21 additions and 0 deletions

View File

@ -127,6 +127,13 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_rng_seed(int rng_seed);
int rng_seed() const;
ExecutableRunOptions& set_launch_id(int32 launch_id) {
launch_id_ = launch_id;
return *this;
}
int32 launch_id() const { return launch_id_; }
ExecutableRunOptions& set_run_id(RunId id);
RunId run_id() const;
@ -153,6 +160,7 @@ class ExecutableRunOptions {
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
int32 launch_id_ = 0;
stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_;

View File

@ -108,6 +108,12 @@ class HloModuleConfig {
void set_seed(uint64 seed) { seed_ = seed; }
uint64 seed() const { return seed_; }
// Set the launch id of the program. Launch id identifies a set of programs
// that should be launched together.
void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; }
int32 launch_id() const { return launch_id_; }
void set_replica_count(int64 replica_count) {
replica_count_ = replica_count;
}
@ -197,6 +203,9 @@ class HloModuleConfig {
// Module/graph-level seed handle.
uint64 seed_ = 0;
// Program id that identifies a set of program to be launched together.
int32 launch_id_ = 0;
// The number of replicas (data parallelism) to compile this binary for.
int64 replica_count_ = 1;

View File

@ -314,6 +314,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
config->set_num_partitions(execution_options->num_partitions());
}
config->set_seed(execution_options->seed());
config->set_launch_id(execution_options->launch_id());
config->set_debug_options(execution_options->debug_options());
} else {
config->set_replica_count(options_.number_of_replicas());

View File

@ -319,6 +319,9 @@ message ExecutionOptions {
// Number of partitions of the computation to run (model parallelism).
// If zero, uses the default number of partitions for the XLA service.
int32 num_partitions = 9;
// Used to identify a set of programs that should be launch together.
int32 launch_id = 10;
}
message GetDeviceHandlesRequest {