[XLA] Add the notion of a RunId.
A RunId uniquely identifies a particular execution of a model in XLA. It's used in a later patch to differentiate between multiple, concurrent runs of an XLA module. Pass the RunId to all XLA:GPU thunks. PiperOrigin-RevId: 249799806
This commit is contained in:
parent
5d51c3a37c
commit
c709683965
@ -575,6 +575,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,8 +15,21 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
RunId::RunId() {
|
||||
static std::atomic<int64> counter{0};
|
||||
data_ = counter.fetch_add(1);
|
||||
}
|
||||
|
||||
bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; }
|
||||
|
||||
std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); }
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
|
||||
int device_ordinal) {
|
||||
device_ordinal_ = device_ordinal;
|
||||
@ -94,4 +107,11 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
|
||||
|
||||
int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) {
|
||||
run_id_ = id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
RunId ExecutableRunOptions::run_id() const { return run_id_; }
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
// These classes are forward declared so that ExecutableRunOptions can be linked
|
||||
// into an XLA-compiled binary without having to link all of the pointed-to
|
||||
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
|
||||
@ -35,6 +39,31 @@ namespace xla {
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
|
||||
// A unique identifier for a particular "logical execution" of an XLA model.
|
||||
//
|
||||
// A logical execution might encompass multiple executions of one or more
|
||||
// HloModules. Runs that are part of the same logical execution can
|
||||
// communicate via collective ops (e.g. kAllToAll), whereas runs that are part
|
||||
// of different logical executions are isolated.
|
||||
class RunId {
|
||||
public:
|
||||
// Creates a new, unique RunId.
|
||||
RunId();
|
||||
|
||||
RunId(const RunId&) = default;
|
||||
RunId& operator=(const RunId&) = default;
|
||||
friend bool operator==(const RunId& a, const RunId& b);
|
||||
std::string ToString() const;
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const RunId& id) {
|
||||
return H::combine(std::move(h), id.data_);
|
||||
}
|
||||
|
||||
private:
|
||||
int64 data_;
|
||||
};
|
||||
|
||||
// Class containing options for running a LocalExecutable.
|
||||
class ExecutableRunOptions {
|
||||
public:
|
||||
@ -87,6 +116,9 @@ class ExecutableRunOptions {
|
||||
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
||||
int rng_seed() const;
|
||||
|
||||
ExecutableRunOptions& set_run_id(RunId id);
|
||||
RunId run_id() const;
|
||||
|
||||
private:
|
||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
@ -96,6 +128,7 @@ class ExecutableRunOptions {
|
||||
ExecutionProfile* execution_profile_ = nullptr;
|
||||
int rng_seed_ = 0;
|
||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||
RunId run_id_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -333,6 +333,7 @@ cc_library(
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":hlo_execution_profiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
@ -347,6 +348,7 @@ tf_cuda_library(
|
||||
":buffer_allocations",
|
||||
":hlo_execution_profiler",
|
||||
":thunk",
|
||||
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -52,7 +52,7 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
|
||||
Status CholeskyThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
VLOG(3) << "type=" << PrimitiveType_Name(type_)
|
||||
<< " uplo=" << se::blas::UpperLowerString(uplo_)
|
||||
<< " batch_size=" << batch_size_ << " n=" << n_
|
||||
|
@ -52,7 +52,7 @@ class CholeskyThunk : public Thunk {
|
||||
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
||||
|
||||
Status ConditionalThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
// Copy the predicate value from device.
|
||||
int32 branch_index = -1;
|
||||
@ -89,7 +89,7 @@ Status ConditionalThunk::ExecuteOnStream(
|
||||
// Execute the branch computation corresponding to the value of branch_index.
|
||||
profiler->StartHloComputation();
|
||||
TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(
|
||||
buffer_allocations, stream, profiler));
|
||||
buffer_allocations, stream, run_id, profiler));
|
||||
profiler->FinishHloComputation(
|
||||
hlo_instruction()->branch_computation(branch_index));
|
||||
|
||||
|
@ -54,7 +54,7 @@ class ConditionalThunk : public Thunk {
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -43,7 +43,7 @@ ConvolutionThunk::ConvolutionThunk(
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
std::vector<se::DeviceMemoryBase> operand_se_buffers;
|
||||
for (const auto& buffer : operand_buffers_) {
|
||||
operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
|
||||
|
@ -54,7 +54,7 @@ class ConvolutionThunk : public Thunk {
|
||||
|
||||
// Does the convolution for the thunk on "stream".
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -32,7 +32,7 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
|
||||
|
||||
Status HostToDeviceCopyThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
se::DeviceMemoryBase destination_data =
|
||||
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
@ -51,7 +51,7 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
|
||||
|
||||
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
se::DeviceMemoryBase destination_data =
|
||||
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
||||
se::DeviceMemoryBase source_data =
|
||||
|
@ -41,7 +41,7 @@ class HostToDeviceCopyThunk : public Thunk {
|
||||
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
@ -65,7 +65,7 @@ class DeviceToDeviceCopyThunk : public Thunk {
|
||||
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -100,7 +100,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
||||
|
||||
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
dnn::BatchDescriptor operand_desc;
|
||||
dnn::BatchDescriptor scale_offset_desc;
|
||||
std::tie(operand_desc, scale_offset_desc) =
|
||||
@ -164,7 +164,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
|
||||
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
dnn::BatchDescriptor operand_desc;
|
||||
dnn::BatchDescriptor scale_offset_desc;
|
||||
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
|
||||
@ -250,7 +250,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
|
||||
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
dnn::BatchDescriptor operand_desc;
|
||||
dnn::BatchDescriptor scale_offset_desc;
|
||||
|
||||
|
@ -61,7 +61,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
||||
const CudnnBatchNormForwardInferenceThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
@ -92,7 +92,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
@ -126,7 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -50,7 +50,7 @@ CustomCallThunk::CustomCallThunk(
|
||||
|
||||
Status CustomCallThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
|
||||
auto gpu_stream = se::gpu::AsGpuStreamValue(stream);
|
||||
auto typed_call_target =
|
||||
|
@ -45,7 +45,7 @@ class CustomCallThunk : public Thunk {
|
||||
const HloInstruction* instr);
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -107,7 +107,7 @@ FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
output_shape_(output_shape) {}
|
||||
|
||||
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& /*run_id*/,
|
||||
HloExecutionProfiler* profiler) {
|
||||
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
|
||||
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
|
||||
|
@ -73,7 +73,7 @@ class FftThunk : public Thunk {
|
||||
|
||||
// Does the FFT for the thunk on "stream".
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -41,7 +41,7 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
|
||||
}
|
||||
|
||||
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) {
|
||||
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
||||
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
|
||||
@ -49,8 +49,8 @@ Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
for (int64 i = 0; i < loop_limit_; ++i) {
|
||||
profiler->StartHloComputation();
|
||||
// Invoke loop body thunk sequence.
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
|
||||
stream, profiler));
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
|
||||
buffer_allocations, stream, run_id, profiler));
|
||||
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -40,7 +40,7 @@ class ForThunk : public Thunk {
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -446,7 +446,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
implements_whole_instruction_(implements_whole_instruction) {}
|
||||
|
||||
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& /*run_id*/,
|
||||
HloExecutionProfiler* profiler) {
|
||||
auto fn = [&]() {
|
||||
switch (output_shape_.element_type()) {
|
||||
|
@ -50,7 +50,7 @@ class GemmThunk : public Thunk {
|
||||
|
||||
// Does the gemm operation for the thunk on "stream", which must be non-null.
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -137,7 +137,8 @@ Status GpuExecutable::ExecuteThunks(
|
||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||
<< stream_no;
|
||||
TF_RETURN_IF_ERROR(
|
||||
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
|
||||
thunk->ExecuteOnStream(buffer_allocations, stream,
|
||||
run_options->run_options().run_id(), &profiler));
|
||||
if (thunk_schedule_->Depended(thunk)) {
|
||||
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
|
||||
finish_event->Init();
|
||||
|
@ -28,7 +28,7 @@ InfeedThunk::InfeedThunk(
|
||||
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
|
||||
|
||||
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& /*run_id*/,
|
||||
HloExecutionProfiler* profiler) {
|
||||
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
|
@ -41,7 +41,7 @@ class InfeedThunk : public Thunk {
|
||||
InfeedThunk& operator=(const InfeedThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -70,7 +70,7 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
|
||||
}
|
||||
|
||||
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& /*run_id*/,
|
||||
HloExecutionProfiler* profiler) {
|
||||
// Load the kernel.
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
|
@ -63,7 +63,7 @@ class KernelThunk : public Thunk {
|
||||
|
||||
// Executes the kernel for the thunk on "stream", which must be non-null.
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -23,7 +23,7 @@ namespace gpu {
|
||||
|
||||
Status MemzeroThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
stream->ThenMemZero(&dest_data, dest_data.size());
|
||||
@ -32,7 +32,7 @@ Status MemzeroThunk::ExecuteOnStream(
|
||||
|
||||
Status Memset32BitValueThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
stream->ThenMemset32(&dest_data, value_, dest_data.size());
|
||||
|
@ -37,7 +37,7 @@ class MemzeroThunk : public Thunk {
|
||||
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
@ -54,7 +54,7 @@ class Memset32BitValueThunk : public Thunk {
|
||||
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId&,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -509,7 +509,7 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||
auto* global_rendezvous = GetGlobalRendezvous();
|
||||
|
||||
ParticipantData participant;
|
||||
@ -553,7 +553,7 @@ NcclAllReduceThunk::~NcclAllReduceThunk() {
|
||||
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||
return Unimplemented(
|
||||
"NCCL support is not available: this binary was not built with a CUDA "
|
||||
"compiler, which is necessary to build the NCCL source library.");
|
||||
|
@ -57,7 +57,7 @@ class NcclAllReduceThunk : public Thunk {
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -30,7 +30,7 @@ OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
||||
|
||||
Status OutfeedThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
|
@ -39,7 +39,7 @@ class OutfeedThunk : public Thunk {
|
||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -35,11 +35,11 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
||||
|
||||
Status SequentialThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
for (const auto& thunk : thunks_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
|
||||
thunk->ExecuteOnStream(buffer_allocations, stream, run_id, profiler));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ class SequentialThunk : public Thunk {
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -92,7 +93,7 @@ class Thunk {
|
||||
//
|
||||
// Precondition: Initialize(stream->parent()) has been called.
|
||||
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) = 0;
|
||||
|
||||
protected:
|
||||
|
@ -70,7 +70,7 @@ TriangularSolveThunk::TriangularSolveThunk(
|
||||
|
||||
Status TriangularSolveThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||
VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
|
||||
<< " side=" << se::blas::SideString(side_)
|
||||
<< " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
|
||||
|
@ -49,7 +49,7 @@ class TriangularSolveThunk : public Thunk {
|
||||
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& /*run_id*/,
|
||||
HloExecutionProfiler* profiler) {
|
||||
auto size = tuple_element_buffers_.size();
|
||||
auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
|
||||
|
@ -46,7 +46,7 @@ class TupleThunk : public Thunk {
|
||||
TupleThunk& operator=(const TupleThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -48,7 +48,7 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
|
||||
}
|
||||
|
||||
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) {
|
||||
se::DeviceMemoryBase condition_result_data =
|
||||
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
|
||||
@ -59,7 +59,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
profiler->StartHloComputation();
|
||||
VLOG(3) << "Executing condition computation";
|
||||
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
|
||||
buffer_allocations, stream, profiler));
|
||||
buffer_allocations, stream, run_id, profiler));
|
||||
profiler->FinishHloComputation(hlo_instruction()->while_condition());
|
||||
|
||||
// Copy the result of condition computation and break the loop if 'false'.
|
||||
@ -83,8 +83,8 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
VLOG(3) << "Executing body computation";
|
||||
// Invoke thunk sequence for while 'body' computation, and pass on
|
||||
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
|
||||
stream, profiler));
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
|
||||
buffer_allocations, stream, run_id, profiler));
|
||||
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -49,7 +49,7 @@ class WhileThunk : public Thunk {
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
se::Stream* stream, const RunId& run_id,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
|
@ -207,7 +207,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
stream.Init();
|
||||
ServiceExecutableRunOptions service_run_options =
|
||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
||||
nullptr);
|
||||
nullptr, RunId());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
CreateExecutable(std::move(module), run_hlo_passes));
|
||||
@ -243,7 +243,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
stream.Init();
|
||||
ServiceExecutableRunOptions service_run_options =
|
||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
||||
nullptr);
|
||||
nullptr, RunId());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer retval,
|
||||
@ -294,6 +294,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
options.num_replicas * options.arguments.size() + 1);
|
||||
std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
|
||||
int64 index = 0;
|
||||
RunId run_id;
|
||||
for (int64 i = 0; i < options.num_replicas; ++i) {
|
||||
int64 device = (*device_assignment)(i, 0);
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
@ -301,7 +302,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
streams.push_back(absl::make_unique<se::Stream>(executor));
|
||||
streams.back()->Init();
|
||||
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
|
||||
device, streams.back().get(), device_assignment));
|
||||
device, streams.back().get(), device_assignment, run_id));
|
||||
|
||||
// Copy arguments to device.
|
||||
for (const Literal* argument : options.arguments) {
|
||||
@ -443,7 +444,8 @@ StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
|
||||
}
|
||||
|
||||
ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
|
||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment) {
|
||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
|
||||
RunId run_id) {
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_device_ordinal(device);
|
||||
run_options.set_stream(stream);
|
||||
@ -453,6 +455,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
|
||||
if (device_assignment != nullptr) {
|
||||
run_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
run_options.set_run_id(run_id);
|
||||
return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
|
||||
}
|
||||
|
||||
|
@ -206,7 +206,8 @@ class HloRunner {
|
||||
// will be used to configure the replication parameters. Replicated executions
|
||||
// should pass the device_assignment parameter.
|
||||
ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
|
||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment);
|
||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
|
||||
RunId run_id);
|
||||
|
||||
std::unique_ptr<Backend> backend_;
|
||||
};
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Class containing options for running a LocalExecutable and other auxiliary
|
||||
// data, now only a stream cache for GPU backend.
|
||||
// data.
|
||||
class ServiceExecutableRunOptions {
|
||||
public:
|
||||
using StreamBorrower = std::function<StatusOr<StreamPool::Ptr>(int)>;
|
||||
|
Loading…
Reference in New Issue
Block a user