[XLA/GPU] Remove uses of Thunk::hlo_instruction() for profiling.

This CL consists of two steps:
* First, refactor all Thunks to take an ThunkInfo instead of const HloInstruction*. This will benefit future extensions to ThunkInfo as we move away from HloInstruction*.
* Secondly, change the data pipeline from:
    Emitter -> Thunk* -> hlo_instruction() -> profiler(HloInstruction*)
  to:
    Emitter -> Thunk with profile indices

The profile doesn't really depend on HloInstruction*, but just its pointer
identity. Removing the dependency on HloInstruction helps with MLIR migration.

PiperOrigin-RevId: 320687291
Change-Id: I7027d4c032f73ed615e5b520e01f3740781735be
This commit is contained in:
Tim Shen 2020-07-10 15:25:28 -07:00 committed by TensorFlower Gardener
parent aa47bcc6f1
commit 5bbf4a1d11
56 changed files with 366 additions and 295 deletions

View File

@ -242,7 +242,6 @@ cc_library(
deps = [
":backend_configs_cc",
":buffer_allocations",
":cudnn_batchnorm_runner",
":elemental_ir_emitter",
":gpu_constants",
":gpu_conv_runner",
@ -267,6 +266,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service:while_loop_analysis",
@ -282,7 +282,6 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:sort_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",

View File

@ -31,13 +31,13 @@ limitations under the License.
namespace xla {
namespace gpu {
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type, int64 batch_size, int64 n,
const HloInstruction* hlo)
: Thunk(Kind::kCholesky, hlo),
PrimitiveType type, int64 batch_size, int64 n)
: Thunk(Kind::kCholesky, thunk_info),
uplo_(options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper),
a_buffer_(a_buffer),
@ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
info_buffer_(info_buffer),
type_(type),
batch_size_(batch_size),
a_batch_stride_(n * n *
ShapeUtil::ByteSizeOfPrimitiveType(
hlo->operand(0)->shape().element_type())),
a_batch_stride_(
n * n *
ShapeUtil::ByteSizeOfPrimitiveType(
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
n_(n) {}
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {

View File

@ -41,12 +41,11 @@ namespace gpu {
class CholeskyThunk : public Thunk {
public:
static StatusOr<int64> ScratchBufferSize(int64 n);
CholeskyThunk(const CholeskyOptions& options,
CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type,
int64 batch_size, int64 n, const HloInstruction* hlo);
BufferAllocation::Slice info_buffer, PrimitiveType type,
int64 batch_size, int64 n);
CholeskyThunk(const CholeskyThunk&) = delete;
CholeskyThunk& operator=(const CholeskyThunk&) = delete;

View File

@ -218,14 +218,14 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
} // anonymous namespace
CollectivePermuteThunk::CollectivePermuteThunk(
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest,
const HloInstruction* instr)
: Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {}
ThunkInfo thunk_info, const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest)
: Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {}
Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction());
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
// Rendezvous with the threads for all other devices that are participating in
// this CollectivePermute.

View File

@ -26,9 +26,9 @@ namespace gpu {
// Thunk that implements the collective-permute HLO.
class CollectivePermuteThunk : public Thunk {
public:
CollectivePermuteThunk(const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest,
const HloInstruction* instr);
CollectivePermuteThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -24,12 +24,14 @@ namespace xla {
namespace gpu {
ConditionalThunk::ConditionalThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo)
: Thunk(Kind::kConditional, hlo),
branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED),
std::vector<ThunkSequence> branch_thunk_sequences)
: Thunk(Kind::kConditional, thunk_info),
branch_index_is_bool_(
thunk_info.hlo_instruction->operand(0)->shape().element_type() ==
PRED),
branch_index_buffer_index_(branch_index_buffer_index),
branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(),
branch_operand_buffer_indexes.end()) {
@ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk(
branch_thunks_.reserve(branch_thunk_sequences.size());
for (auto& branch_thunk_sequence : branch_thunk_sequences) {
branch_thunks_.emplace_back(
new SequentialThunk(std::move(branch_thunk_sequence), nullptr));
new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence)));
}
}
@ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& profiler = *params.profiler;
auto& stream = *params.stream;
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
// Copy the predicate value from device.
int32 branch_index = -1;
bool pred = false;

View File

@ -43,10 +43,10 @@ namespace gpu {
class ConditionalThunk : public Thunk {
public:
ConditionalThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo);
std::vector<ThunkSequence> branch_thunk_sequences);
ConditionalThunk(const ConditionalThunk&) = delete;
ConditionalThunk& operator=(const ConditionalThunk&) = delete;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@ -30,16 +31,16 @@ namespace xla {
namespace gpu {
ConvolutionThunk::ConvolutionThunk(
const HloCustomCallInstruction* cudnn_call,
std::vector<BufferAllocation::Slice> operand_slices,
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
BufferAllocation::Slice tuple_result_slice)
: Thunk(Kind::kConvolution, cudnn_call),
cudnn_call_(cudnn_call),
: Thunk(Kind::kConvolution, thunk_info),
operand_buffers_(std::move(operand_slices)),
result_buffer_(result_slice),
scratch_buffer_(scratch_slice),
tuple_result_buffer_(tuple_result_slice) {}
tuple_result_buffer_(tuple_result_slice) {
cudnn_call_ = Cast<HloCustomCallInstruction>(hlo_instruction());
}
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
const auto& buffer_allocations = *params.buffer_allocations;
@ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(scratch_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers),
result_buffer, scratch, params.stream));

View File

@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
// operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
ConvolutionThunk(ThunkInfo thunk_info,
std::vector<BufferAllocation::Slice> operand_slices,
BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,

View File

@ -22,10 +22,9 @@ namespace xla {
namespace gpu {
HostToDeviceCopyThunk::HostToDeviceCopyThunk(
const void* source_address,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kCopy, hlo_instruction),
ThunkInfo thunk_info, const void* source_address,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
: Thunk(Kind::kCopy, thunk_info),
source_address_(source_address),
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
@ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase destination_data =
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
return Status::OK();
}
DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kCopy, hlo_instruction),
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
: Thunk(Kind::kCopy, thunk_info),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
@ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase source_data =
params.buffer_allocations->GetDeviceAddress(source_buffer_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemcpy(&destination_data, source_data, mem_size_);
return Status::OK();
}

View File

@ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk {
// Constructs a CopyThunk that copies host data from `source_address` to the
// device buffer `destination_buffer`. `mem_size` is the size of the data in
// bytes.
HostToDeviceCopyThunk(const void* source_address,
HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address,
const BufferAllocation::Slice& destination_buffer,
uint64 mem_size, const HloInstruction* hlo_instruction);
uint64 mem_size);
HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete;
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
@ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk {
// Constructs a CopyThunk that copies host data from `source_buffer` to the
// device buffer `destination_buffer`. `mem_size` is the size of the data in
// bytes.
DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer,
DeviceToDeviceCopyThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
uint64 mem_size,
const HloInstruction* hlo_instruction);
uint64 mem_size);
DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete;
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;

View File

@ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
} // namespace
CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
const BufferAllocation::Slice& output, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
const BufferAllocation::Slice& output)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
operand_(operand),
scale_(scale),
offset_(offset),
@ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
epsilon_(epsilon),
feature_index_(feature_index),
output_(output) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(),
kCudnnBatchNormForwardInferenceCallTarget);
@ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
const ExecuteParams& params) {
auto& buffer_allocations = *params.buffer_allocations;
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::DeviceMemoryBase output_base =
buffer_allocations.GetDeviceAddress(output_);
se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
@ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
}
CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
operand_(operand),
scale_(scale),
offset_(offset),
@ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
output_mean_(output_mean),
output_inv_stddev_(output_inv_stddev),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
@ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
se::DeviceMemory<float> null_device_ptr(nullptr);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto& stream = *params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
hlo_instruction(), operand, output_data, output_mean, output_inv_stddev,
@ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
}
CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output, float epsilon,
int64 feature_index, const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
operand_(operand),
scale_(scale),
mean_(mean),
@ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
output_grad_scale_(output_grad_scale),
output_grad_offset_(output_grad_offset),
output_tuple_(output_tuple) {
const auto* hlo = hlo_instruction();
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
@ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(output_grad_offset_));
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
se::Stream* stream = params.stream;
TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
hlo_instruction(), operand, output_grad_data, grad_output,

