Add a launch id field in run options and hlo module config.
PiperOrigin-RevId: 307922589 Change-Id: Ie1ea0b389e5228f827d570086799227983035f81
This commit is contained in:
parent
9e6096a772
commit
49e59a8cad
@ -127,6 +127,13 @@ class ExecutableRunOptions {
|
||||
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
||||
int rng_seed() const;
|
||||
|
||||
ExecutableRunOptions& set_launch_id(int32 launch_id) {
|
||||
launch_id_ = launch_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32 launch_id() const { return launch_id_; }
|
||||
|
||||
ExecutableRunOptions& set_run_id(RunId id);
|
||||
RunId run_id() const;
|
||||
|
||||
@ -153,6 +160,7 @@ class ExecutableRunOptions {
|
||||
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
||||
ExecutionProfile* execution_profile_ = nullptr;
|
||||
int rng_seed_ = 0;
|
||||
int32 launch_id_ = 0;
|
||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||
ThenExecuteFunction* then_execute_function_ = nullptr;
|
||||
RunId run_id_;
|
||||
|
@ -108,6 +108,12 @@ class HloModuleConfig {
|
||||
void set_seed(uint64 seed) { seed_ = seed; }
|
||||
uint64 seed() const { return seed_; }
|
||||
|
||||
// Set the launch id of the program. Launch id identifies a set of programs
|
||||
// that should be launched together.
|
||||
void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; }
|
||||
|
||||
int32 launch_id() const { return launch_id_; }
|
||||
|
||||
void set_replica_count(int64 replica_count) {
|
||||
replica_count_ = replica_count;
|
||||
}
|
||||
@ -197,6 +203,9 @@ class HloModuleConfig {
|
||||
// Module/graph-level seed handle.
|
||||
uint64 seed_ = 0;
|
||||
|
||||
// Program id that identifies a set of program to be launched together.
|
||||
int32 launch_id_ = 0;
|
||||
|
||||
// The number of replicas (data parallelism) to compile this binary for.
|
||||
int64 replica_count_ = 1;
|
||||
|
||||
|
@ -314,6 +314,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
config->set_num_partitions(execution_options->num_partitions());
|
||||
}
|
||||
config->set_seed(execution_options->seed());
|
||||
config->set_launch_id(execution_options->launch_id());
|
||||
config->set_debug_options(execution_options->debug_options());
|
||||
} else {
|
||||
config->set_replica_count(options_.number_of_replicas());
|
||||
|
@ -319,6 +319,9 @@ message ExecutionOptions {
|
||||
// Number of partitions of the computation to run (model parallelism).
|
||||
// If zero, uses the default number of partitions for the XLA service.
|
||||
int32 num_partitions = 9;
|
||||
|
||||
// Used to identify a set of programs that should be launch together.
|
||||
int32 launch_id = 10;
|
||||
}
|
||||
|
||||
message GetDeviceHandlesRequest {
|
||||
|
Loading…
Reference in New Issue
Block a user