From c709683965331d94ae196464f8d3f7bc633b1d7a Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 24 May 2019 02:31:20 -0700 Subject: [PATCH] [XLA] Add the notion of a RunId. A RunId uniquely identifies a particular execution of a model in XLA. It's used in a later patch to differentiate between multiple, concurrent runs of an XLA module. Pass the RunId to all XLA:GPU thunks. PiperOrigin-RevId: 249799806 --- tensorflow/compiler/xla/BUILD | 1 + .../compiler/xla/executable_run_options.cc | 20 +++++++++++ .../compiler/xla/executable_run_options.h | 33 +++++++++++++++++++ tensorflow/compiler/xla/service/gpu/BUILD | 2 ++ .../xla/service/gpu/cholesky_thunk.cc | 2 +- .../compiler/xla/service/gpu/cholesky_thunk.h | 2 +- .../xla/service/gpu/conditional_thunk.cc | 4 +-- .../xla/service/gpu/conditional_thunk.h | 2 +- .../xla/service/gpu/convolution_thunk.cc | 2 +- .../xla/service/gpu/convolution_thunk.h | 2 +- .../compiler/xla/service/gpu/copy_thunk.cc | 4 +-- .../compiler/xla/service/gpu/copy_thunk.h | 4 +-- .../xla/service/gpu/cudnn_batchnorm_thunk.cc | 6 ++-- .../xla/service/gpu/cudnn_batchnorm_thunk.h | 6 ++-- .../xla/service/gpu/custom_call_thunk.cc | 2 +- .../xla/service/gpu/custom_call_thunk.h | 2 +- .../compiler/xla/service/gpu/fft_thunk.cc | 2 +- .../compiler/xla/service/gpu/fft_thunk.h | 2 +- .../compiler/xla/service/gpu/for_thunk.cc | 6 ++-- .../compiler/xla/service/gpu/for_thunk.h | 2 +- .../compiler/xla/service/gpu/gemm_thunk.cc | 2 +- .../compiler/xla/service/gpu/gemm_thunk.h | 2 +- .../xla/service/gpu/gpu_executable.cc | 3 +- .../compiler/xla/service/gpu/infeed_thunk.cc | 2 +- .../compiler/xla/service/gpu/infeed_thunk.h | 2 +- .../compiler/xla/service/gpu/kernel_thunk.cc | 2 +- .../compiler/xla/service/gpu/kernel_thunk.h | 2 +- .../compiler/xla/service/gpu/memset_thunk.cc | 4 +-- .../compiler/xla/service/gpu/memset_thunk.h | 4 +-- .../xla/service/gpu/nccl_all_reduce_thunk.cc | 4 +-- .../xla/service/gpu/nccl_all_reduce_thunk.h | 2 +- .../compiler/xla/service/gpu/outfeed_thunk.cc | 2 +- .../compiler/xla/service/gpu/outfeed_thunk.h | 2 +- .../xla/service/gpu/sequential_thunk.cc | 4 +-- .../xla/service/gpu/sequential_thunk.h | 2 +- tensorflow/compiler/xla/service/gpu/thunk.h | 3 +- .../xla/service/gpu/triangular_solve_thunk.cc | 2 +- .../xla/service/gpu/triangular_solve_thunk.h | 2 +- .../compiler/xla/service/gpu/tuple_thunk.cc | 2 +- .../compiler/xla/service/gpu/tuple_thunk.h | 2 +- .../compiler/xla/service/gpu/while_thunk.cc | 8 ++--- .../compiler/xla/service/gpu/while_thunk.h | 2 +- tensorflow/compiler/xla/service/hlo_runner.cc | 11 ++++--- tensorflow/compiler/xla/service/hlo_runner.h | 3 +- .../service/service_executable_run_options.h | 2 +- 45 files changed, 122 insertions(+), 60 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 679471affee..4bd10195577 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -575,6 +575,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 39c90b60a09..1cfb449ebd0 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -15,8 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" +#include + +#include "absl/strings/str_cat.h" + namespace xla { +RunId::RunId() { + static std::atomic counter{0}; + data_ = counter.fetch_add(1); +} + +bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; } + +std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); } + ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; @@ -94,4 +107,11 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { int ExecutableRunOptions::rng_seed() const { return rng_seed_; } +ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) { + run_id_ = id; + return *this; +} + +RunId ExecutableRunOptions::run_id() const { return run_id_; } + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 84629593953..4de8148451b 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ +#include + +#include "tensorflow/compiler/xla/types.h" + // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -35,6 +39,31 @@ namespace xla { class DeviceAssignment; class ExecutionProfile; +// A unique identifier for a particular "logical execution" of an XLA model. +// +// A logical execution might encompass multiple executions of one or more +// HloModules. Runs that are part of the same logical execution can +// communicate via collective ops (e.g. kAllToAll), whereas runs that are part +// of different logical executions are isolated. +class RunId { + public: + // Creates a new, unique RunId. + RunId(); + + RunId(const RunId&) = default; + RunId& operator=(const RunId&) = default; + friend bool operator==(const RunId& a, const RunId& b); + std::string ToString() const; + + template + friend H AbslHashValue(H h, const RunId& id) { + return H::combine(std::move(h), id.data_); + } + + private: + int64 data_; +}; + // Class containing options for running a LocalExecutable. class ExecutableRunOptions { public: @@ -87,6 +116,9 @@ class ExecutableRunOptions { ExecutableRunOptions& set_rng_seed(int rng_seed); int rng_seed() const; + ExecutableRunOptions& set_run_id(RunId id); + RunId run_id() const; + private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -96,6 +128,7 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; stream_executor::Stream* host_to_device_stream_ = nullptr; + RunId run_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ec75f28a1d7..e75708c5c96 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -333,6 +333,7 @@ cc_library( deps = [ ":buffer_allocations", ":hlo_execution_profiler", + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -347,6 +348,7 @@ tf_cuda_library( ":buffer_allocations", ":hlo_execution_profiler", ":thunk", + "//tensorflow/compiler/xla:refcounting_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc index 7daef16cb62..84970a71ac3 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -52,7 +52,7 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, Status CholeskyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "type=" << PrimitiveType_Name(type_) << " uplo=" << se::blas::UpperLowerString(uplo_) << " batch_size=" << batch_size_ << " n=" << n_ diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h index cde245a7e8b..eb6f02baa8c 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h @@ -52,7 +52,7 @@ class CholeskyThunk : public Thunk { CholeskyThunk& operator=(const CholeskyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index ea639249826..90f797e7e15 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, Status ConditionalThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. int32 branch_index = -1; @@ -89,7 +89,7 @@ Status ConditionalThunk::ExecuteOnStream( // Execute the branch computation corresponding to the value of branch_index. profiler->StartHloComputation(); TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream( - buffer_allocations, stream, profiler)); + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation( hlo_instruction()->branch_computation(branch_index)); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index c0093ca6397..ca625f4a97b 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -54,7 +54,7 @@ class ConditionalThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index e1dffad3045..265a3f67020 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -43,7 +43,7 @@ ConvolutionThunk::ConvolutionThunk( Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { std::vector operand_se_buffers; for (const auto& buffer : operand_buffers_) { operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index c71515490c9..4a29164cbe6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -54,7 +54,7 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 92e03f94c11..62878cf864d 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -32,7 +32,7 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); @@ -51,7 +51,7 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 91564b520ac..30fb71f4c4e 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -41,7 +41,7 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -65,7 +65,7 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 5f6175abb87..3147bc66e3f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -100,7 +100,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; std::tie(operand_desc, scale_offset_desc) = @@ -164,7 +164,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; // The BatchNormTraining HLO outputs a tuple of three elements: output data, @@ -250,7 +250,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index d2143b39529..e0e6e86818f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -61,7 +61,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { const CudnnBatchNormForwardInferenceThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -92,7 +92,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const CudnnBatchNormForwardTrainingThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -126,7 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk { delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index f0f3152ac98..b521e36108b 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -50,7 +50,7 @@ CustomCallThunk::CustomCallThunk( Status CustomCallThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { // gpu_stream is CUstream or e.g. the equivalent type in ROCm. auto gpu_stream = se::gpu::AsGpuStreamValue(stream); auto typed_call_target = diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index 9011fa26ffa..6db7950e8e0 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -45,7 +45,7 @@ class CustomCallThunk : public Thunk { const HloInstruction* instr); Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 1609f0d60c4..55300a8d33a 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -107,7 +107,7 @@ FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, output_shape_(output_shape) {} Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index f653e4f12fe..12718db873b 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -73,7 +73,7 @@ class FftThunk : public Thunk { // Does the FFT for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 88f0b4d71c9..ee47fea38c3 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -41,7 +41,7 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) { VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " << (hlo_instruction() ? hlo_instruction()->ToString() : ""); @@ -49,8 +49,8 @@ Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, for (int64 i = 0; i < loop_limit_; ++i) { profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, - stream, profiler)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index c2d39071b29..e3bef820e57 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -40,7 +40,7 @@ class ForThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 9bbe1ab5a38..237c065cd73 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -446,7 +446,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { auto fn = [&]() { switch (output_shape_.element_type()) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index e4f07d04820..3cba1d5e169 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -50,7 +50,7 @@ class GemmThunk : public Thunk { // Does the gemm operation for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index b5dbe33b6e9..8be1655367f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -137,7 +137,8 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR( - thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); + thunk->ExecuteOnStream(buffer_allocations, stream, + run_options->run_options().run_id(), &profiler)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 676380c3b10..dbf590591c3 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -28,7 +28,7 @@ InfeedThunk::InfeedThunk( : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 59487e245b7..50d9c53d957 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -41,7 +41,7 @@ class InfeedThunk : public Thunk { InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index fbe22e3a18e..c85b35ed386 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -70,7 +70,7 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { } Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 2cea89e4e2a..e867904bcf2 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -63,7 +63,7 @@ class KernelThunk : public Thunk { // Executes the kernel for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index 9fd6cf7157e..7a5b14be7b0 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -23,7 +23,7 @@ namespace gpu { Status MemzeroThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemZero(&dest_data, dest_data.size()); @@ -32,7 +32,7 @@ Status MemzeroThunk::ExecuteOnStream( Status Memset32BitValueThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemset32(&dest_data, value_, dest_data.size()); diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index d1fec0bd76b..727f2441f39 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -37,7 +37,7 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: @@ -54,7 +54,7 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId&, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index d9accdd44ab..a98d532bd44 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -509,7 +509,7 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { auto* global_rendezvous = GetGlobalRendezvous(); ParticipantData participant; @@ -553,7 +553,7 @@ NcclAllReduceThunk::~NcclAllReduceThunk() { Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { return Unimplemented( "NCCL support is not available: this binary was not built with a CUDA " "compiler, which is necessary to build the NCCL source library."); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 9ff4fb187af..351a954492b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -57,7 +57,7 @@ class NcclAllReduceThunk : public Thunk { ~NcclAllReduceThunk() override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index e0f3e84a4cb..527305070b7 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -30,7 +30,7 @@ OutfeedThunk::OutfeedThunk(ShapeTree outfeed_slices, Status OutfeedThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index 8ed89f05f0c..5e7bc7cea1a 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -39,7 +39,7 @@ class OutfeedThunk : public Thunk { OutfeedThunk& operator=(const OutfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 84285be70a4..2f456938d92 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -35,11 +35,11 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& run_id, HloExecutionProfiler* profiler) { auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR( - thunk->ExecuteOnStream(buffer_allocations, stream, profiler)); + thunk->ExecuteOnStream(buffer_allocations, stream, run_id, profiler)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 3c4de1d1a6c..e617c99c2c9 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -42,7 +42,7 @@ class SequentialThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index bdd06718717..9670a3ece08 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -92,7 +93,7 @@ class Thunk { // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) = 0; protected: diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc index 5200a2af412..2635a7b3c45 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -70,7 +70,7 @@ TriangularSolveThunk::TriangularSolveThunk( Status TriangularSolveThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, - HloExecutionProfiler* profiler) { + const RunId& /*run_id*/, HloExecutionProfiler* profiler) { VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_) << " side=" << se::blas::SideString(side_) << " diagonal=" << se::blas::DiagonalString(unit_diagonal_) diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h index c947162ea32..94bf6bf6442 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -49,7 +49,7 @@ class TriangularSolveThunk : public Thunk { TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 989b542ff45..f7dda240367 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -23,7 +23,7 @@ namespace xla { namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& /*run_id*/, HloExecutionProfiler* profiler) { auto size = tuple_element_buffers_.size(); auto tuple_element_buffer_addresses = absl::make_unique(size); diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index dcdbf2cf3c2..47784c5c373 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -46,7 +46,7 @@ class TupleThunk : public Thunk { TupleThunk& operator=(const TupleThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index c4754fe3789..0223582f2a9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -48,7 +48,7 @@ Status WhileThunk::Initialize(const GpuExecutable& executable, } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) { se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -59,7 +59,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, profiler->StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( - buffer_allocations, stream, profiler)); + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_condition()); // Copy the result of condition computation and break the loop if 'false'. @@ -83,8 +83,8 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Executing body computation"; // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, - stream, profiler)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, run_id, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 9270f95ee67..97ac24f61cc 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -49,7 +49,7 @@ class WhileThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream, + se::Stream* stream, const RunId& run_id, HloExecutionProfiler* profiler) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5ba390acfd4..9fb0cd7e077 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -207,7 +207,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( stream.Init(); ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr); + nullptr, RunId()); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); @@ -243,7 +243,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( stream.Init(); ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr); + nullptr, RunId()); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer retval, @@ -294,6 +294,7 @@ StatusOr> HloRunner::ExecuteReplicated( options.num_replicas * options.arguments.size() + 1); std::vector> argument_buffer_slices; int64 index = 0; + RunId run_id; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = (*device_assignment)(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, @@ -301,7 +302,7 @@ StatusOr> HloRunner::ExecuteReplicated( streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( - device, streams.back().get(), device_assignment)); + device, streams.back().get(), device_assignment, run_id)); // Copy arguments to device. for (const Literal* argument : options.arguments) { @@ -443,7 +444,8 @@ StatusOr> HloRunner::CreateExecutable( } ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( - int64 device, se::Stream* stream, DeviceAssignment* device_assignment) { + int64 device, se::Stream* stream, DeviceAssignment* device_assignment, + RunId run_id) { ExecutableRunOptions run_options; run_options.set_device_ordinal(device); run_options.set_stream(stream); @@ -453,6 +455,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } + run_options.set_run_id(run_id); return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 7e666a8186e..c077ccd95fe 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -206,7 +206,8 @@ class HloRunner { // will be used to configure the replication parameters. Replicated executions // should pass the device_assignment parameter. ServiceExecutableRunOptions GetServiceRunOptionsForDevice( - int64 device, se::Stream* stream, DeviceAssignment* device_assignment); + int64 device, se::Stream* stream, DeviceAssignment* device_assignment, + RunId run_id); std::unique_ptr backend_; }; diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7fc66310ee7..58028aebe1f 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // Class containing options for running a LocalExecutable and other auxiliary -// data, now only a stream cache for GPU backend. +// data. class ServiceExecutableRunOptions { public: using StreamBorrower = std::function(int)>;