View File

@ -46,14 +46,14 @@ namespace gpu {
class CudnnBatchNormForwardInferenceThunk : public Thunk {
public:
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand,
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& variance,
float epsilon, int64 feature_index,
const BufferAllocation::Slice& output,
const HloInstruction* hlo);
const BufferAllocation::Slice& output);
CudnnBatchNormForwardInferenceThunk(
const CudnnBatchNormForwardInferenceThunk&) = delete;
@ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
class CudnnBatchNormForwardTrainingThunk : public Thunk {
public:
CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& operand,
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo);
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormForwardTrainingThunk(
const CudnnBatchNormForwardTrainingThunk&) = delete;
@ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
class CudnnBatchNormBackwardThunk : public Thunk {
public:
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand,
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
@ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple,
const HloInstruction* hlo);
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =

View File

@ -22,15 +22,15 @@ namespace xla {
namespace gpu {
CustomCallThunk::CustomCallThunk(
void* call_target,
ThunkInfo thunk_info, void* call_target,
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
const HloInstruction* instr)
: Thunk(Thunk::kCustomCall, instr),
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
: Thunk(Thunk::kCustomCall, thunk_info),
call_target_(call_target),
operand_slices_(std::move(operand_slices)),
result_slices_(std::move(result_slices)),
opaque_(std::move(opaque)) {
const HloInstruction* instr = hlo_instruction();
CHECK_EQ(instr->operand_count(), operand_slices_.size());
for (int64 i = 0; i < instr->operand_count(); ++i) {
const auto& s1 = operand_slices_[i].shape();

View File

@ -39,10 +39,9 @@ namespace gpu {
class CustomCallThunk : public Thunk {
public:
CustomCallThunk(
void* call_target,
ThunkInfo thunk_info, void* call_target,
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
const HloInstruction* instr);
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
struct NcclAllReduceThunk::AuxData {};
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
ThunkInfo thunk_info, int64 replica_count,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
replica_count_(replica_count),
buffers_(std::move(buffers)) {}

View File

@ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) {
} // namespace
FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
const HloInstruction* hlo)
: Thunk(Kind::kFft, hlo),
const Shape& input_shape, const Shape& output_shape)
: Thunk(Kind::kFft, thunk_info),
fft_type_(
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
input_shape.element_type() == C128)),
@ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.memory_allocator());
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
if (fft_plan_ == nullptr) {
const int64 fft_rank = fft_length_.size();
CHECK_LE(fft_rank, 3);

View File

@ -62,11 +62,11 @@ class FftThunk : public Thunk {
public:
// Constructs a thunk for launching an FFT on a stream.
// Semantics of null hlo_instruction argument are as in Thunk.
FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
FftThunk(ThunkInfo thunk_info, FftType fft_type,
absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
const HloInstruction* hlo);
const Shape& input_shape, const Shape& output_shape);
FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_

View File

@ -23,16 +23,15 @@ limitations under the License.
namespace xla {
namespace gpu {
ForThunk::ForThunk(const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence)
: Thunk(Kind::kWhile, thunk_info),
loop_limit_(loop_limit),
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
std::move(*body_thunk_sequence), nullptr)) {}
ThunkInfo(), std::move(*body_thunk_sequence))) {}
void ForThunk::ComputeAnnotations() {
Thunk::ComputeAnnotations();
@ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
for (int64 i = 0; i < loop_limit_; ++i) {
params.profiler->StartHloComputation();
// Invoke loop body thunk sequence.

View File

@ -31,9 +31,8 @@ namespace gpu {
// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
class ForThunk : public Thunk {
public:
ForThunk(const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo);
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence);
ForThunk(const ForThunk&) = delete;
ForThunk& operator=(const ForThunk&) = delete;

View File

@ -132,6 +132,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer,
stream,
/*implements_whole_instruction=*/true,
/*profile_index=*/-1,
/*profiler=*/nullptr,
/*profile_result=*/&profile_result, algorithm)
.ok());

View File

@ -33,13 +33,13 @@ limitations under the License.
namespace xla {
namespace gpu {
GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
GemmThunk::GemmThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice &lhs_buffer,
const BufferAllocation::Slice &rhs_buffer,
const BufferAllocation::Slice &output_buffer,
bool implements_whole_instruction,
const HloInstruction *hlo_instruction,
const GemmBackendConfig &backend_config)
: Thunk(Kind::kGemm, hlo_instruction),
: Thunk(Kind::kGemm, thunk_info),
lhs_buffer_(lhs_buffer),
rhs_buffer_(rhs_buffer),
output_buffer_(output_buffer),
@ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams &params) {
se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
output_data, params.stream, implements_whole_instruction_,
params.profiler);
profile_index(), params.profiler);
}
// This struct contains the metadata of a matrix, e.g., its base address and
@ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm,
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
se::DeviceMemoryBase output_buffer, se::Stream *stream,
bool implements_whole_instruction,
absl::optional<int64> profile_index,
HloExecutionProfiler *profiler,
se::blas::ProfileResult *profile_result,
absl::optional<se::blas::AlgorithmType> algorithm) {
@ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm,
rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
std::unique_ptr<ScopedInstructionProfiler> op_profiler =
profiler ? profiler->MakeScopedInstructionProfiler(
implements_whole_instruction ? gemm : nullptr)
implements_whole_instruction ? profile_index : -1)
: nullptr;
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {

View File

@ -39,11 +39,10 @@ class GemmThunk : public Thunk {
public:
// Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
// BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
bool implements_whole_instruction,
const HloInstruction* hlo_instruction,
const GemmBackendConfig& backend_config);
GemmThunk(const GemmThunk&) = delete;
@ -72,7 +71,8 @@ Status RunGemm(
const HloInstruction* gemm, const GemmBackendConfig& backend_config,
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
se::DeviceMemoryBase output_buffer, se::Stream* stream,
bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr,
bool implements_whole_instruction, absl::optional<int64> profile_index,
HloExecutionProfiler* profiler = nullptr,
se::blas::ProfileResult* profile_result = nullptr,
absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt);

View File

@ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl(
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
int pointer_size, std::unique_ptr<llvm::Module>* llvm_module,
int pointer_size, const HloProfileIndexMap* profile_index_map,
std::unique_ptr<llvm::Module>* llvm_module,
std::unique_ptr<BufferAssignment>* buffer_assignment,
std::unique_ptr<ThunkSchedule>* thunk_schedule) {
*llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
@ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl(
IrEmitterContext ir_emitter_context(
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
cuda_compute_capability, llvm_module->get());
cuda_compute_capability, profile_index_map, llvm_module->get());
HloComputation* entry_computation = hlo_module->entry_computation();
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
@ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl(
// not all explicitly checked, but at least we can document them here:
// * The entry HloComputation shall not have dead code (all reachable from
// ROOT).
// * For each visit of HloInstruction, either none or one Thunk will be
// returned.
// * The visited instructions are all instructions in the entry
// computation.
// * For each visit of these HloInstructions, either none or one Thunk
// will be returned.
// * If there is a thunk returned, thunk->hlo_instruction() equals the
// input HloInstruction*.
// * A returned thunk may contain other sub-thunks. A sub-thunk may or may
// not have an associated hlo_instruction().
TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
if (!thunks->empty()) {
auto thunk = std::move(thunks->front());
@ -603,6 +608,25 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
return cuda_compute_capability;
}();
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
VLOG(1) << "HLO memory read+written: "
<< tensorflow::strings::HumanReadableNumBytes(
cost_analysis.bytes_accessed());
if (module->config().hlo_profiling_enabled()) {
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
module->entry_computation()->name());
}
}
std::unique_ptr<llvm::Module> llvm_module;
std::unique_ptr<BufferAssignment> buffer_assignment;
std::unique_ptr<ThunkSchedule> thunk_schedule;
@ -610,8 +634,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
module.get(), &llvm_context, target_triple_, data_layout_,
stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment,
&thunk_schedule));
GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
&buffer_assignment, &thunk_schedule));
if (user_pre_optimization_hook_) {
user_pre_optimization_hook_(*llvm_module);
@ -653,25 +677,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
thunk_schedule->ToString());
}
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
VLOG(1) << "HLO memory read+written: "
<< tensorflow::strings::HumanReadableNumBytes(
cost_analysis.bytes_accessed());
if (module->config().hlo_profiling_enabled()) {
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
module->entry_computation()->name());
}
}
auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version,
std::move(thunk_schedule), std::move(module),
@ -709,7 +714,8 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
hlo_module, llvm_context, target_triple, data_layout, platform_name,
gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule));
pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
&buffer_assignment, &thunk_schedule));
return llvm_module;
}
} // namespace gpu

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() {
}
}
void HloExecutionProfiler::FinishHloInstruction(
const HloInstruction* hlo_instruction) {
void HloExecutionProfiler::FinishHloInstruction(size_t index) {
if (do_profile_) {
hlo_instructions_.erase(hlo_instruction);
profile_->SetCyclesTakenBy(
hlo_instruction,
GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
indices_.erase(index);
profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_,
stream_, clock_rate_ghz_));
}
}
std::unique_ptr<ScopedInstructionProfiler>
HloExecutionProfiler::MakeScopedInstructionProfiler(
const HloInstruction* hlo_instruction) {
if (do_profile_ && hlo_instruction != nullptr) {
absl::optional<int64> index) {
if (do_profile_ && index.has_value()) {
// Make sure that we are not already measuring the time for the same
// 'hlo_instruction'.
CHECK(hlo_instructions_.insert(hlo_instruction).second)
<< hlo_instruction->name();
// instruction.
// TODO(timshen): provide more useful printout.
CHECK(indices_.insert(*index).second) << *index;
}
return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
return absl::make_unique<ScopedInstructionProfiler>(this, index);
}
} // namespace gpu

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -58,14 +57,17 @@ class HloExecutionProfiler {
void StartHloInstruction();
// If profiling is enabled, stops the per-operation timer and records the time
// that the hlo_instruction took to execute in the profile.
void FinishHloInstruction(const HloInstruction* hlo_instruction);
// that at `profile_index`. Profile indices can be looked up from
// HloProfileIndexMap.
void FinishHloInstruction(size_t profile_index);
// Returns a ScopedInstructionProfiler and triggers a call to
// StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
// out of scope, it triggers a call to FinishHloInstruction().
//
// If profile_index < 0, it results in a no-op.
std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
const HloInstruction* hlo_instruction);
absl::optional<int64> profile_index);
private:
const bool do_profile_;
@ -77,7 +79,7 @@ class HloExecutionProfiler {
std::stack<std::unique_ptr<se::Timer>> timers_;
// Contains the HLO instructions for which we are currently measuring the
// time.
std::unordered_set<const HloInstruction*> hlo_instructions_;
std::unordered_set<size_t> indices_;
bool finished_execution_ = false;
};
@ -87,21 +89,21 @@ class HloExecutionProfiler {
class ScopedInstructionProfiler {
public:
ScopedInstructionProfiler(HloExecutionProfiler* profiler,
const HloInstruction* hlo_instruction)
: profiler_(profiler), hlo_instruction_(hlo_instruction) {
if (hlo_instruction != nullptr) {
absl::optional<int64> index)
: profiler_(profiler), index_(index) {
if (index_.has_value()) {
profiler->StartHloInstruction();
}
}
~ScopedInstructionProfiler() {
if (hlo_instruction_ != nullptr) {
profiler_->FinishHloInstruction(hlo_instruction_);
if (index_.has_value()) {
profiler_->FinishHloInstruction(*index_);
}
}
private:
HloExecutionProfiler* profiler_;
const HloInstruction* hlo_instruction_;
absl::optional<int64> index_;
};
} // namespace gpu

View File

@ -23,9 +23,9 @@ namespace xla {
namespace gpu {
InfeedThunk::InfeedThunk(
const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
ThunkInfo thunk_info,
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
: Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream;
@ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
ShapeTree<InfeedBuffer> infeed_buffers =
GetOrCreateInfeedManager()->BlockingGetNextDestination();

View File

@ -34,8 +34,8 @@ class InfeedThunk : public Thunk {
public:
// Constructs a InfeedThunk that copies data from the on-device
// infeed queue into the buffers in the given shape tree.
InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction);
InfeedThunk(ThunkInfo thunk_info,
const ShapeTree<BufferAllocation::Slice>& infeed_slices);
InfeedThunk(const InfeedThunk&) = delete;
InfeedThunk& operator=(const InfeedThunk&) = delete;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
namespace xla {
@ -33,12 +34,13 @@ class IrEmitterContext {
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
std::string platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
llvm::Module* llvm_module)
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
: hlo_module_(hlo_module),
buffer_assignment_(buffer_assignment),
platform_name_(std::move(platform_name)),
gpu_device_info_(gpu_device_info),
cuda_compute_capability_(cuda_compute_capability),
profile_index_map_(profile_index_map),
llvm_module_(llvm_module) {}
// Disallow copy and assign.
IrEmitterContext(const IrEmitterContext&) = delete;
@ -54,6 +56,7 @@ class IrEmitterContext {
absl::optional<CudaComputeCapability> cuda_compute_capability() const {
return cuda_compute_capability_;
}
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
llvm::Module* llvm_module() { return llvm_module_; }
NameUniquer* name_uniquer() { return &name_uniquer_; }
@ -63,6 +66,7 @@ class IrEmitterContext {
std::string platform_name_;
GpuDeviceInfo gpu_device_info_;
absl::optional<CudaComputeCapability> cuda_compute_capability_;
const HloProfileIndexMap* profile_index_map_;
llvm::Module* llvm_module_;
NameUniquer name_uniquer_;
};

View File

@ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
/*updates_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(2))));
}
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(fusion), std::move(thunks)));
return Status::OK();
}
// In the case of root tuple, it can be either reduce or slice input
@ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
auto destination_buffer = GetAllocationSlice(*copy);
if (operand_buffer != destination_buffer) {
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
GetThunkInfo(copy),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/
ByteSizeOf(copy->operand(0)->shape()), copy));
ByteSizeOf(copy->operand(0)->shape())));
}
return Status::OK();
}
@ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
}
AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
GetThunkInfo(tuple), tuple_element_buffers,
GetAllocationSlice(*tuple)));
return Status::OK();
}
AddThunkToThunkSequence(
@ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
thunks.push_back(BuildKernelThunk(select_and_scatter,
/*implements_whole_instruction=*/false));
std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
absl::make_unique<SequentialThunk>(GetThunkInfo(select_and_scatter),
std::move(thunks));
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
@ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
auto destination_buffer = GetAllocationSlice(*scatter);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
/*hlo_instruction=*/nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
}
thunks.push_back(
@ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(scatter), std::move(thunks)));
}
return Status::OK();
@ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/source_address,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
}
}
@ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
}
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(sort), std::move(thunks)));
if (sort->operand_count() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
@ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
}
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
AddThunkToThunkSequence(
absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));
return Status::OK();
}
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo)));
return Status::OK();
}
@ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
GetThunkInfo(crs),
/*replica_count=*/hlo_module_config_.replica_count(),
/*buffers=*/std::move(buffers), crs);
/*buffers=*/std::move(buffers));
if (crs->shape().IsTuple()) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(all_reduce_thunk));
thunks.push_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(crs), std::move(thunks)));
} else {
AddThunkToThunkSequence(std::move(all_reduce_thunk));
}
@ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
CHECK(crs->operand(0)->shape().IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
GetThunkInfo(crs),
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape())));
return Status::OK();
}
@ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
.GetUniqueSlice(crs, {i})
.ValueOrDie());
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape())));
}
// Output a tuple of the buffers above.
thunks.push_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
absl::make_unique<SequentialThunk>(GetThunkInfo(crs), std::move(thunks)));
return Status::OK();
}
@ -1787,8 +1794,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return absl::make_unique<KernelThunk>(
non_constant_buffers, std::string(kernel->getName()),
implements_whole_instruction ? inst : nullptr);
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
non_constant_buffers, std::string(kernel->getName()));
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
@ -1838,8 +1845,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
nullptr)};
return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(),
GetAllocationSlice(*hlo, index))};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@ -1857,7 +1864,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {absl::make_unique<Memset32BitValueThunk>(
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@ -1868,7 +1875,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {absl::make_unique<Memset32BitValueThunk>(
word, GetAllocationSlice(*hlo, index), nullptr)};
Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))};
}
}
@ -2014,9 +2021,10 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return absl::make_unique<WhileThunk>(
GetThunkInfo(hlo),
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence(), hlo);
ir_emitter_body.ConsumeThunkSequence());
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
@ -2031,8 +2039,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return absl::make_unique<ForThunk>(
loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
ir_emitter_body.ConsumeThunkSequence());
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
@ -2054,8 +2062,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
}
return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks), hlo);
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks));
}
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
@ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
ir_emitter_context_->llvm_module());
thunks.push_back(std::move(kernel_thunk));
auto sequential_thunk =
absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
auto sequential_thunk = absl::make_unique<SequentialThunk>(
GetThunkInfo(unnested_hlo), std::move(thunks));
AddThunkToThunkSequence(std::move(sequential_thunk));
return Status::OK();
@ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
return emit_status;
}
Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
const HloInstruction* hlo) const {
auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
info.profile_index.emplace(
static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
}
return info;
}
} // namespace gpu
} // namespace xla

