diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 6981b35975f..43ee0fdd820 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -127,6 +127,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_rng_seed(int rng_seed); int rng_seed() const; + ExecutableRunOptions& set_launch_id(int32 launch_id) { + launch_id_ = launch_id; + return *this; + } + + int32 launch_id() const { return launch_id_; } + ExecutableRunOptions& set_run_id(RunId id); RunId run_id() const; @@ -153,6 +160,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + int32 launch_id_ = 0; stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; RunId run_id_; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index d90a1485441..5b80a6adca2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -108,6 +108,12 @@ class HloModuleConfig { void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } + // Set the launch id of the program. Launch id identifies a set of programs + // that should be launched together. + void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; } + + int32 launch_id() const { return launch_id_; } + void set_replica_count(int64 replica_count) { replica_count_ = replica_count; } @@ -197,6 +203,9 @@ class HloModuleConfig { // Module/graph-level seed handle. uint64 seed_ = 0; + // Program id that identifies a set of program to be launched together. + int32 launch_id_ = 0; + // The number of replicas (data parallelism) to compile this binary for. int64 replica_count_ = 1; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index e12e1577211..ab71c30dcae 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -314,6 +314,7 @@ StatusOr> Service::CreateModuleConfig( config->set_num_partitions(execution_options->num_partitions()); } config->set_seed(execution_options->seed()); + config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); } else { config->set_replica_count(options_.number_of_replicas()); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f8bd7a0750e..c8ba08fc351 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -319,6 +319,9 @@ message ExecutionOptions { // Number of partitions of the computation to run (model parallelism). // If zero, uses the default number of partitions for the XLA service. int32 num_partitions = 9; + + // Used to identify a set of programs that should be launch together. + int32 launch_id = 10; } message GetDeviceHandlesRequest {