[XLA] Add the notion of a RunId.
A RunId uniquely identifies a particular execution of a model in XLA. It's used in a later patch to differentiate between multiple, concurrent runs of an XLA module. Pass the RunId to all XLA:GPU thunks. PiperOrigin-RevId: 249799806
This commit is contained in:
parent
5d51c3a37c
commit
c709683965
@ -575,6 +575,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":types",
|
":types",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,8 +15,21 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
RunId::RunId() {
|
||||||
|
static std::atomic<int64> counter{0};
|
||||||
|
data_ = counter.fetch_add(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; }
|
||||||
|
|
||||||
|
std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); }
|
||||||
|
|
||||||
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
|
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
device_ordinal_ = device_ordinal;
|
device_ordinal_ = device_ordinal;
|
||||||
@ -94,4 +107,11 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
|
|||||||
|
|
||||||
int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
|
int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
|
||||||
|
|
||||||
|
ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) {
|
||||||
|
run_id_ = id;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
RunId ExecutableRunOptions::run_id() const { return run_id_; }
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
// These classes are forward declared so that ExecutableRunOptions can be linked
|
// These classes are forward declared so that ExecutableRunOptions can be linked
|
||||||
// into an XLA-compiled binary without having to link all of the pointed-to
|
// into an XLA-compiled binary without having to link all of the pointed-to
|
||||||
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
|
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
|
||||||
@ -35,6 +39,31 @@ namespace xla {
|
|||||||
class DeviceAssignment;
|
class DeviceAssignment;
|
||||||
class ExecutionProfile;
|
class ExecutionProfile;
|
||||||
|
|
||||||
|
// A unique identifier for a particular "logical execution" of an XLA model.
|
||||||
|
//
|
||||||
|
// A logical execution might encompass multiple executions of one or more
|
||||||
|
// HloModules. Runs that are part of the same logical execution can
|
||||||
|
// communicate via collective ops (e.g. kAllToAll), whereas runs that are part
|
||||||
|
// of different logical executions are isolated.
|
||||||
|
class RunId {
|
||||||
|
public:
|
||||||
|
// Creates a new, unique RunId.
|
||||||
|
RunId();
|
||||||
|
|
||||||
|
RunId(const RunId&) = default;
|
||||||
|
RunId& operator=(const RunId&) = default;
|
||||||
|
friend bool operator==(const RunId& a, const RunId& b);
|
||||||
|
std::string ToString() const;
|
||||||
|
|
||||||
|
template <typename H>
|
||||||
|
friend H AbslHashValue(H h, const RunId& id) {
|
||||||
|
return H::combine(std::move(h), id.data_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 data_;
|
||||||
|
};
|
||||||
|
|
||||||
// Class containing options for running a LocalExecutable.
|
// Class containing options for running a LocalExecutable.
|
||||||
class ExecutableRunOptions {
|
class ExecutableRunOptions {
|
||||||
public:
|
public:
|
||||||
@ -87,6 +116,9 @@ class ExecutableRunOptions {
|
|||||||
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
||||||
int rng_seed() const;
|
int rng_seed() const;
|
||||||
|
|
||||||
|
ExecutableRunOptions& set_run_id(RunId id);
|
||||||
|
RunId run_id() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||||
int device_ordinal_ = -1;
|
int device_ordinal_ = -1;
|
||||||
@ -96,6 +128,7 @@ class ExecutableRunOptions {
|
|||||||
ExecutionProfile* execution_profile_ = nullptr;
|
ExecutionProfile* execution_profile_ = nullptr;
|
||||||
int rng_seed_ = 0;
|
int rng_seed_ = 0;
|
||||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||||
|
RunId run_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -333,6 +333,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":hlo_execution_profiler",
|
":hlo_execution_profiler",
|
||||||
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
@ -347,6 +348,7 @@ tf_cuda_library(
|
|||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":hlo_execution_profiler",
|
":hlo_execution_profiler",
|
||||||
":thunk",
|
":thunk",
|
||||||
|
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
@ -52,7 +52,7 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
|||||||
|
|
||||||
Status CholeskyThunk::ExecuteOnStream(
|
Status CholeskyThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
VLOG(3) << "type=" << PrimitiveType_Name(type_)
|
VLOG(3) << "type=" << PrimitiveType_Name(type_)
|
||||||
<< " uplo=" << se::blas::UpperLowerString(uplo_)
|
<< " uplo=" << se::blas::UpperLowerString(uplo_)
|
||||||
<< " batch_size=" << batch_size_ << " n=" << n_
|
<< " batch_size=" << batch_size_ << " n=" << n_
|
||||||
|
@ -52,7 +52,7 @@ class CholeskyThunk : public Thunk {
|
|||||||
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
|||||||
|
|
||||||
Status ConditionalThunk::ExecuteOnStream(
|
Status ConditionalThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
// Copy the predicate value from device.
|
// Copy the predicate value from device.
|
||||||
int32 branch_index = -1;
|
int32 branch_index = -1;
|
||||||
@ -89,7 +89,7 @@ Status ConditionalThunk::ExecuteOnStream(
|
|||||||
// Execute the branch computation corresponding to the value of branch_index.
|
// Execute the branch computation corresponding to the value of branch_index.
|
||||||
profiler->StartHloComputation();
|
profiler->StartHloComputation();
|
||||||
TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(
|
TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(
|
||||||
buffer_allocations, stream, profiler));
|
buffer_allocations, stream, run_id, profiler));
|
||||||
profiler->FinishHloComputation(
|
profiler->FinishHloComputation(
|
||||||
hlo_instruction()->branch_computation(branch_index));
|
hlo_instruction()->branch_computation(branch_index));
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class ConditionalThunk : public Thunk {
|
|||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -43,7 +43,7 @@ ConvolutionThunk::ConvolutionThunk(
|
|||||||
|
|
||||||
Status ConvolutionThunk::ExecuteOnStream(
|
Status ConvolutionThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
std::vector<se::DeviceMemoryBase> operand_se_buffers;
|
std::vector<se::DeviceMemoryBase> operand_se_buffers;
|
||||||
for (const auto& buffer : operand_buffers_) {
|
for (const auto& buffer : operand_buffers_) {
|
||||||
operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
|
operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
|
||||||
|
@ -54,7 +54,7 @@ class ConvolutionThunk : public Thunk {
|
|||||||
|
|
||||||
// Does the convolution for the thunk on "stream".
|
// Does the convolution for the thunk on "stream".
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -32,7 +32,7 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
|
|||||||
|
|
||||||
Status HostToDeviceCopyThunk::ExecuteOnStream(
|
Status HostToDeviceCopyThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
se::DeviceMemoryBase destination_data =
|
se::DeviceMemoryBase destination_data =
|
||||||
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
@ -51,7 +51,7 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
|
|||||||
|
|
||||||
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
|
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
se::DeviceMemoryBase destination_data =
|
se::DeviceMemoryBase destination_data =
|
||||||
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
buffer_allocations.GetDeviceAddress(destination_buffer_);
|
||||||
se::DeviceMemoryBase source_data =
|
se::DeviceMemoryBase source_data =
|
||||||
|
@ -41,7 +41,7 @@ class HostToDeviceCopyThunk : public Thunk {
|
|||||||
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
|
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -65,7 +65,7 @@ class DeviceToDeviceCopyThunk : public Thunk {
|
|||||||
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
|
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -100,7 +100,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
|||||||
|
|
||||||
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
dnn::BatchDescriptor operand_desc;
|
dnn::BatchDescriptor operand_desc;
|
||||||
dnn::BatchDescriptor scale_offset_desc;
|
dnn::BatchDescriptor scale_offset_desc;
|
||||||
std::tie(operand_desc, scale_offset_desc) =
|
std::tie(operand_desc, scale_offset_desc) =
|
||||||
@ -164,7 +164,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
|||||||
|
|
||||||
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
dnn::BatchDescriptor operand_desc;
|
dnn::BatchDescriptor operand_desc;
|
||||||
dnn::BatchDescriptor scale_offset_desc;
|
dnn::BatchDescriptor scale_offset_desc;
|
||||||
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
|
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
|
||||||
@ -250,7 +250,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
|||||||
|
|
||||||
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
dnn::BatchDescriptor operand_desc;
|
dnn::BatchDescriptor operand_desc;
|
||||||
dnn::BatchDescriptor scale_offset_desc;
|
dnn::BatchDescriptor scale_offset_desc;
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
|||||||
const CudnnBatchNormForwardInferenceThunk&) = delete;
|
const CudnnBatchNormForwardInferenceThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -92,7 +92,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
|||||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -126,7 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
|||||||
delete;
|
delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -50,7 +50,7 @@ CustomCallThunk::CustomCallThunk(
|
|||||||
|
|
||||||
Status CustomCallThunk::ExecuteOnStream(
|
Status CustomCallThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
|
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
|
||||||
auto gpu_stream = se::gpu::AsGpuStreamValue(stream);
|
auto gpu_stream = se::gpu::AsGpuStreamValue(stream);
|
||||||
auto typed_call_target =
|
auto typed_call_target =
|
||||||
|
@ -45,7 +45,7 @@ class CustomCallThunk : public Thunk {
|
|||||||
const HloInstruction* instr);
|
const HloInstruction* instr);
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -107,7 +107,7 @@ FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
|||||||
output_shape_(output_shape) {}
|
output_shape_(output_shape) {}
|
||||||
|
|
||||||
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& /*run_id*/,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
|
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
|
||||||
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
|
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
|
||||||
|
@ -73,7 +73,7 @@ class FftThunk : public Thunk {
|
|||||||
|
|
||||||
// Does the FFT for the thunk on "stream".
|
// Does the FFT for the thunk on "stream".
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -41,7 +41,7 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
||||||
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
|
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
|
||||||
@ -49,8 +49,8 @@ Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
|||||||
for (int64 i = 0; i < loop_limit_; ++i) {
|
for (int64 i = 0; i < loop_limit_; ++i) {
|
||||||
profiler->StartHloComputation();
|
profiler->StartHloComputation();
|
||||||
// Invoke loop body thunk sequence.
|
// Invoke loop body thunk sequence.
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
|
||||||
stream, profiler));
|
buffer_allocations, stream, run_id, profiler));
|
||||||
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -40,7 +40,7 @@ class ForThunk : public Thunk {
|
|||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -446,7 +446,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
|||||||
implements_whole_instruction_(implements_whole_instruction) {}
|
implements_whole_instruction_(implements_whole_instruction) {}
|
||||||
|
|
||||||
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& /*run_id*/,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
auto fn = [&]() {
|
auto fn = [&]() {
|
||||||
switch (output_shape_.element_type()) {
|
switch (output_shape_.element_type()) {
|
||||||
|
@ -50,7 +50,7 @@ class GemmThunk : public Thunk {
|
|||||||
|
|
||||||
// Does the gemm operation for the thunk on "stream", which must be non-null.
|
// Does the gemm operation for the thunk on "stream", which must be non-null.
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -137,7 +137,8 @@ Status GpuExecutable::ExecuteThunks(
|
|||||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||||
<< stream_no;
|
<< stream_no;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
|
thunk->ExecuteOnStream(buffer_allocations, stream,
|
||||||
|
run_options->run_options().run_id(), &profiler));
|
||||||
if (thunk_schedule_->Depended(thunk)) {
|
if (thunk_schedule_->Depended(thunk)) {
|
||||||
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
|
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
|
||||||
finish_event->Init();
|
finish_event->Init();
|
||||||
|
@ -28,7 +28,7 @@ InfeedThunk::InfeedThunk(
|
|||||||
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
|
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
|
||||||
|
|
||||||
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& /*run_id*/,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
|
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class InfeedThunk : public Thunk {
|
|||||||
InfeedThunk& operator=(const InfeedThunk&) = delete;
|
InfeedThunk& operator=(const InfeedThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -70,7 +70,7 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& /*run_id*/,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
// Load the kernel.
|
// Load the kernel.
|
||||||
se::StreamExecutor* executor = stream->parent();
|
se::StreamExecutor* executor = stream->parent();
|
||||||
|
@ -63,7 +63,7 @@ class KernelThunk : public Thunk {
|
|||||||
|
|
||||||
// Executes the kernel for the thunk on "stream", which must be non-null.
|
// Executes the kernel for the thunk on "stream", which must be non-null.
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -23,7 +23,7 @@ namespace gpu {
|
|||||||
|
|
||||||
Status MemzeroThunk::ExecuteOnStream(
|
Status MemzeroThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
stream->ThenMemZero(&dest_data, dest_data.size());
|
stream->ThenMemZero(&dest_data, dest_data.size());
|
||||||
@ -32,7 +32,7 @@ Status MemzeroThunk::ExecuteOnStream(
|
|||||||
|
|
||||||
Status Memset32BitValueThunk::ExecuteOnStream(
|
Status Memset32BitValueThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
stream->ThenMemset32(&dest_data, value_, dest_data.size());
|
stream->ThenMemset32(&dest_data, value_, dest_data.size());
|
||||||
|
@ -37,7 +37,7 @@ class MemzeroThunk : public Thunk {
|
|||||||
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
|
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -54,7 +54,7 @@ class Memset32BitValueThunk : public Thunk {
|
|||||||
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
|
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId&,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -509,7 +509,7 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
|||||||
|
|
||||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||||
auto* global_rendezvous = GetGlobalRendezvous();
|
auto* global_rendezvous = GetGlobalRendezvous();
|
||||||
|
|
||||||
ParticipantData participant;
|
ParticipantData participant;
|
||||||
@ -553,7 +553,7 @@ NcclAllReduceThunk::~NcclAllReduceThunk() {
|
|||||||
|
|
||||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"NCCL support is not available: this binary was not built with a CUDA "
|
"NCCL support is not available: this binary was not built with a CUDA "
|
||||||
"compiler, which is necessary to build the NCCL source library.");
|
"compiler, which is necessary to build the NCCL source library.");
|
||||||
|
@ -57,7 +57,7 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
~NcclAllReduceThunk() override;
|
~NcclAllReduceThunk() override;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -30,7 +30,7 @@ OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
|||||||
|
|
||||||
Status OutfeedThunk::ExecuteOnStream(
|
Status OutfeedThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
|
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
|
||||||
|
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
|
@ -39,7 +39,7 @@ class OutfeedThunk : public Thunk {
|
|||||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -35,11 +35,11 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
|||||||
|
|
||||||
Status SequentialThunk::ExecuteOnStream(
|
Status SequentialThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& run_id, HloExecutionProfiler* profiler) {
|
||||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
for (const auto& thunk : thunks_) {
|
for (const auto& thunk : thunks_) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
|
thunk->ExecuteOnStream(buffer_allocations, stream, run_id, profiler));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ class SequentialThunk : public Thunk {
|
|||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -92,7 +93,7 @@ class Thunk {
|
|||||||
//
|
//
|
||||||
// Precondition: Initialize(stream->parent()) has been called.
|
// Precondition: Initialize(stream->parent()) has been called.
|
||||||
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) = 0;
|
HloExecutionProfiler* profiler) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -70,7 +70,7 @@ TriangularSolveThunk::TriangularSolveThunk(
|
|||||||
|
|
||||||
Status TriangularSolveThunk::ExecuteOnStream(
|
Status TriangularSolveThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
const RunId& /*run_id*/, HloExecutionProfiler* profiler) {
|
||||||
VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
|
VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
|
||||||
<< " side=" << se::blas::SideString(side_)
|
<< " side=" << se::blas::SideString(side_)
|
||||||
<< " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
|
<< " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
|
||||||
|
@ -49,7 +49,7 @@ class TriangularSolveThunk : public Thunk {
|
|||||||
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
|
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -23,7 +23,7 @@ namespace xla {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& /*run_id*/,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
auto size = tuple_element_buffers_.size();
|
auto size = tuple_element_buffers_.size();
|
||||||
auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
|
auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
|
||||||
|
@ -46,7 +46,7 @@ class TupleThunk : public Thunk {
|
|||||||
TupleThunk& operator=(const TupleThunk&) = delete;
|
TupleThunk& operator=(const TupleThunk&) = delete;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -48,7 +48,7 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
se::DeviceMemoryBase condition_result_data =
|
se::DeviceMemoryBase condition_result_data =
|
||||||
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
|
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
|
||||||
@ -59,7 +59,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
|||||||
profiler->StartHloComputation();
|
profiler->StartHloComputation();
|
||||||
VLOG(3) << "Executing condition computation";
|
VLOG(3) << "Executing condition computation";
|
||||||
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
|
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
|
||||||
buffer_allocations, stream, profiler));
|
buffer_allocations, stream, run_id, profiler));
|
||||||
profiler->FinishHloComputation(hlo_instruction()->while_condition());
|
profiler->FinishHloComputation(hlo_instruction()->while_condition());
|
||||||
|
|
||||||
// Copy the result of condition computation and break the loop if 'false'.
|
// Copy the result of condition computation and break the loop if 'false'.
|
||||||
@ -83,8 +83,8 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
|||||||
VLOG(3) << "Executing body computation";
|
VLOG(3) << "Executing body computation";
|
||||||
// Invoke thunk sequence for while 'body' computation, and pass on
|
// Invoke thunk sequence for while 'body' computation, and pass on
|
||||||
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(
|
||||||
stream, profiler));
|
buffer_allocations, stream, run_id, profiler));
|
||||||
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
profiler->FinishHloComputation(hlo_instruction()->while_body());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -49,7 +49,7 @@ class WhileThunk : public Thunk {
|
|||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream, const RunId& run_id,
|
||||||
HloExecutionProfiler* profiler) override;
|
HloExecutionProfiler* profiler) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -207,7 +207,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|||||||
stream.Init();
|
stream.Init();
|
||||||
ServiceExecutableRunOptions service_run_options =
|
ServiceExecutableRunOptions service_run_options =
|
||||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
||||||
nullptr);
|
nullptr, RunId());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
CreateExecutable(std::move(module), run_hlo_passes));
|
CreateExecutable(std::move(module), run_hlo_passes));
|
||||||
@ -243,7 +243,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|||||||
stream.Init();
|
stream.Init();
|
||||||
ServiceExecutableRunOptions service_run_options =
|
ServiceExecutableRunOptions service_run_options =
|
||||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
||||||
nullptr);
|
nullptr, RunId());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
ScopedShapedBuffer retval,
|
ScopedShapedBuffer retval,
|
||||||
@ -294,6 +294,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
|||||||
options.num_replicas * options.arguments.size() + 1);
|
options.num_replicas * options.arguments.size() + 1);
|
||||||
std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
|
std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
|
||||||
int64 index = 0;
|
int64 index = 0;
|
||||||
|
RunId run_id;
|
||||||
for (int64 i = 0; i < options.num_replicas; ++i) {
|
for (int64 i = 0; i < options.num_replicas; ++i) {
|
||||||
int64 device = (*device_assignment)(i, 0);
|
int64 device = (*device_assignment)(i, 0);
|
||||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||||
@ -301,7 +302,7 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
|||||||
streams.push_back(absl::make_unique<se::Stream>(executor));
|
streams.push_back(absl::make_unique<se::Stream>(executor));
|
||||||
streams.back()->Init();
|
streams.back()->Init();
|
||||||
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
|
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
|
||||||
device, streams.back().get(), device_assignment));
|
device, streams.back().get(), device_assignment, run_id));
|
||||||
|
|
||||||
// Copy arguments to device.
|
// Copy arguments to device.
|
||||||
for (const Literal* argument : options.arguments) {
|
for (const Literal* argument : options.arguments) {
|
||||||
@ -443,7 +444,8 @@ StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
|
ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
|
||||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment) {
|
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
|
||||||
|
RunId run_id) {
|
||||||
ExecutableRunOptions run_options;
|
ExecutableRunOptions run_options;
|
||||||
run_options.set_device_ordinal(device);
|
run_options.set_device_ordinal(device);
|
||||||
run_options.set_stream(stream);
|
run_options.set_stream(stream);
|
||||||
@ -453,6 +455,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
|
|||||||
if (device_assignment != nullptr) {
|
if (device_assignment != nullptr) {
|
||||||
run_options.set_device_assignment(device_assignment);
|
run_options.set_device_assignment(device_assignment);
|
||||||
}
|
}
|
||||||
|
run_options.set_run_id(run_id);
|
||||||
return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
|
return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,7 +206,8 @@ class HloRunner {
|
|||||||
// will be used to configure the replication parameters. Replicated executions
|
// will be used to configure the replication parameters. Replicated executions
|
||||||
// should pass the device_assignment parameter.
|
// should pass the device_assignment parameter.
|
||||||
ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
|
ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
|
||||||
int64 device, se::Stream* stream, DeviceAssignment* device_assignment);
|
int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
|
||||||
|
RunId run_id);
|
||||||
|
|
||||||
std::unique_ptr<Backend> backend_;
|
std::unique_ptr<Backend> backend_;
|
||||||
};
|
};
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Class containing options for running a LocalExecutable and other auxiliary
|
// Class containing options for running a LocalExecutable and other auxiliary
|
||||||
// data, now only a stream cache for GPU backend.
|
// data.
|
||||||
class ServiceExecutableRunOptions {
|
class ServiceExecutableRunOptions {
|
||||||
public:
|
public:
|
||||||
using StreamBorrower = std::function<StatusOr<StreamPool::Ptr>(int)>;
|
using StreamBorrower = std::function<StatusOr<StreamPool::Ptr>(int)>;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user