View File

@ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter,
// Returns the last generated thunk.
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
// The thunk sequence this IrEmitter generates for the input computation.
ThunkSequence thunk_sequence_;

View File

@ -33,10 +33,10 @@ limitations under the License.
namespace xla {
namespace gpu {
KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kKernel, hlo_instruction),
KernelThunk::KernelThunk(ThunkInfo thunk_info,
absl::Span<const BufferAllocation* const> args,
const string& kernel_name)
: Thunk(Kind::kKernel, thunk_info),
args_(args.begin(), args.end()),
kernel_name_(kernel_name) {}
@ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
}
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
return ExecuteKernelOnStream(*kernel, buffer_args,
launch_dimensions.threads_per_block(),
launch_dimensions.block_count(), params.stream);

View File

@ -47,8 +47,9 @@ class KernelThunk : public Thunk {
// Constructs a thunk for the given kernel.
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name, const HloInstruction* hlo_instruction);
KernelThunk(ThunkInfo thunk_info,
absl::Span<const BufferAllocation* const> args,
const string& kernel_name);
KernelThunk(const KernelThunk&) = delete;
KernelThunk& operator=(const KernelThunk&) = delete;
~KernelThunk() override = default;

View File

@ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase dest_data =
params.buffer_allocations->GetDeviceAddress(dest_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemZero(&dest_data, dest_data.size());
return Status::OK();
}
@ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) {
se::DeviceMemoryBase dest_data =
params.buffer_allocations->GetDeviceAddress(dest_);
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
params.stream->ThenMemset32(&dest_data, value_, dest_data.size());
return Status::OK();
}

