[XLA] Add the notion of a RunId.

A RunId uniquely identifies a particular execution of a model in XLA.  It's
used in a later patch to differentiate between multiple, concurrent runs of an
XLA module.

Pass the RunId to all XLA:GPU thunks.

PiperOrigin-RevId: 249799806
This commit is contained in:
Justin Lebar 2019-05-24 02:31:20 -07:00 committed by TensorFlower Gardener
parent 5d51c3a37c
commit c709683965
45 changed files with 122 additions and 60 deletions

View File

@ -575,6 +575,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":types",
"@com_google_absl//absl/strings",
],
)

View File

@ -15,8 +15,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
#include <atomic>
#include "absl/strings/str_cat.h"
namespace xla {
RunId::RunId() {
static std::atomic<int64> counter{0};
data_ = counter.fetch_add(1);
}
bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; }
std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); }
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
int device_ordinal) {
device_ordinal_ = device_ordinal;
@ -94,4 +107,11 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) {
run_id_ = id;
return *this;
}
RunId ExecutableRunOptions::run_id() const { return run_id_; }
} // namespace xla

View File

@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
#include <string>
#include "tensorflow/compiler/xla/types.h"
// These classes are forward declared so that ExecutableRunOptions can be linked
// into an XLA-compiled binary without having to link all of the pointed-to
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
@ -35,6 +39,31 @@ namespace xla {
class DeviceAssignment;
class ExecutionProfile;
// A unique identifier for a particular "logical execution" of an XLA model.
//
// A logical execution might encompass multiple executions of one or more
// HloModules. Runs that are part of the same logical execution can
// communicate via collective ops (e.g. kAllToAll), whereas runs that are part
// of different logical executions are isolated.
class RunId {
public:
// Creates a new, unique RunId.
RunId();
RunId(const RunId&) = default;
RunId& operator=(const RunId&) = default;
friend bool operator==(const RunId& a, const RunId& b);
std::string ToString() const;
template <typename H>
friend H AbslHashValue(H h, const RunId& id) {
return H::combine(std::move(h), id.data_);
}
private:
int64 data_;
};
// Class containing options for running a LocalExecutable.
class ExecutableRunOptions {
public:
@ -87,6 +116,9 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_rng_seed(int rng_seed);
int rng_seed() const;
ExecutableRunOptions& set_run_id(RunId id);
RunId run_id() const;
private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
@ -96,6 +128,7 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
stream_executor::Stream* host_to_device_stream_ = nullptr;
RunId run_id_;
};
} // namespace xla

View File

