[XLA/GPU] Remove uses of Thunk::hlo_instruction() for profiling.
This CL consists of two steps: * First, refactor all Thunks to take an ThunkInfo instead of const HloInstruction*. This will benefit future extensions to ThunkInfo as we move away from HloInstruction*. * Secondly, change the data pipeline from: Emitter -> Thunk* -> hlo_instruction() -> profiler(HloInstruction*) to: Emitter -> Thunk with profile indices The profile doesn't really depend on HloInstruction*, but just its pointer identity. Removing the dependency on HloInstruction helps with MLIR migration. PiperOrigin-RevId: 320687291 Change-Id: I7027d4c032f73ed615e5b520e01f3740781735be
This commit is contained in:
parent
aa47bcc6f1
commit
5bbf4a1d11
@ -242,7 +242,6 @@ cc_library(
|
||||
deps = [
|
||||
":backend_configs_cc",
|
||||
":buffer_allocations",
|
||||
":cudnn_batchnorm_runner",
|
||||
":elemental_ir_emitter",
|
||||
":gpu_constants",
|
||||
":gpu_conv_runner",
|
||||
@ -267,6 +266,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service:while_loop_analysis",
|
||||
@ -282,7 +282,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:sort_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -31,13 +31,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
|
||||
const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type, int64 batch_size, int64 n,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kCholesky, hlo),
|
||||
PrimitiveType type, int64 batch_size, int64 n)
|
||||
: Thunk(Kind::kCholesky, thunk_info),
|
||||
uplo_(options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper),
|
||||
a_buffer_(a_buffer),
|
||||
@ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
info_buffer_(info_buffer),
|
||||
type_(type),
|
||||
batch_size_(batch_size),
|
||||
a_batch_stride_(n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
hlo->operand(0)->shape().element_type())),
|
||||
a_batch_stride_(
|
||||
n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
|
||||
n_(n) {}
|
||||
|
||||
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
@ -41,12 +41,11 @@ namespace gpu {
|
||||
class CholeskyThunk : public Thunk {
|
||||
public:
|
||||
static StatusOr<int64> ScratchBufferSize(int64 n);
|
||||
CholeskyThunk(const CholeskyOptions& options,
|
||||
CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type,
|
||||
int64 batch_size, int64 n, const HloInstruction* hlo);
|
||||
BufferAllocation::Slice info_buffer, PrimitiveType type,
|
||||
int64 batch_size, int64 n);
|
||||
|
||||
CholeskyThunk(const CholeskyThunk&) = delete;
|
||||
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
||||
|
@ -218,14 +218,14 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||
} // anonymous namespace
|
||||
|
||||
CollectivePermuteThunk::CollectivePermuteThunk(
|
||||
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {}
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {}
|
||||
|
||||
Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction());
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
// Rendezvous with the threads for all other devices that are participating in
|
||||
// this CollectivePermute.
|
||||
|
@ -26,9 +26,9 @@ namespace gpu {
|
||||
// Thunk that implements the collective-permute HLO.
|
||||
class CollectivePermuteThunk : public Thunk {
|
||||
public:
|
||||
CollectivePermuteThunk(const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr);
|
||||
CollectivePermuteThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -24,12 +24,14 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ConditionalThunk::ConditionalThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& branch_index_buffer_index,
|
||||
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
|
||||
std::vector<ThunkSequence> branch_thunk_sequences,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kConditional, hlo),
|
||||
branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED),
|
||||
std::vector<ThunkSequence> branch_thunk_sequences)
|
||||
: Thunk(Kind::kConditional, thunk_info),
|
||||
branch_index_is_bool_(
|
||||
thunk_info.hlo_instruction->operand(0)->shape().element_type() ==
|
||||
PRED),
|
||||
branch_index_buffer_index_(branch_index_buffer_index),
|
||||
branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(),
|
||||
branch_operand_buffer_indexes.end()) {
|
||||
@ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk(
|
||||
branch_thunks_.reserve(branch_thunk_sequences.size());
|
||||
for (auto& branch_thunk_sequence : branch_thunk_sequences) {
|
||||
branch_thunks_.emplace_back(
|
||||
new SequentialThunk(std::move(branch_thunk_sequence), nullptr));
|
||||
new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& profiler = *params.profiler;
|
||||
auto& stream = *params.stream;
|
||||
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
|
||||
// Copy the predicate value from device.
|
||||
int32 branch_index = -1;
|
||||
bool pred = false;
|
||||
|
@ -43,10 +43,10 @@ namespace gpu {
|
||||
class ConditionalThunk : public Thunk {
|
||||
public:
|
||||
ConditionalThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& branch_index_buffer_index,
|
||||
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
|
||||
std::vector<ThunkSequence> branch_thunk_sequences,
|
||||
const HloInstruction* hlo);
|
||||
std::vector<ThunkSequence> branch_thunk_sequences);
|
||||
|
||||
ConditionalThunk(const ConditionalThunk&) = delete;
|
||||
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -30,16 +31,16 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ConvolutionThunk::ConvolutionThunk(
|
||||
const HloCustomCallInstruction* cudnn_call,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
|
||||
BufferAllocation::Slice tuple_result_slice)
|
||||
: Thunk(Kind::kConvolution, cudnn_call),
|
||||
cudnn_call_(cudnn_call),
|
||||
: Thunk(Kind::kConvolution, thunk_info),
|
||||
operand_buffers_(std::move(operand_slices)),
|
||||
result_buffer_(result_slice),
|
||||
scratch_buffer_(scratch_slice),
|
||||
tuple_result_buffer_(tuple_result_slice) {}
|
||||
tuple_result_buffer_(tuple_result_slice) {
|
||||
cudnn_call_ = Cast<HloCustomCallInstruction>(hlo_instruction());
|
||||
}
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
const auto& buffer_allocations = *params.buffer_allocations;
|
||||
@ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(scratch_buffer_);
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers),
|
||||
result_buffer, scratch, params.stream));
|
||||
|
||||
|
@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
|
||||
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
|
||||
//
|
||||
// operand_slices should be in the same order as cudnn_call->operands().
|
||||
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
|
||||
ConvolutionThunk(ThunkInfo thunk_info,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice,
|
||||
BufferAllocation::Slice scratch_slice,
|
||||
|
@ -22,10 +22,9 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
HostToDeviceCopyThunk::HostToDeviceCopyThunk(
|
||||
const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kCopy, hlo_instruction),
|
||||
ThunkInfo thunk_info, const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
|
||||
: Thunk(Kind::kCopy, thunk_info),
|
||||
source_address_(source_address),
|
||||
destination_buffer_(destination_buffer),
|
||||
mem_size_(mem_size) {}
|
||||
@ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kCopy, hlo_instruction),
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer, uint64 mem_size)
|
||||
: Thunk(Kind::kCopy, thunk_info),
|
||||
source_buffer_(source_buffer),
|
||||
destination_buffer_(destination_buffer),
|
||||
mem_size_(mem_size) {}
|
||||
@ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase source_data =
|
||||
params.buffer_allocations->GetDeviceAddress(source_buffer_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemcpy(&destination_data, source_data, mem_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk {
|
||||
// Constructs a CopyThunk that copies host data from `source_address` to the
|
||||
// device buffer `destination_buffer`. `mem_size` is the size of the data in
|
||||
// bytes.
|
||||
HostToDeviceCopyThunk(const void* source_address,
|
||||
HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
uint64 mem_size, const HloInstruction* hlo_instruction);
|
||||
uint64 mem_size);
|
||||
|
||||
HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete;
|
||||
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
|
||||
@ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk {
|
||||
// Constructs a CopyThunk that copies host data from `source_buffer` to the
|
||||
// device buffer `destination_buffer`. `mem_size` is the size of the data in
|
||||
// bytes.
|
||||
DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer,
|
||||
DeviceToDeviceCopyThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
uint64 mem_size,
|
||||
const HloInstruction* hlo_instruction);
|
||||
uint64 mem_size);
|
||||
|
||||
DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete;
|
||||
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
|
||||
|
@ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) {
|
||||
} // namespace
|
||||
|
||||
CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& variance, float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo),
|
||||
const BufferAllocation::Slice& output)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
@ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
|
||||
epsilon_(epsilon),
|
||||
feature_index_(feature_index),
|
||||
output_(output) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(),
|
||||
kCudnnBatchNormForwardInferenceCallTarget);
|
||||
@ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||
const ExecuteParams& params) {
|
||||
auto& buffer_allocations = *params.buffer_allocations;
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
se::DeviceMemoryBase output_base =
|
||||
buffer_allocations.GetDeviceAddress(output_);
|
||||
se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
|
||||
@ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
|
||||
}
|
||||
|
||||
CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||
float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo),
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
offset_(offset),
|
||||
@ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
output_mean_(output_mean),
|
||||
output_inv_stddev_(output_inv_stddev),
|
||||
output_tuple_(output_tuple) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget);
|
||||
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
|
||||
@ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
|
||||
se::DeviceMemory<float> null_device_ptr(nullptr);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
auto& stream = *params.stream;
|
||||
TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
|
||||
hlo_instruction(), operand, output_data, output_mean, output_inv_stddev,
|
||||
@ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
}
|
||||
|
||||
CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
const BufferAllocation::Slice& grad_output, float epsilon,
|
||||
int64 feature_index, const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo),
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
|
||||
operand_(operand),
|
||||
scale_(scale),
|
||||
mean_(mean),
|
||||
@ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
output_grad_scale_(output_grad_scale),
|
||||
output_grad_offset_(output_grad_offset),
|
||||
output_tuple_(output_tuple) {
|
||||
const auto* hlo = hlo_instruction();
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall);
|
||||
CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget);
|
||||
CHECK_EQ(hlo->shape().tuple_shapes_size(), 3);
|
||||
@ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
buffer_allocations.GetDeviceAddress(output_grad_offset_));
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
se::Stream* stream = params.stream;
|
||||
TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
|
||||
hlo_instruction(), operand, output_grad_data, grad_output,
|
||||
|
@ -46,14 +46,14 @@ namespace gpu {
|
||||
|
||||
class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand,
|
||||
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& variance,
|
||||
float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output,
|
||||
const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output);
|
||||
|
||||
CudnnBatchNormForwardInferenceThunk(
|
||||
const CudnnBatchNormForwardInferenceThunk&) = delete;
|
||||
@ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
|
||||
class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormForwardTrainingThunk(
|
||||
const BufferAllocation::Slice& operand,
|
||||
ThunkInfo thunk_info, const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
|
||||
CudnnBatchNormForwardTrainingThunk(
|
||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||
@ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
|
||||
class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand,
|
||||
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
@ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple,
|
||||
const HloInstruction* hlo);
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
|
||||
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
|
||||
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
|
||||
|
@ -22,15 +22,15 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
CustomCallThunk::CustomCallThunk(
|
||||
void* call_target,
|
||||
ThunkInfo thunk_info, void* call_target,
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(Thunk::kCustomCall, instr),
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
|
||||
: Thunk(Thunk::kCustomCall, thunk_info),
|
||||
call_target_(call_target),
|
||||
operand_slices_(std::move(operand_slices)),
|
||||
result_slices_(std::move(result_slices)),
|
||||
opaque_(std::move(opaque)) {
|
||||
const HloInstruction* instr = hlo_instruction();
|
||||
CHECK_EQ(instr->operand_count(), operand_slices_.size());
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
const auto& s1 = operand_slices_[i].shape();
|
||||
|
@ -39,10 +39,9 @@ namespace gpu {
|
||||
class CustomCallThunk : public Thunk {
|
||||
public:
|
||||
CustomCallThunk(
|
||||
void* call_target,
|
||||
ThunkInfo thunk_info, void* call_target,
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque,
|
||||
const HloInstruction* instr);
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
struct NcclAllReduceThunk::AuxData {};
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
replica_count_(replica_count),
|
||||
buffers_(std::move(buffers)) {}
|
||||
|
||||
|
@ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
|
||||
absl::Span<const int64> fft_length,
|
||||
const BufferAllocation::Slice& input_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kFft, hlo),
|
||||
const Shape& input_shape, const Shape& output_shape)
|
||||
: Thunk(Kind::kFft, thunk_info),
|
||||
fft_type_(
|
||||
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
|
||||
input_shape.element_type() == C128)),
|
||||
@ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.memory_allocator());
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
if (fft_plan_ == nullptr) {
|
||||
const int64 fft_rank = fft_length_.size();
|
||||
CHECK_LE(fft_rank, 3);
|
||||
|
@ -62,11 +62,11 @@ class FftThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a thunk for launching an FFT on a stream.
|
||||
// Semantics of null hlo_instruction argument are as in Thunk.
|
||||
FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
FftThunk(ThunkInfo thunk_info, FftType fft_type,
|
||||
absl::Span<const int64> fft_length,
|
||||
const BufferAllocation::Slice& input_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const HloInstruction* hlo);
|
||||
const Shape& input_shape, const Shape& output_shape);
|
||||
|
||||
FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
|
||||
FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_
|
||||
|
@ -23,16 +23,15 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ForThunk::ForThunk(const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kWhile, hlo),
|
||||
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
||||
: Thunk(Kind::kWhile, thunk_info),
|
||||
loop_limit_(loop_limit),
|
||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
|
||||
// constructor because this SequentialThunk is logically "part of"
|
||||
// this ForThunk, and shouldn't be profiled separately from it.
|
||||
std::move(*body_thunk_sequence), nullptr)) {}
|
||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
||||
|
||||
void ForThunk::ComputeAnnotations() {
|
||||
Thunk::ComputeAnnotations();
|
||||
@ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
||||
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
for (int64 i = 0; i < loop_limit_; ++i) {
|
||||
params.profiler->StartHloComputation();
|
||||
// Invoke loop body thunk sequence.
|
||||
|
@ -31,9 +31,8 @@ namespace gpu {
|
||||
// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
|
||||
class ForThunk : public Thunk {
|
||||
public:
|
||||
ForThunk(const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo);
|
||||
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
||||
ForThunk(const ForThunk&) = delete;
|
||||
ForThunk& operator=(const ForThunk&) = delete;
|
||||
|
||||
|
@ -132,6 +132,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer,
|
||||
stream,
|
||||
/*implements_whole_instruction=*/true,
|
||||
/*profile_index=*/-1,
|
||||
/*profiler=*/nullptr,
|
||||
/*profile_result=*/&profile_result, algorithm)
|
||||
.ok());
|
||||
|
@ -33,13 +33,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
|
||||
GemmThunk::GemmThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice &lhs_buffer,
|
||||
const BufferAllocation::Slice &rhs_buffer,
|
||||
const BufferAllocation::Slice &output_buffer,
|
||||
bool implements_whole_instruction,
|
||||
const HloInstruction *hlo_instruction,
|
||||
const GemmBackendConfig &backend_config)
|
||||
: Thunk(Kind::kGemm, hlo_instruction),
|
||||
: Thunk(Kind::kGemm, thunk_info),
|
||||
lhs_buffer_(lhs_buffer),
|
||||
rhs_buffer_(rhs_buffer),
|
||||
output_buffer_(output_buffer),
|
||||
@ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) {
|
||||
se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
|
||||
return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
|
||||
output_data, params.stream, implements_whole_instruction_,
|
||||
params.profiler);
|
||||
profile_index(), params.profiler);
|
||||
}
|
||||
|
||||
// This struct contains the metadata of a matrix, e.g., its base address and
|
||||
@ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm,
|
||||
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream *stream,
|
||||
bool implements_whole_instruction,
|
||||
absl::optional<int64> profile_index,
|
||||
HloExecutionProfiler *profiler,
|
||||
se::blas::ProfileResult *profile_result,
|
||||
absl::optional<se::blas::AlgorithmType> algorithm) {
|
||||
@ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm,
|
||||
rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
|
||||
std::unique_ptr<ScopedInstructionProfiler> op_profiler =
|
||||
profiler ? profiler->MakeScopedInstructionProfiler(
|
||||
implements_whole_instruction ? gemm : nullptr)
|
||||
implements_whole_instruction ? profile_index : -1)
|
||||
: nullptr;
|
||||
|
||||
if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
|
||||
|
@ -39,11 +39,10 @@ class GemmThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
|
||||
// BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
|
||||
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer,
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
bool implements_whole_instruction,
|
||||
const HloInstruction* hlo_instruction,
|
||||
const GemmBackendConfig& backend_config);
|
||||
|
||||
GemmThunk(const GemmThunk&) = delete;
|
||||
@ -72,7 +71,8 @@ Status RunGemm(
|
||||
const HloInstruction* gemm, const GemmBackendConfig& backend_config,
|
||||
se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream* stream,
|
||||
bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr,
|
||||
bool implements_whole_instruction, absl::optional<int64> profile_index,
|
||||
HloExecutionProfiler* profiler = nullptr,
|
||||
se::blas::ProfileResult* profile_result = nullptr,
|
||||
absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt);
|
||||
|
||||
|
@ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
|
||||
int pointer_size, std::unique_ptr<llvm::Module>* llvm_module,
|
||||
int pointer_size, const HloProfileIndexMap* profile_index_map,
|
||||
std::unique_ptr<llvm::Module>* llvm_module,
|
||||
std::unique_ptr<BufferAssignment>* buffer_assignment,
|
||||
std::unique_ptr<ThunkSchedule>* thunk_schedule) {
|
||||
*llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
|
||||
@ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
|
||||
cuda_compute_capability, llvm_module->get());
|
||||
cuda_compute_capability, profile_index_map, llvm_module->get());
|
||||
|
||||
HloComputation* entry_computation = hlo_module->entry_computation();
|
||||
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
|
||||
@ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
// not all explicitly checked, but at least we can document them here:
|
||||
// * The entry HloComputation shall not have dead code (all reachable from
|
||||
// ROOT).
|
||||
// * For each visit of HloInstruction, either none or one Thunk will be
|
||||
// returned.
|
||||
// * The visited instructions are all instructions in the entry
|
||||
// computation.
|
||||
// * For each visit of these HloInstructions, either none or one Thunk
|
||||
// will be returned.
|
||||
// * If there is a thunk returned, thunk->hlo_instruction() equals the
|
||||
// input HloInstruction*.
|
||||
// * A returned thunk may contain other sub-thunks. A sub-thunk may or may
|
||||
// not have an associated hlo_instruction().
|
||||
TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString();
|
||||
if (!thunks->empty()) {
|
||||
auto thunk = std::move(thunks->front());
|
||||
@ -603,6 +608,25 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
return cuda_compute_capability;
|
||||
}();
|
||||
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||
|
||||
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
|
||||
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
|
||||
cost_analysis.set_bytes_per_second(
|
||||
stream_exec->GetDeviceDescription().memory_bandwidth());
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
|
||||
VLOG(1) << "HLO memory read+written: "
|
||||
<< tensorflow::strings::HumanReadableNumBytes(
|
||||
cost_analysis.bytes_accessed());
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
|
||||
profile_printer =
|
||||
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
|
||||
module->entry_computation()->name());
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> llvm_module;
|
||||
std::unique_ptr<BufferAssignment> buffer_assignment;
|
||||
std::unique_ptr<ThunkSchedule> thunk_schedule;
|
||||
@ -610,8 +634,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
|
||||
module.get(), &llvm_context, target_triple_, data_layout_,
|
||||
stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
|
||||
GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment,
|
||||
&thunk_schedule));
|
||||
GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
|
||||
&buffer_assignment, &thunk_schedule));
|
||||
|
||||
if (user_pre_optimization_hook_) {
|
||||
user_pre_optimization_hook_(*llvm_module);
|
||||
@ -653,25 +677,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||
|
||||
if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) {
|
||||
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
|
||||
cost_analysis.set_bytes_per_second(
|
||||
stream_exec->GetDeviceDescription().memory_bandwidth());
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
|
||||
VLOG(1) << "HLO memory read+written: "
|
||||
<< tensorflow::strings::HumanReadableNumBytes(
|
||||
cost_analysis.bytes_accessed());
|
||||
if (module->config().hlo_profiling_enabled()) {
|
||||
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
|
||||
profile_printer =
|
||||
CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
|
||||
module->entry_computation()->name());
|
||||
}
|
||||
}
|
||||
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
backend_result.first, backend_result.second, gpu_version,
|
||||
std::move(thunk_schedule), std::move(module),
|
||||
@ -709,7 +714,8 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
|
||||
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
|
||||
hlo_module, llvm_context, target_triple, data_layout, platform_name,
|
||||
gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
|
||||
pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule));
|
||||
pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
|
||||
&buffer_assignment, &thunk_schedule));
|
||||
return llvm_module;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() {
|
||||
}
|
||||
}
|
||||
|
||||
void HloExecutionProfiler::FinishHloInstruction(
|
||||
const HloInstruction* hlo_instruction) {
|
||||
void HloExecutionProfiler::FinishHloInstruction(size_t index) {
|
||||
if (do_profile_) {
|
||||
hlo_instructions_.erase(hlo_instruction);
|
||||
profile_->SetCyclesTakenBy(
|
||||
hlo_instruction,
|
||||
GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
|
||||
indices_.erase(index);
|
||||
profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_,
|
||||
stream_, clock_rate_ghz_));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ScopedInstructionProfiler>
|
||||
HloExecutionProfiler::MakeScopedInstructionProfiler(
|
||||
const HloInstruction* hlo_instruction) {
|
||||
if (do_profile_ && hlo_instruction != nullptr) {
|
||||
absl::optional<int64> index) {
|
||||
if (do_profile_ && index.has_value()) {
|
||||
// Make sure that we are not already measuring the time for the same
|
||||
// 'hlo_instruction'.
|
||||
CHECK(hlo_instructions_.insert(hlo_instruction).second)
|
||||
<< hlo_instruction->name();
|
||||
// instruction.
|
||||
// TODO(timshen): provide more useful printout.
|
||||
CHECK(indices_.insert(*index).second) << *index;
|
||||
}
|
||||
return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
|
||||
return absl::make_unique<ScopedInstructionProfiler>(this, index);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
@ -58,14 +57,17 @@ class HloExecutionProfiler {
|
||||
void StartHloInstruction();
|
||||
|
||||
// If profiling is enabled, stops the per-operation timer and records the time
|
||||
// that the hlo_instruction took to execute in the profile.
|
||||
void FinishHloInstruction(const HloInstruction* hlo_instruction);
|
||||
// that at `profile_index`. Profile indices can be looked up from
|
||||
// HloProfileIndexMap.
|
||||
void FinishHloInstruction(size_t profile_index);
|
||||
|
||||
// Returns a ScopedInstructionProfiler and triggers a call to
|
||||
// StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
|
||||
// out of scope, it triggers a call to FinishHloInstruction().
|
||||
//
|
||||
// If profile_index < 0, it results in a no-op.
|
||||
std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
|
||||
const HloInstruction* hlo_instruction);
|
||||
absl::optional<int64> profile_index);
|
||||
|
||||
private:
|
||||
const bool do_profile_;
|
||||
@ -77,7 +79,7 @@ class HloExecutionProfiler {
|
||||
std::stack<std::unique_ptr<se::Timer>> timers_;
|
||||
// Contains the HLO instructions for which we are currently measuring the
|
||||
// time.
|
||||
std::unordered_set<const HloInstruction*> hlo_instructions_;
|
||||
std::unordered_set<size_t> indices_;
|
||||
bool finished_execution_ = false;
|
||||
};
|
||||
|
||||
@ -87,21 +89,21 @@ class HloExecutionProfiler {
|
||||
class ScopedInstructionProfiler {
|
||||
public:
|
||||
ScopedInstructionProfiler(HloExecutionProfiler* profiler,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: profiler_(profiler), hlo_instruction_(hlo_instruction) {
|
||||
if (hlo_instruction != nullptr) {
|
||||
absl::optional<int64> index)
|
||||
: profiler_(profiler), index_(index) {
|
||||
if (index_.has_value()) {
|
||||
profiler->StartHloInstruction();
|
||||
}
|
||||
}
|
||||
~ScopedInstructionProfiler() {
|
||||
if (hlo_instruction_ != nullptr) {
|
||||
profiler_->FinishHloInstruction(hlo_instruction_);
|
||||
if (index_.has_value()) {
|
||||
profiler_->FinishHloInstruction(*index_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
HloExecutionProfiler* profiler_;
|
||||
const HloInstruction* hlo_instruction_;
|
||||
absl::optional<int64> index_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -23,9 +23,9 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
InfeedThunk::InfeedThunk(
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
|
||||
ThunkInfo thunk_info,
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
|
||||
: Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {}
|
||||
|
||||
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& stream = *params.stream;
|
||||
@ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
ShapeTree<InfeedBuffer> infeed_buffers =
|
||||
GetOrCreateInfeedManager()->BlockingGetNextDestination();
|
||||
|
||||
|
@ -34,8 +34,8 @@ class InfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a InfeedThunk that copies data from the on-device
|
||||
// infeed queue into the buffers in the given shape tree.
|
||||
InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
|
||||
const HloInstruction* hlo_instruction);
|
||||
InfeedThunk(ThunkInfo thunk_info,
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices);
|
||||
|
||||
InfeedThunk(const InfeedThunk&) = delete;
|
||||
InfeedThunk& operator=(const InfeedThunk&) = delete;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
|
||||
namespace xla {
|
||||
@ -33,12 +34,13 @@ class IrEmitterContext {
|
||||
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
|
||||
std::string platform_name, GpuDeviceInfo gpu_device_info,
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
llvm::Module* llvm_module)
|
||||
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
|
||||
: hlo_module_(hlo_module),
|
||||
buffer_assignment_(buffer_assignment),
|
||||
platform_name_(std::move(platform_name)),
|
||||
gpu_device_info_(gpu_device_info),
|
||||
cuda_compute_capability_(cuda_compute_capability),
|
||||
profile_index_map_(profile_index_map),
|
||||
llvm_module_(llvm_module) {}
|
||||
// Disallow copy and assign.
|
||||
IrEmitterContext(const IrEmitterContext&) = delete;
|
||||
@ -54,6 +56,7 @@ class IrEmitterContext {
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability() const {
|
||||
return cuda_compute_capability_;
|
||||
}
|
||||
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
|
||||
llvm::Module* llvm_module() { return llvm_module_; }
|
||||
NameUniquer* name_uniquer() { return &name_uniquer_; }
|
||||
|
||||
@ -63,6 +66,7 @@ class IrEmitterContext {
|
||||
std::string platform_name_;
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability_;
|
||||
const HloProfileIndexMap* profile_index_map_;
|
||||
llvm::Module* llvm_module_;
|
||||
NameUniquer name_uniquer_;
|
||||
};
|
||||
|
@ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
/*updates_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2))));
|
||||
}
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(fusion), std::move(thunks)));
|
||||
return Status::OK();
|
||||
}
|
||||
// In the case of root tuple, it can be either reduce or slice input
|
||||
@ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
||||
auto destination_buffer = GetAllocationSlice(*copy);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
GetThunkInfo(copy),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/
|
||||
ByteSizeOf(copy->operand(0)->shape()), copy));
|
||||
ByteSizeOf(copy->operand(0)->shape())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
|
||||
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
|
||||
}
|
||||
AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
|
||||
GetThunkInfo(tuple), tuple_element_buffers,
|
||||
GetAllocationSlice(*tuple)));
|
||||
return Status::OK();
|
||||
}
|
||||
AddThunkToThunkSequence(
|
||||
@ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
thunks.push_back(BuildKernelThunk(select_and_scatter,
|
||||
/*implements_whole_instruction=*/false));
|
||||
std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
|
||||
absl::make_unique<SequentialThunk>(GetThunkInfo(select_and_scatter),
|
||||
std::move(thunks));
|
||||
|
||||
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
|
||||
if (window_util::HasDilation(window)) {
|
||||
@ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
auto destination_buffer = GetAllocationSlice(*scatter);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
|
||||
/*hlo_instruction=*/nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
|
||||
}
|
||||
|
||||
thunks.push_back(
|
||||
@ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(scatter), std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
// TODO(b/26783907): Figure out why we never seem to share buffers for
|
||||
// key/value sort.
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/source_address,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
|
||||
nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
|
||||
}
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(sort), std::move(thunks)));
|
||||
if (sort->operand_count() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
@ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
|
||||
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
|
||||
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)),
|
||||
GetAllocationSlice(*hlo)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||
}
|
||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||
GetThunkInfo(crs),
|
||||
/*replica_count=*/hlo_module_config_.replica_count(),
|
||||
/*buffers=*/std::move(buffers), crs);
|
||||
/*buffers=*/std::move(buffers));
|
||||
if (crs->shape().IsTuple()) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(std::move(all_reduce_thunk));
|
||||
thunks.push_back(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
|
||||
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(crs), std::move(thunks)));
|
||||
} else {
|
||||
AddThunkToThunkSequence(std::move(all_reduce_thunk));
|
||||
}
|
||||
@ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
CHECK(crs->operand(0)->shape().IsArray())
|
||||
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
GetThunkInfo(crs),
|
||||
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
|
||||
/*destination_buffer=*/GetAllocationSlice(*crs),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
.GetUniqueSlice(crs, {i})
|
||||
.ValueOrDie());
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
|
||||
/*destination_buffer=*/tuple_element_buffers.back(),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape())));
|
||||
}
|
||||
|
||||
// Output a tuple of the buffers above.
|
||||
thunks.push_back(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
|
||||
Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs)));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
|
||||
absl::make_unique<SequentialThunk>(GetThunkInfo(crs), std::move(thunks)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1787,8 +1794,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
}
|
||||
|
||||
return absl::make_unique<KernelThunk>(
|
||||
non_constant_buffers, std::string(kernel->getName()),
|
||||
implements_whole_instruction ? inst : nullptr);
|
||||
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
|
||||
non_constant_buffers, std::string(kernel->getName()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
@ -1838,8 +1845,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
absl::Span<const uint8> literal_bytes(
|
||||
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
|
||||
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
|
||||
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
|
||||
nullptr)};
|
||||
return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(),
|
||||
GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
|
||||
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
|
||||
@ -1857,7 +1864,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
}
|
||||
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
|
||||
return {absl::make_unique<Memset32BitValueThunk>(
|
||||
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
|
||||
Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
|
||||
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
|
||||
@ -1868,7 +1875,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
uint32 word;
|
||||
memcpy(&word, literal_bytes.data(), sizeof(word));
|
||||
return {absl::make_unique<Memset32BitValueThunk>(
|
||||
word, GetAllocationSlice(*hlo, index), nullptr)};
|
||||
Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))};
|
||||
}
|
||||
}
|
||||
|
||||
@ -2014,9 +2021,10 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return absl::make_unique<WhileThunk>(
|
||||
GetThunkInfo(hlo),
|
||||
GetAllocationSlice(*condition->root_instruction()), // cond result
|
||||
ir_emitter_condition.ConsumeThunkSequence(),
|
||||
ir_emitter_body.ConsumeThunkSequence(), hlo);
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
|
||||
@ -2031,8 +2039,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return absl::make_unique<ForThunk>(
|
||||
loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
|
||||
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
|
||||
@ -2054,8 +2062,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
|
||||
}
|
||||
|
||||
return absl::make_unique<ConditionalThunk>(
|
||||
GetAllocationSlice(*hlo->operand(0)), branch_operands,
|
||||
std::move(branch_thunks), hlo);
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
|
||||
std::move(branch_thunks));
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
@ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
thunks.push_back(std::move(kernel_thunk));
|
||||
auto sequential_thunk =
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
|
||||
auto sequential_thunk = absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(unnested_hlo), std::move(thunks));
|
||||
AddThunkToThunkSequence(std::move(sequential_thunk));
|
||||
|
||||
return Status::OK();
|
||||
@ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
|
||||
return emit_status;
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
|
||||
if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
|
||||
info.profile_index.emplace(
|
||||
static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Returns the last generated thunk.
|
||||
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
|
||||
|
||||
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
|
||||
|
||||
// The thunk sequence this IrEmitter generates for the input computation.
|
||||
ThunkSequence thunk_sequence_;
|
||||
|
||||
|
@ -33,10 +33,10 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kKernel, hlo_instruction),
|
||||
KernelThunk::KernelThunk(ThunkInfo thunk_info,
|
||||
absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name)
|
||||
: Thunk(Kind::kKernel, thunk_info),
|
||||
args_(args.begin(), args.end()),
|
||||
kernel_name_(kernel_name) {}
|
||||
|
||||
@ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
return ExecuteKernelOnStream(*kernel, buffer_args,
|
||||
launch_dimensions.threads_per_block(),
|
||||
launch_dimensions.block_count(), params.stream);
|
||||
|
@ -47,8 +47,9 @@ class KernelThunk : public Thunk {
|
||||
// Constructs a thunk for the given kernel.
|
||||
//
|
||||
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
|
||||
KernelThunk(absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name, const HloInstruction* hlo_instruction);
|
||||
KernelThunk(ThunkInfo thunk_info,
|
||||
absl::Span<const BufferAllocation* const> args,
|
||||
const string& kernel_name);
|
||||
KernelThunk(const KernelThunk&) = delete;
|
||||
KernelThunk& operator=(const KernelThunk&) = delete;
|
||||
~KernelThunk() override = default;
|
||||
|
@ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase dest_data =
|
||||
params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemZero(&dest_data, dest_data.size());
|
||||
return Status::OK();
|
||||
}
|
||||
@ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
se::DeviceMemoryBase dest_data =
|
||||
params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
params.stream->ThenMemset32(&dest_data, value_, dest_data.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -32,9 +32,9 @@ namespace gpu {
|
||||
// Thunk that zeroes out a given chunk of memory.
|
||||
class MemzeroThunk : public Thunk {
|
||||
public:
|
||||
explicit MemzeroThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
|
||||
explicit MemzeroThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kMemzero, thunk_info), dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
@ -46,10 +46,11 @@ class MemzeroThunk : public Thunk {
|
||||
// destination chunk must have size divisible by 32 bits.
|
||||
class Memset32BitValueThunk : public Thunk {
|
||||
public:
|
||||
explicit Memset32BitValueThunk(uint32 value,
|
||||
const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
|
||||
explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kMemset32BitValue, thunk_info),
|
||||
value_(value),
|
||||
dest_(dest) {}
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
replica_count_(replica_count),
|
||||
buffers_(std::move(buffers)),
|
||||
aux_data_(absl::make_unique<AuxData>()) {
|
||||
@ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(1) << "Starting NcclAllReduceThunk.";
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
|
||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||
|
@ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk {
|
||||
BufferAllocation::Slice source_buffer;
|
||||
BufferAllocation::Slice destination_buffer;
|
||||
};
|
||||
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
|
||||
const HloInstruction* all_reduce);
|
||||
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
|
||||
std::vector<Buffer> buffers);
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
@ -23,9 +23,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kOutfeed, hlo_instruction),
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices)
|
||||
: Thunk(Kind::kOutfeed, thunk_info),
|
||||
outfeed_slices_(std::move(outfeed_slices)) {}
|
||||
|
||||
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
@ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
@ -32,8 +32,8 @@ class OutfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a OutfeedThunk that copies data to the host-side
|
||||
// outfeed queue from the buffers in the given shape tree.
|
||||
OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
|
||||
const HloInstruction* hlo_instruction);
|
||||
OutfeedThunk(ThunkInfo thunk_info,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices);
|
||||
|
||||
OutfeedThunk(const OutfeedThunk&) = delete;
|
||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||
|
@ -18,13 +18,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr)
|
||||
: Thunk(Kind::kReplicaId, instr), dest_(dest) {}
|
||||
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
|
||||
|
||||
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
|
@ -26,8 +26,7 @@ namespace gpu {
|
||||
// Thunk that implements the ReplicaId HLO.
|
||||
class ReplicaIdThunk : public Thunk {
|
||||
public:
|
||||
ReplicaIdThunk(const BufferAllocation::Slice& dest,
|
||||
const HloInstruction* instr);
|
||||
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -24,9 +24,9 @@ namespace gpu {
|
||||
|
||||
using ::tensorflow::profiler::ScopedAnnotation;
|
||||
|
||||
SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
|
||||
SequentialThunk::SequentialThunk(ThunkInfo thunk_info,
|
||||
std::vector<std::unique_ptr<Thunk>> thunks)
|
||||
: Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {}
|
||||
|
||||
void SequentialThunk::ComputeAnnotations() {
|
||||
for (const auto& thunk : thunks_) {
|
||||
@ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
||||
|
||||
Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
for (const auto& thunk : thunks_) {
|
||||
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
|
||||
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params));
|
||||
|
@ -32,8 +32,8 @@ namespace gpu {
|
||||
// require multiple kernel launches or library calls.
|
||||
class SequentialThunk : public Thunk {
|
||||
public:
|
||||
SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks,
|
||||
const HloInstruction* hlo);
|
||||
SequentialThunk(ThunkInfo thunk_info,
|
||||
std::vector<std::unique_ptr<Thunk>> thunks);
|
||||
SequentialThunk(const SequentialThunk&) = delete;
|
||||
SequentialThunk& operator=(const SequentialThunk&) = delete;
|
||||
|
||||
|
@ -68,13 +68,21 @@ class Thunk {
|
||||
kWhile,
|
||||
};
|
||||
|
||||
struct ThunkInfo {
|
||||
const HloInstruction* hlo_instruction = nullptr;
|
||||
absl::optional<int64> profile_index;
|
||||
// TODO(timshen): Remove hlo_instruction and add name(),
|
||||
// profile_annotation() here.
|
||||
};
|
||||
|
||||
// The hlo_instruction argument is meant to be the instruction this thunk was
|
||||
// generated from, but Thunk never uses this argument other than to save it
|
||||
// to Thunk::hlo_instruction, so it can be null.
|
||||
explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
|
||||
explicit Thunk(Kind kind, ThunkInfo thunk_info)
|
||||
: kind_(kind),
|
||||
hlo_instruction_(hlo_instruction),
|
||||
name_(hlo_instruction_ ? hlo_instruction_->name() : "") {}
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
name_(hlo_instruction_ ? hlo_instruction_->name() : ""),
|
||||
profile_index_(thunk_info.profile_index) {}
|
||||
virtual ~Thunk() {}
|
||||
Thunk(const Thunk&) = delete;
|
||||
Thunk& operator=(const Thunk&) = delete;
|
||||
@ -128,6 +136,8 @@ class Thunk {
|
||||
protected:
|
||||
const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
|
||||
|
||||
absl::optional<int64> profile_index() const { return profile_index_; }
|
||||
|
||||
const HloModuleConfig& GetModuleConfig() const {
|
||||
return hlo_instruction()->GetModule()->config();
|
||||
}
|
||||
@ -146,8 +156,12 @@ class Thunk {
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
|
||||
// Will be removed in the future, as Thunk is migrating away from the
|
||||
// monolithic HloInstruction.
|
||||
const HloInstruction* hlo_instruction_;
|
||||
std::string name_;
|
||||
absl::optional<int64> profile_index_;
|
||||
string profile_annotation_;
|
||||
};
|
||||
|
||||
|
@ -40,11 +40,11 @@ namespace gpu {
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
|
||||
const HloInstruction* operand = inst->operand(0);
|
||||
return absl::make_unique<FftThunk>(
|
||||
inst->fft_type(), inst->fft_length(),
|
||||
context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(),
|
||||
/*input_buffer=*/GetAllocationSlice(*operand),
|
||||
/*output_buffer=*/GetAllocationSlice(*inst),
|
||||
/*input_shape=*/operand->shape(),
|
||||
/*output_shape=*/inst->shape(), inst);
|
||||
/*output_shape=*/inst->shape());
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
|
||||
@ -63,11 +63,11 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
|
||||
: n * n * elem_size;
|
||||
int64 b_batch_stride = m * n * elem_size;
|
||||
return absl::make_unique<TriangularSolveThunk>(
|
||||
inst->triangular_solve_options(),
|
||||
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
|
||||
/*a_input_buffer=*/GetAllocationSlice(*a),
|
||||
/*b_input_buffer=*/GetAllocationSlice(*inst),
|
||||
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
|
||||
b_batch_stride, inst);
|
||||
b_batch_stride);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
@ -86,24 +86,27 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_buffer=*/GetAllocationSlice(*bias),
|
||||
/*destination_buffer=*/GetAllocationSlice(*inst),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape())));
|
||||
thunks.push_back(absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/false, inst,
|
||||
std::move(gemm_config)));
|
||||
return absl::make_unique<SequentialThunk>(std::move(thunks), inst);
|
||||
/*implements_whole_instruction=*/false, std::move(gemm_config)));
|
||||
return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(thunks));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/true, inst, std::move(gemm_config));
|
||||
/*implements_whole_instruction=*/true, std::move(gemm_config));
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
||||
@ -115,7 +118,7 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
||||
[&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
|
||||
*slice = GetAllocationSlice(*inst, index);
|
||||
});
|
||||
return absl::make_unique<InfeedThunk>(slices, inst);
|
||||
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst), slices);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
@ -130,7 +133,8 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
*slice = status_or_slice.ValueOrDie();
|
||||
}
|
||||
});
|
||||
return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
|
||||
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(slices));
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
@ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*variance=*/GetAllocationSlice(*custom_call->operand(4)),
|
||||
/*epsilon=*/epsilon_value,
|
||||
/*feature_index=*/feature_index_value,
|
||||
/*output=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto output_inv_stddev = GetAllocationSlice(*custom_call, {2});
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*output_data=*/output_data,
|
||||
/*output_mean=*/output_mean,
|
||||
/*output_inv_stddev=*/output_inv_stddev,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto output_grad_scale = GetAllocationSlice(*custom_call, {1});
|
||||
auto output_grad_offset = GetAllocationSlice(*custom_call, {2});
|
||||
AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
|
||||
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
|
||||
/*mean=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
@ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*output_grad_data=*/output_grad_data,
|
||||
/*output_grad_scale=*/output_grad_scale,
|
||||
/*output_grad_offset=*/output_grad_offset,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call),
|
||||
/*hlo=*/custom_call));
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
|
||||
context_->GetThunkInfo(custom_call), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice, tuple_result_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
if (operand_buffer != a_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/a_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
|
||||
}
|
||||
|
||||
thunks.push_back(absl::make_unique<CholeskyThunk>(
|
||||
options, a_buffer, workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n,
|
||||
custom_call));
|
||||
context_->GetThunkInfo(custom_call), options, a_buffer,
|
||||
workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(custom_call), std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
ShapeTree<BufferAllocation::Slice> result_slices =
|
||||
get_slices_for_instr(custom_call);
|
||||
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
|
||||
call_target, std::move(operand_slices), std::move(result_slices),
|
||||
Cast<HloCustomCallInstruction>(custom_call)->opaque(), custom_call));
|
||||
context_->GetThunkInfo(custom_call), call_target,
|
||||
std::move(operand_slices), std::move(result_slices),
|
||||
Cast<HloCustomCallInstruction>(custom_call)->opaque()));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
@ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
auto destination_buffer = GetAllocationSlice(*hlo);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(hlo),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
|
||||
}
|
||||
|
||||
thunks.push_back(BuildTriangularSolveThunk(hlo));
|
||||
@ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), hlo));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(hlo), std::move(thunks)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
CHECK(hlo);
|
||||
Thunk::ThunkInfo info;
|
||||
info.hlo_instruction = hlo;
|
||||
return info;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -36,6 +36,7 @@ class ThunkEmitter {
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const = 0;
|
||||
virtual int64 ByteSizeOf(const Shape& shape) const = 0;
|
||||
virtual absl::string_view platform_name() const = 0;
|
||||
virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
|
||||
|
||||
virtual ~EmissionContext() = default;
|
||||
};
|
||||
|
@ -32,12 +32,12 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
TriangularSolveThunk::TriangularSolveThunk(
|
||||
const TriangularSolveOptions& options,
|
||||
ThunkInfo thunk_info, const TriangularSolveOptions& options,
|
||||
const BufferAllocation::Slice& a_buffer,
|
||||
const BufferAllocation::Slice& b_buffer, PrimitiveType type,
|
||||
int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
|
||||
int64 b_batch_stride, const HloInstruction* hlo)
|
||||
: Thunk(Kind::kTriangularSolve, hlo),
|
||||
int64 b_batch_stride)
|
||||
: Thunk(Kind::kTriangularSolve, thunk_info),
|
||||
uplo_(options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper),
|
||||
side_(options.left_side() ? se::blas::Side::kLeft
|
||||
|
@ -38,12 +38,12 @@ namespace gpu {
|
||||
// Thread-compatible.
|
||||
class TriangularSolveThunk : public Thunk {
|
||||
public:
|
||||
TriangularSolveThunk(const TriangularSolveOptions& options,
|
||||
TriangularSolveThunk(ThunkInfo thunk_info,
|
||||
const TriangularSolveOptions& options,
|
||||
const BufferAllocation::Slice& a_buffer,
|
||||
const BufferAllocation::Slice& b_buffer,
|
||||
PrimitiveType type, int64 batch_size, int64 m, int64 n,
|
||||
int64 a_batch_stride, int64 b_batch_stride,
|
||||
const HloInstruction* hlo);
|
||||
int64 a_batch_stride, int64 b_batch_stride);
|
||||
|
||||
TriangularSolveThunk(const TriangularSolveThunk&) = delete;
|
||||
TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete;
|
||||
|
@ -34,7 +34,7 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
SafeH2DMemcpy(se::DeviceMemory<void*>(
|
||||
buffer_allocations.GetDeviceAddress(dest_buffer_)),
|
||||
std::move(tuple_data), n, &stream,
|
||||
|
@ -34,10 +34,10 @@ namespace gpu {
|
||||
// issue (b/31336476).
|
||||
class TupleThunk : public Thunk {
|
||||
public:
|
||||
TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
|
||||
const BufferAllocation::Slice& dest_buffer,
|
||||
const HloInstruction* hlo_instruction)
|
||||
: Thunk(Kind::kTuple, hlo_instruction),
|
||||
TupleThunk(ThunkInfo thunk_info,
|
||||
absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
|
||||
const BufferAllocation::Slice& dest_buffer)
|
||||
: Thunk(Kind::kTuple, thunk_info),
|
||||
tuple_element_buffers_(tuple_element_buffers.begin(),
|
||||
tuple_element_buffers.end()),
|
||||
dest_buffer_(dest_buffer) {}
|
||||
|
@ -24,20 +24,20 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
WhileThunk::WhileThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kWhile, hlo),
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
||||
: Thunk(Kind::kWhile, thunk_info),
|
||||
condition_result_buffer_index_(condition_result_buffer_index),
|
||||
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
|
||||
// and body_thunk_sequence_ constructors because these SequentialThunks
|
||||
// are logically "part of" this WhileThunk, and shouldn't be profiled
|
||||
// separately from it.
|
||||
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
std::move(*condition_thunk_sequence), nullptr)),
|
||||
ThunkInfo(), std::move(*condition_thunk_sequence))),
|
||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
std::move(*body_thunk_sequence), nullptr)) {}
|
||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
||||
|
||||
void WhileThunk::ComputeAnnotations() {
|
||||
Thunk::ComputeAnnotations();
|
||||
@ -61,7 +61,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
params.buffer_allocations->GetDeviceAddress(
|
||||
condition_result_buffer_index_);
|
||||
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction());
|
||||
auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index());
|
||||
while (true) {
|
||||
// Invoke thunk sequence for while 'condition' computation.
|
||||
profiler.StartHloComputation();
|
||||
|
@ -39,10 +39,10 @@ namespace gpu {
|
||||
class WhileThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a WhileThunk to compute while instruction 'hlo'.
|
||||
WhileThunk(const BufferAllocation::Slice& condition_result_buffer_index,
|
||||
WhileThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
const HloInstruction* hlo);
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
||||
WhileThunk(const WhileThunk&) = delete;
|
||||
WhileThunk& operator=(const WhileThunk&) = delete;
|
||||
|
||||
|
@ -133,8 +133,12 @@ HloExecutionProfile::HloExecutionProfile(
|
||||
|
||||
void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
|
||||
uint64 cycles_taken) {
|
||||
profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] =
|
||||
cycles_taken;
|
||||
SetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(*hlo),
|
||||
cycles_taken);
|
||||
}
|
||||
|
||||
void HloExecutionProfile::SetCyclesTakenBy(size_t index, uint64 cycles_taken) {
|
||||
profile_counters_[index] = cycles_taken;
|
||||
}
|
||||
|
||||
uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
|
||||
|
@ -114,6 +114,9 @@ class HloExecutionProfile {
|
||||
// Record how many cycles this HLO took to execute.
|
||||
void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken);
|
||||
|
||||
// Record how many cycles this HLO took to execute.
|
||||
void SetCyclesTakenBy(size_t index, uint64 cycles_taken);
|
||||
|
||||
// Returns how many cycles this HLO took to execute. Profiling information
|
||||
// may not be available for some instructions in which case zero is returned.
|
||||
uint64 GetCyclesTakenBy(const HloInstruction& hlo) const;
|
||||
|
@ -88,6 +88,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const override;
|
||||
int64 ByteSizeOf(const Shape& shape) const override;
|
||||
absl::string_view platform_name() const override;
|
||||
|
||||
mlir::Location getLocation(const HloInstruction* instr) const;
|
||||
|
||||
xla::mlir_gpu::EmissionContext* emission_context_;
|
||||
|
@ -436,8 +436,10 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
|
||||
kernel, operand_to_value_map, ordered_operands, assignment, buffers));
|
||||
|
||||
// Finally, create the thunk and set the launch dimensions.
|
||||
auto thunk = absl::make_unique<gpu::KernelThunk>(
|
||||
buffers, kernel.getName().str(), instr);
|
||||
gpu::Thunk::ThunkInfo info;
|
||||
info.hlo_instruction = instr;
|
||||
auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers,
|
||||
kernel.getName().str());
|
||||
|
||||
// Set launch bounds.
|
||||
mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();
|
||||
|
Loading…
x
Reference in New Issue
Block a user