View File

@ -32,9 +32,9 @@ namespace gpu {
// Thunk that zeroes out a given chunk of memory.
class MemzeroThunk : public Thunk {
public:
explicit MemzeroThunk(const BufferAllocation::Slice& dest,
const HloInstruction* hlo)
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
explicit MemzeroThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kMemzero, thunk_info), dest_(dest) {}
Status ExecuteOnStream(const ExecuteParams& params) override;
@ -46,10 +46,11 @@ class MemzeroThunk : public Thunk {
// destination chunk must have size divisible by 32 bits.
class Memset32BitValueThunk : public Thunk {
public:
explicit Memset32BitValueThunk(uint32 value,
const BufferAllocation::Slice& dest,
const HloInstruction* hlo)
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kMemset32BitValue, thunk_info),
value_(value),
dest_(dest) {}
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
}
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
ThunkInfo thunk_info, int64 replica_count,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
replica_count_(replica_count),
buffers_(std::move(buffers)),
aux_data_(absl::make_unique<AuxData>()) {
@ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk(
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << "Starting NcclAllReduceThunk.";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
int64 local_device_ordinal = params.stream->parent()->device_ordinal();

View File

@ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk {
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
const HloInstruction* all_reduce);
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
std::vector<Buffer> buffers);
~NcclAllReduceThunk() override;
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -23,9 +23,9 @@ limitations under the License.
namespace xla {
namespace gpu {
OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kOutfeed, hlo_instruction),
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
ShapeTree<BufferAllocation::Slice> outfeed_slices)
: Thunk(Kind::kOutfeed, thunk_info),
outfeed_slices_(std::move(outfeed_slices)) {}
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
@ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
outfeed_manager->BlockingGetNextDestination();

View File

@ -32,8 +32,8 @@ class OutfeedThunk : public Thunk {
public:
// Constructs a OutfeedThunk that copies data to the host-side
// outfeed queue from the buffers in the given shape tree.
OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
const HloInstruction* hlo_instruction);
OutfeedThunk(ThunkInfo thunk_info,
ShapeTree<BufferAllocation::Slice> outfeed_slices);
OutfeedThunk(const OutfeedThunk&) = delete;
OutfeedThunk& operator=(const OutfeedThunk&) = delete;

View File

@ -18,13 +18,13 @@ limitations under the License.
namespace xla {
namespace gpu {
ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest,
const HloInstruction* instr)
: Thunk(Kind::kReplicaId, instr), dest_(dest) {}
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
TF_ASSIGN_OR_RETURN(int replica_id,

View File

@ -26,8 +26,7 @@ namespace gpu {
// Thunk that implements the ReplicaId HLO.
class ReplicaIdThunk : public Thunk {
public:
ReplicaIdThunk(const BufferAllocation::Slice& dest,
const HloInstruction* instr);
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -24,9 +24,9 @@ namespace gpu {
using ::tensorflow::profiler::ScopedAnnotation;
SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
const HloInstruction* hlo)
: Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
SequentialThunk::SequentialThunk(ThunkInfo thunk_info,
std::vector<std::unique_ptr<Thunk>> thunks)
: Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {}
void SequentialThunk::ComputeAnnotations() {
for (const auto& thunk : thunks_) {
@ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
for (const auto& thunk : thunks_) {
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params));

View File

@ -32,8 +32,8 @@ namespace gpu {
// require multiple kernel launches or library calls.
class SequentialThunk : public Thunk {
public:
SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
const HloInstruction* hlo);
SequentialThunk(ThunkInfo thunk_info,
std::vector<std::unique_ptr<Thunk>> thunks);
SequentialThunk(const SequentialThunk&) = delete;
SequentialThunk& operator=(const SequentialThunk&) = delete;

View File

@ -68,13 +68,21 @@ class Thunk {
kWhile,
};
struct ThunkInfo {
const HloInstruction* hlo_instruction = nullptr;
absl::optional<int64> profile_index;
// TODO(timshen): Remove hlo_instruction and add name(),
// profile_annotation() here.
};
// The hlo_instruction argument is meant to be the instruction this thunk was
// generated from, but Thunk never uses this argument other than to save it
// to Thunk::hlo_instruction, so it can be null.
explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
explicit Thunk(Kind kind, ThunkInfo thunk_info)
: kind_(kind),
hlo_instruction_(hlo_instruction),
name_(hlo_instruction_ ? hlo_instruction_->name() : "") {}
hlo_instruction_(thunk_info.hlo_instruction),
name_(hlo_instruction_ ? hlo_instruction_->name() : ""),
profile_index_(thunk_info.profile_index) {}
virtual ~Thunk() {}
Thunk(const Thunk&) = delete;
Thunk& operator=(const Thunk&) = delete;
@ -128,6 +136,8 @@ class Thunk {
protected:
const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
absl::optional<int64> profile_index() const { return profile_index_; }
const HloModuleConfig& GetModuleConfig() const {
return hlo_instruction()->GetModule()->config();
}
@ -146,8 +156,12 @@ class Thunk {
private:
Kind kind_;
// Will be removed in the future, as Thunk is migrating away from the
// monolithic HloInstruction.
const HloInstruction* hlo_instruction_;
std::string name_;
absl::optional<int64> profile_index_;
string profile_annotation_;
};

View File

@ -40,11 +40,11 @@ namespace gpu {
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
return absl::make_unique<FftThunk>(
inst->fft_type(), inst->fft_length(),
context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(),
/*input_buffer=*/GetAllocationSlice(*operand),
/*output_buffer=*/GetAllocationSlice(*inst),
/*input_shape=*/operand->shape(),
/*output_shape=*/inst->shape(), inst);
/*output_shape=*/inst->shape());
}
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
@ -63,11 +63,11 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
: n * n * elem_size;
int64 b_batch_stride = m * n * elem_size;
return absl::make_unique<TriangularSolveThunk>(
inst->triangular_solve_options(),
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
/*a_input_buffer=*/GetAllocationSlice(*a),
/*b_input_buffer=*/GetAllocationSlice(*inst),
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
b_batch_stride, inst);
b_batch_stride);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
@ -86,24 +86,27 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_buffer=*/GetAllocationSlice(*bias),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape())));
thunks.push_back(absl::make_unique<GemmThunk>(
context_->GetThunkInfo(inst),
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
/*implements_whole_instruction=*/false, inst,
std::move(gemm_config)));
return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
/*implements_whole_instruction=*/false, std::move(gemm_config)));
return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst),
std::move(thunks));
}
}
return absl::make_unique<GemmThunk>(
context_->GetThunkInfo(inst),
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
/*implements_whole_instruction=*/true, inst, std::move(gemm_config));
/*implements_whole_instruction=*/true, std::move(gemm_config));
}
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
@ -115,7 +118,7 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
[&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
*slice = GetAllocationSlice(*inst, index);
});
return absl::make_unique<InfeedThunk>(slices, inst);
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst), slices);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
@ -130,7 +133,8 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
*slice = status_or_slice.ValueOrDie();
}
});
return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
std::move(slices));
}
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
@ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*variance=*/GetAllocationSlice(*custom_call->operand(4)),
/*epsilon=*/epsilon_value,
/*feature_index=*/feature_index_value,
/*output=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_inv_stddev = GetAllocationSlice(*custom_call, {2});
AddThunkToThunkSequence(
absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*output_data=*/output_data,
/*output_mean=*/output_mean,
/*output_inv_stddev=*/output_inv_stddev,
/*output_tuple=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_scale = GetAllocationSlice(*custom_call, {1});
auto output_grad_offset = GetAllocationSlice(*custom_call, {2});
AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
context_->GetThunkInfo(custom_call),
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*mean=*/GetAllocationSlice(*custom_call->operand(2)),
@ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*output_grad_data=*/output_grad_data,
/*output_grad_scale=*/output_grad_scale,
/*output_grad_offset=*/output_grad_offset,
/*output_tuple=*/GetAllocationSlice(*custom_call),
/*hlo=*/custom_call));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK();
}
@ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
context_->GetThunkInfo(custom_call), std::move(operand_slices),
conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
@ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (operand_buffer != a_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(custom_call),
/*source_address=*/operand_buffer,
/*destination_buffer=*/a_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
}
thunks.push_back(absl::make_unique<CholeskyThunk>(
options, a_buffer, workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n,
custom_call));
context_->GetThunkInfo(custom_call), options, a_buffer,
workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(custom_call), std::move(thunks)));
}
return Status::OK();
@ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
ShapeTree<BufferAllocation::Slice> result_slices =
get_slices_for_instr(custom_call);
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
call_target, std::move(operand_slices), std::move(result_slices),
Cast<HloCustomCallInstruction>(custom_call)->opaque(), custom_call));
context_->GetThunkInfo(custom_call), call_target,
std::move(operand_slices), std::move(result_slices),
Cast<HloCustomCallInstruction>(custom_call)->opaque()));
return Status::OK();
}
#endif
@ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
auto destination_buffer = GetAllocationSlice(*hlo);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(hlo),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
}
thunks.push_back(BuildTriangularSolveThunk(hlo));
@ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(hlo), std::move(thunks)));
}
return Status::OK();
}
@ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
return Status::OK();
}
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
const HloInstruction* hlo) const {
CHECK(hlo);
Thunk::ThunkInfo info;
info.hlo_instruction = hlo;
return info;
}
} // namespace gpu
} // namespace xla