@ -333,6 +333,7 @@ cc_library(
deps = [
":buffer_allocations",
":hlo_execution_profiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@ -347,6 +348,7 @@ tf_cuda_library(
":buffer_allocations",
":hlo_execution_profiler",
":thunk",
"//tensorflow/compiler/xla:refcounting_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/synchronization",
"//tensorflow/compiler/xla:util",

View File

@ -52,7 +52,7 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
Status CholeskyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
VLOG(3) << "type=" << PrimitiveType_Name(type_)
<< " uplo=" << se::blas::UpperLowerString(uplo_)
<< " batch_size=" << batch_size_ << " n=" << n_

View File

@ -52,7 +52,7 @@ class CholeskyThunk : public Thunk {
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
Status ConditionalThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& run_id, HloExecutionProfiler* profiler) {
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// Copy the predicate value from device.
int32 branch_index = -1;
@ -89,7 +89,7 @@ Status ConditionalThunk::ExecuteOnStream(
// Execute the branch computation corresponding to the value of branch_index.
profiler->StartHloComputation();
TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(
buffer_allocations, stream, profiler));
buffer_allocations, stream, run_id, profiler));
profiler->FinishHloComputation(
hlo_instruction()->branch_computation(branch_index));

View File

@ -54,7 +54,7 @@ class ConditionalThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -43,7 +43,7 @@ ConvolutionThunk::ConvolutionThunk(
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
std::vector<se::DeviceMemoryBase> operand_se_buffers;
for (const auto& buffer : operand_buffers_) {
operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));

View File

@ -54,7 +54,7 @@ class ConvolutionThunk : public Thunk {
// Does the convolution for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -32,7 +32,7 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
Status HostToDeviceCopyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
@ -51,7 +51,7 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
se::DeviceMemoryBase source_data =

View File

@ -41,7 +41,7 @@ class HostToDeviceCopyThunk : public Thunk {
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:
@ -65,7 +65,7 @@ class DeviceToDeviceCopyThunk : public Thunk {
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -100,7 +100,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
std::tie(operand_desc, scale_offset_desc) =
@ -164,7 +164,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
@ -250,7 +250,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;

View File

@ -61,7 +61,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
const CudnnBatchNormForwardInferenceThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:
@ -92,7 +92,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
const CudnnBatchNormForwardTrainingThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:
@ -126,7 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -50,7 +50,7 @@ CustomCallThunk::CustomCallThunk(
Status CustomCallThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
auto gpu_stream = se::gpu::AsGpuStreamValue(stream);
auto typed_call_target =

View File

@ -45,7 +45,7 @@ class CustomCallThunk : public Thunk {
const HloInstruction* instr);
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -107,7 +107,7 @@ FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
output_shape_(output_shape) {}
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& /*run_id*/,
HloExecutionProfiler* profiler) {
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);

View File

@ -73,7 +73,7 @@ class FftThunk : public Thunk {
// Does the FFT for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -41,7 +41,7 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
}
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
@ -49,8 +49,8 @@ Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
for (int64 i = 0; i < loop_limit_; ++i) {
profiler->StartHloComputation();
// Invoke loop body thunk sequence.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
stream, profiler));
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
buffer_allocations, stream, run_id, profiler));
profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();

View File

@ -40,7 +40,7 @@ class ForThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -446,7 +446,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
implements_whole_instruction_(implements_whole_instruction) {}
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& /*run_id*/,
HloExecutionProfiler* profiler) {
auto fn = [&]() {
switch (output_shape_.element_type()) {

View File

@ -50,7 +50,7 @@ class GemmThunk : public Thunk {
// Does the gemm operation for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -137,7 +137,8 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
thunk->ExecuteOnStream(buffer_allocations, stream,
run_options->run_options().run_id(), &profiler));
if (thunk_schedule_->Depended(thunk)) {
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
finish_event->Init();

View File

@ -28,7 +28,7 @@ InfeedThunk::InfeedThunk(
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& /*run_id*/,
HloExecutionProfiler* profiler) {
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();

View File

@ -41,7 +41,7 @@ class InfeedThunk : public Thunk {
InfeedThunk& operator=(const InfeedThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -70,7 +70,7 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
}
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& /*run_id*/,
HloExecutionProfiler* profiler) {
// Load the kernel.
se::StreamExecutor* executor = stream->parent();

View File

@ -63,7 +63,7 @@ class KernelThunk : public Thunk {
// Executes the kernel for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -23,7 +23,7 @@ namespace gpu {
Status MemzeroThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemZero(&dest_data, dest_data.size());
@ -32,7 +32,7 @@ Status MemzeroThunk::ExecuteOnStream(
Status Memset32BitValueThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemset32(&dest_data, value_, dest_data.size());

View File

@ -37,7 +37,7 @@ class MemzeroThunk : public Thunk {
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:
@ -54,7 +54,7 @@ class Memset32BitValueThunk : public Thunk {
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId&,
HloExecutionProfiler* profiler) override;
private:

View File

@ -509,7 +509,7 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
Status NcclAllReduceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& run_id, HloExecutionProfiler* profiler) {
auto* global_rendezvous = GetGlobalRendezvous();
ParticipantData participant;
@ -553,7 +553,7 @@ NcclAllReduceThunk::~NcclAllReduceThunk() {
Status NcclAllReduceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& run_id, HloExecutionProfiler* profiler) {
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");

View File

@ -57,7 +57,7 @@ class NcclAllReduceThunk : public Thunk {
~NcclAllReduceThunk() override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -30,7 +30,7 @@ OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
Status OutfeedThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());

View File

@ -39,7 +39,7 @@ class OutfeedThunk : public Thunk {
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -35,11 +35,11 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
Status SequentialThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& run_id, HloExecutionProfiler* profiler) {
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (const auto& thunk : thunks_) {
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
thunk->ExecuteOnStream(buffer_allocations, stream, run_id, profiler));
}
return Status::OK();
}

View File

@ -42,7 +42,7 @@ class SequentialThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -92,7 +93,7 @@ class Thunk {
//
// Precondition: Initialize(stream->parent()) has been called.
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) = 0;
protected:

View File

@ -70,7 +70,7 @@ TriangularSolveThunk::TriangularSolveThunk(
Status TriangularSolveThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
<< " side=" << se::blas::SideString(side_)
<< " diagonal=" << se::blas::DiagonalString(unit_diagonal_)

View File

@ -49,7 +49,7 @@ class TriangularSolveThunk : public Thunk {
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -23,7 +23,7 @@ namespace xla {
namespace gpu {
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& /*run_id*/,
HloExecutionProfiler* profiler) {
auto size = tuple_element_buffers_.size();
auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);

View File

@ -46,7 +46,7 @@ class TupleThunk : public Thunk {
TupleThunk& operator=(const TupleThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -48,7 +48,7 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
}
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) {
se::DeviceMemoryBase condition_result_data =
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
@ -59,7 +59,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
profiler->StartHloComputation();
VLOG(3) << "Executing condition computation";
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
buffer_allocations, stream, profiler));
buffer_allocations, stream, run_id, profiler));
profiler->FinishHloComputation(hlo_instruction()->while_condition());
// Copy the result of condition computation and break the loop if 'false'.
@ -83,8 +83,8 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << "Executing body computation";
// Invoke thunk sequence for while 'body' computation, and pass on
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
stream, profiler));
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
buffer_allocations, stream, run_id, profiler));
profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();

View File

@ -49,7 +49,7 @@ class WhileThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
se::Stream* stream, const RunId& run_id,
HloExecutionProfiler* profiler) override;
private:

View File

@ -207,7 +207,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
stream.Init();
ServiceExecutableRunOptions service_run_options =
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
nullptr);
nullptr, RunId());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
@ -243,7 +243,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
stream.Init();
ServiceExecutableRunOptions service_run_options =
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
nullptr);
nullptr, RunId());
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer retval,
@ -294,6 +294,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
options.num_replicas * options.arguments.size() + 1);
std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
int64 index = 0;
RunId run_id;
for (int64 i = 0; i < options.num_replicas; ++i) {
int64 device = (*device_assignment)(i, 0);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
@ -301,7 +302,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
streams.push_back(absl::make_unique<se::Stream>(executor));
streams.back()->Init();
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
device, streams.back().get(), device_assignment));
device, streams.back().get(), device_assignment, run_id));
// Copy arguments to device.
for (const Literal* argument : options.arguments) {
@ -443,7 +444,8 @@ StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
}
ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
int64 device, se::Stream* stream, DeviceAssignment* device_assignment) {
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
RunId run_id) {
ExecutableRunOptions run_options;
run_options.set_device_ordinal(device);
run_options.set_stream(stream);
@ -453,6 +455,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
if (device_assignment != nullptr) {
run_options.set_device_assignment(device_assignment);
}
run_options.set_run_id(run_id);
return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
}

View File

@ -206,7 +206,8 @@ class HloRunner {
// will be used to configure the replication parameters. Replicated executions
// should pass the device_assignment parameter.
ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
int64 device, se::Stream* stream, DeviceAssignment* device_assignment);
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
RunId run_id);
std::unique_ptr<Backend> backend_;
};

View File

@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// Class containing options for running a LocalExecutable and other auxiliary
// data, now only a stream cache for GPU backend.
// data.
class ServiceExecutableRunOptions {
public:
using StreamBorrower = std::function<StatusOr<StreamPool::Ptr>(int)>;