View File

@ -36,6 +36,7 @@ class ThunkEmitter {
const HloInstruction& hlo, const ShapeIndex& index) const = 0;
virtual int64 ByteSizeOf(const Shape& shape) const = 0;
virtual absl::string_view platform_name() const = 0;
virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
virtual ~EmissionContext() = default;
};

View File

@ -32,12 +32,12 @@ namespace xla {
namespace gpu {
TriangularSolveThunk::TriangularSolveThunk(
const TriangularSolveOptions& options,
ThunkInfo thunk_info, const TriangularSolveOptions& options,
const BufferAllocation::Slice& a_buffer,
const BufferAllocation::Slice& b_buffer, PrimitiveType type,
int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
int64 b_batch_stride, const HloInstruction* hlo)
: Thunk(Kind::kTriangularSolve, hlo),
int64 b_batch_stride)
: Thunk(Kind::kTriangularSolve, thunk_info),
uplo_(options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper),
side_(options.left_side() ? se::blas::Side::kLeft

View File

@ -38,12 +38,12 @@ namespace gpu {
// Thread-compatible.
class TriangularSolveThunk : public Thunk {
public:
TriangularSolveThunk(const TriangularSolveOptions& options,
TriangularSolveThunk(ThunkInfo thunk_info,
const TriangularSolveOptions& options,
const BufferAllocation::Slice& a_buffer,
const BufferAllocation::Slice& b_buffer,
PrimitiveType type, int64 batch_size, int64 m, int64 n,
int64 a_batch_stride, int64 b_batch_stride,
const HloInstruction* hlo);
int64 a_batch_stride, int64 b_batch_stride);
TriangularSolveThunk(const TriangularSolveThunk&) = delete;
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;

View File

@ -34,7 +34,7 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) {
}
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
params.profiler->MakeScopedInstructionProfiler(profile_index());
SafeH2DMemcpy(se::DeviceMemory<void*>(
buffer_allocations.GetDeviceAddress(dest_buffer_)),
std::move(tuple_data), n, &stream,

View File

@ -34,10 +34,10 @@ namespace gpu {
// issue (b/31336476).
class TupleThunk : public Thunk {
public:
TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
const BufferAllocation::Slice& dest_buffer,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kTuple, hlo_instruction),
TupleThunk(ThunkInfo thunk_info,
absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
const BufferAllocation::Slice& dest_buffer)
: Thunk(Kind::kTuple, thunk_info),
tuple_element_buffers_(tuple_element_buffers.begin(),
tuple_element_buffers.end()),
dest_buffer_(dest_buffer) {}

View File

@ -24,20 +24,20 @@ namespace xla {
namespace gpu {
WhileThunk::WhileThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& condition_result_buffer_index,
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
std::unique_ptr<ThunkSequence> body_thunk_sequence)
: Thunk(Kind::kWhile, thunk_info),
condition_result_buffer_index_(condition_result_buffer_index),
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
// and body_thunk_sequence_ constructors because these SequentialThunks
// are logically "part of" this WhileThunk, and shouldn't be profiled
// separately from it.
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*condition_thunk_sequence), nullptr)),
ThunkInfo(), std::move(*condition_thunk_sequence))),
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*body_thunk_sequence), nullptr)) {}
ThunkInfo(), std::move(*body_thunk_sequence))) {}
void WhileThunk::ComputeAnnotations() {
Thunk::ComputeAnnotations();
@ -61,7 +61,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
params.buffer_allocations->GetDeviceAddress(
condition_result_buffer_index_);
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
while (true) {
// Invoke thunk sequence for while 'condition' computation.
profiler.StartHloComputation();

View File

@ -39,10 +39,10 @@ namespace gpu {
class WhileThunk : public Thunk {
public:
// Constructs a WhileThunk to compute while instruction 'hlo'.
WhileThunk(const BufferAllocation::Slice& condition_result_buffer_index,
WhileThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& condition_result_buffer_index,
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
std::unique_ptr<ThunkSequence> body_thunk_sequence,
const HloInstruction* hlo);
std::unique_ptr<ThunkSequence> body_thunk_sequence);
WhileThunk(const WhileThunk&) = delete;
WhileThunk& operator=(const WhileThunk&) = delete;

View File

@ -133,8 +133,12 @@ HloExecutionProfile::HloExecutionProfile(
void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
uint64 cycles_taken) {
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] =
cycles_taken;
SetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(*hlo),
cycles_taken);
}
void HloExecutionProfile::SetCyclesTakenBy(size_t index, uint64 cycles_taken) {
profile_counters_[index] = cycles_taken;
}
uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {

View File

@ -114,6 +114,9 @@ class HloExecutionProfile {
// Record how many cycles this HLO took to execute.
void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken);
// Record how many cycles this HLO took to execute.
void SetCyclesTakenBy(size_t index, uint64 cycles_taken);
// Returns how many cycles this HLO took to execute. Profiling information
// may not be available for some instructions in which case zero is returned.
uint64 GetCyclesTakenBy(const HloInstruction& hlo) const;

View File

@ -88,6 +88,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
const HloInstruction& hlo, const ShapeIndex& index) const override;
int64 ByteSizeOf(const Shape& shape) const override;
absl::string_view platform_name() const override;
mlir::Location getLocation(const HloInstruction* instr) const;
xla::mlir_gpu::EmissionContext* emission_context_;

View File

@ -436,8 +436,10 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
kernel, operand_to_value_map, ordered_operands, assignment, buffers));
// Finally, create the thunk and set the launch dimensions.
auto thunk = absl::make_unique<gpu::KernelThunk>(
buffers, kernel.getName().str(), instr);
gpu::Thunk::ThunkInfo info;
info.hlo_instruction = instr;
auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers,
kernel.getName().str());
// Set launch bounds.
mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();