Add a tool that lowers HLO to LLVM IR via XLA GPU and use that to write FileCheck tests for scatter

This will be useful to make sure we keep generating the same LLVM IR as we add an LHLO backend.

I also discovered a source of non-determinism via the scatter.hlo test, where in
IrEmitterUnnested::BuildKernelThunk we were iterating over a std::map that had
pointer keys.  Fix that as well.
PiperOrigin-RevId: 318581722
Change-Id: I0b5f0cdd760880378bb046f66468af2f0c4a6a15
This commit is contained in:
Sanjoy Das 2020-06-26 18:51:42 -07:00 committed by TensorFlower Gardener
parent 07dd408634
commit 15f7e2fca2
9 changed files with 602 additions and 97 deletions

View File

@ -73,7 +73,7 @@ tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'
] ]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs) llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -48,6 +48,7 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot', 'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu', 'tensorflow/compiler/xla/service/mlir_gpu',
'tensorflow/compiler/xla/service/gpu/tests',
] ]
config.mlir_tf_tools_dirs = [ config.mlir_tf_tools_dirs = [
os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s)

View File

@ -1133,14 +1133,12 @@ cc_library(
deps = [ deps = [
":alias_passthrough_params", ":alias_passthrough_params",
":cudnn_batchnorm_rewriter", ":cudnn_batchnorm_rewriter",
":cudnn_pad_for_convolutions",
":fusion_merger", ":fusion_merger",
":gemm_rewriter", ":gemm_rewriter",
":gpu_constants", ":gpu_constants",
":gpu_conv_algorithm_picker", ":gpu_conv_algorithm_picker",
":gpu_conv_padding_legalization",
":gpu_conv_rewriter",
":gpu_copy_insertion", ":gpu_copy_insertion",
":gpu_device_info",
":gpu_executable", ":gpu_executable",
":gpu_hlo_schedule", ":gpu_hlo_schedule",
":gpu_layout_assignment", ":gpu_layout_assignment",
@ -1188,7 +1186,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:hlo_verifier",
@ -1214,11 +1211,8 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor:stream_executor_headers",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core", "@llvm-project//llvm:Core",
], ],
) )

View File

@ -464,6 +464,66 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
return std::move(module); return std::move(module);
} }
static Status CompileModuleToLlvmIrImpl(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
int pointer_size, std::unique_ptr<llvm::Module>* llvm_module,
std::unique_ptr<StreamAssignment>* stream_assignment,
std::unique_ptr<GpuHloSchedule>* hlo_schedule,
std::unique_ptr<BufferAssignment>* buffer_assignment,
std::unique_ptr<ThunkSequence>* thunk_sequence) {
*llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
(*llvm_module)->setTargetTriple(target_triple);
(*llvm_module)->setDataLayout(data_layout);
*stream_assignment = AssignStreams(*hlo_module);
TF_ASSIGN_OR_RETURN(
*hlo_schedule,
GpuHloSchedule::Build(*hlo_module, **stream_assignment, pointer_size));
auto buffer_size_bytes_function =
[pointer_size](const BufferValue& buffer_value) -> int64 {
return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size);
};
TF_ASSIGN_OR_RETURN(
*buffer_assignment,
BufferAssigner::Run(
hlo_module, (*hlo_schedule)->ConsumeHloOrdering(),
buffer_size_bytes_function,
/*color_alignment=*/
[](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
/*allocate_buffers_for_constants=*/true,
/*colorer=*/BufferAssigner::DefaultColorer(),
/*must_not_live_out=*/{}, can_share_buffer_function));
VLOG(1) << "Buffer Assignment Stats "
<< (*buffer_assignment)->GetStats().ToString();
DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
"after_optimizations");
IrEmitterContext ir_emitter_context(
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
cuda_compute_capability, llvm_module->get());
HloComputation* entry_computation = hlo_module->entry_computation();
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
&ir_emitter_context);
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
}
*thunk_sequence = ir_emitter.ConsumeThunkSequence();
return Status::OK();
}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { se::DeviceMemoryAllocator* device_allocator) {
@ -483,35 +543,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
}; };
llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
llvm::Module llvm_module(module->name().c_str(), llvm_context);
// Set the target triple and the data layout.
llvm_module.setTargetTriple(target_triple_);
llvm_module.setDataLayout(data_layout_);
// Determine the HLO schedule, which is an ordering of HLO instructions. This
// is used by buffer assignment to enable buffer reuse, and the same ordering
// must also be used to determine the thunk launch schedule.
std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GpuHloSchedule> hlo_schedule,
GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> buffer_assignment,
BufferAssigner::Run(
module.get(), hlo_schedule->ConsumeHloOrdering(),
BufferSizeBytesFunction(),
/*color_alignment=*/
[](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
/*allocate_buffers_for_constants=*/true,
/*colorer=*/BufferAssigner::DefaultColorer(),
/*must_not_live_out=*/{}, GetCanShareBuffer()));
VLOG(1) << "Buffer Assignment Stats "
<< buffer_assignment->GetStats().ToString();
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
GpuDeviceInfo gpu_device_info; GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit = gpu_device_info.threads_per_block_limit =
stream_exec->GetDeviceDescription().threads_per_block_limit(); stream_exec->GetDeviceDescription().threads_per_block_limit();
@ -531,32 +562,29 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
return cuda_compute_capability; return cuda_compute_capability;
}(); }();
IrEmitterContext ir_emitter_context( std::unique_ptr<llvm::Module> llvm_module;
module.get(), buffer_assignment.get(), stream_exec->platform()->Name(), std::unique_ptr<StreamAssignment> stream_assignment;
gpu_device_info, cuda_compute_capability, &llvm_module); std::unique_ptr<GpuHloSchedule> hlo_schedule;
std::unique_ptr<BufferAssignment> buffer_assignment;
std::unique_ptr<ThunkSequence> thunk_sequence;
HloComputation* entry_computation = module->entry_computation(); TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
IrEmitterUnnested ir_emitter(module->config(), entry_computation, module.get(), &llvm_context, target_triple_, data_layout_,
&ir_emitter_context); stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
GetCanShareBuffer(), pointer_size_, &llvm_module, &stream_assignment,
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); &hlo_schedule, &buffer_assignment, &thunk_sequence));
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
}
if (user_pre_optimization_hook_) { if (user_pre_optimization_hook_) {
user_pre_optimization_hook_(llvm_module); user_pre_optimization_hook_(*llvm_module);
} }
string ir_module_string_before_opt; string ir_module_string_before_opt;
const bool embed_ir_in_executable = const bool embed_ir_in_executable =
module->config().debug_options().xla_embed_ir_in_executable(); module->config().debug_options().xla_embed_ir_in_executable();
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); ir_module_string_before_opt = llvm_ir::DumpModuleToString(*llvm_module);
} }
llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
{ {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
@ -565,7 +593,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
llvm::raw_string_ostream err_stream(err); llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken. // verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n" << "Invalid LLVM IR before optimizations:\n"
<< err_stream.str() << err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
@ -578,11 +606,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>; using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
CompileTargetBinary(module.get(), &llvm_module, CompileTargetBinary(module.get(), llvm_module.get(),
gpu_version, stream_exec)); gpu_version, stream_exec));
auto thunk_schedule = absl::make_unique<ThunkSchedule>( auto thunk_schedule = absl::make_unique<ThunkSchedule>(
ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), std::move(thunk_sequence), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder()); hlo_schedule->ThunkLaunchOrder());
if (DumpingEnabledForHloModule(*module)) { if (DumpingEnabledForHloModule(*module)) {
DumpToFileInDirOrStdout(*module, "", "thunk_schedule", DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
@ -602,8 +630,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
cost_analysis.bytes_accessed()); cost_analysis.bytes_accessed());
if (module->config().hlo_profiling_enabled()) { if (module->config().hlo_profiling_enabled()) {
profile_index_map = absl::make_unique<HloProfileIndexMap>(*module); profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer = CreateHloProfilePrinterData( profile_printer =
*profile_index_map, cost_analysis, entry_computation->name()); CreateHloProfilePrinterData(*profile_index_map, cost_analysis,
module->entry_computation()->name());
} }
} }
@ -625,5 +654,30 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
} }
static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
const HloInstruction*,
const ShapeIndex&) {
return absl::nullopt;
}
StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
int pointer_size) {
std::unique_ptr<llvm::Module> llvm_module;
std::unique_ptr<StreamAssignment> stream_assignment;
std::unique_ptr<GpuHloSchedule> hlo_schedule;
std::unique_ptr<BufferAssignment> buffer_assignment;
std::unique_ptr<ThunkSequence> thunk_sequence;
TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
hlo_module, llvm_context, target_triple, data_layout, platform_name,
gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
pointer_size, &llvm_module, &stream_assignment, &hlo_schedule,
&buffer_assignment, &thunk_sequence));
return llvm_module;
}
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
@ -94,13 +95,17 @@ class GpuCompiler : public LLVMCompiler {
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
// Capture just the pointer size, not the entire GpuCompiler object. // Capture just the pointer size, not the entire GpuCompiler object.
return [pointer_size = pointer_size_](const Shape& shape) { return [pointer_size = pointer_size_](const Shape& shape) {
return GetSizeOfShape(shape, pointer_size);
};
}
static int64 GetSizeOfShape(const Shape& shape, int pointer_size) {
if (shape.is_static() || shape.IsTuple()) { if (shape.is_static() || shape.IsTuple()) {
return ShapeUtil::ByteSizeOf(shape, pointer_size); return ShapeUtil::ByteSizeOf(shape, pointer_size);
} }
// Each dynamic dimension size is represented as a S32. // Each dynamic dimension size is represented as a S32.
int64 metadata_size = sizeof(int32) * shape.dimensions_size(); int64 metadata_size = sizeof(int32) * shape.dimensions_size();
return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size; return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
};
} }
private: private:
@ -117,6 +122,16 @@ class GpuCompiler : public LLVMCompiler {
TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
}; };
// Compile `hlo_module` using XLA GPU and return the LLVM module thus generated.
// The GpuExecutable (and the Thunks that are part of it) are not returned.
StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
const std::string& platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
int pointer_size);
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -1444,8 +1444,6 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
return Status::OK(); return Status::OK();
} }
namespace {} // namespace
Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count() VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
<< "; operand count: " << crs->operand_count() << "; operand count: " << crs->operand_count()
@ -1557,29 +1555,37 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
return Status::OK(); return Status::OK();
} }
// Describes how to access a particular subshape for an HLO. For instance if
// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found
// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we
// dereference twice -- first at index 3, and then at index 4 -- to get the
// address of our buffer.
struct HloBufferSlice {
const HloInstruction* instr;
ShapeIndex hlo_index;
// The root buffer to look at.
BufferAllocation::Slice buffer_slice;
// Describes how to dereference starting at that buffer to get to the buffer
// in question.
ShapeIndex gte_index;
};
// Figures out how to access the buffers for all subshapes of hlo's operands and // Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO). // for hlo itself (i.e. all the buffers produced by HLO).
// //
// Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for // Returns a vector of `HloBufferSlice`s, one for each HLO subshape `hlo` needs
// this key is a pair {Slice, ShapeIndex}, where the slice tells you the root // to access (including one or more for itself).
// buffer to look in, and the ShapeIndex describes how to dereference starting
// at that buffer to get to the buffer in question.
//
// For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
// hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
// is found at slice[3][4]. That is, slice is a void***, which we dereference
// twice -- first at index 3, and then at index 4 -- to get the address of our
// buffer.
// //
// This function conservatively assumes that we'll touch all sub-buffers of // This function conservatively assumes that we'll touch all sub-buffers of
// every operand and of the output. // every operand and of the output.
static std::map<std::pair<const HloInstruction*, ShapeIndex>, static std::vector<HloBufferSlice> GetHloBufferSlices(
std::pair<BufferAllocation::Slice, ShapeIndex>> const HloInstruction* hlo, const BufferAssignment& buffer_assn) {
GetHloBufferSlices(const HloInstruction* hlo, std::vector<HloBufferSlice> result;
const BufferAssignment& buffer_assn) { absl::flat_hash_set<std::pair<const HloInstruction*, ShapeIndex>>
std::map<std::pair<const HloInstruction*, ShapeIndex>, inserted_buffer_slices;
std::pair<BufferAllocation::Slice, ShapeIndex>>
slices;
// Tries to find a slice plus an array of indices i1, ..., iN such that the // Tries to find a slice plus an array of indices i1, ..., iN such that the
// sub-buffer for instr at index can be found at slice[i1]...[iN]. // sub-buffer for instr at index can be found at slice[i1]...[iN].
@ -1646,13 +1652,18 @@ GetHloBufferSlices(const HloInstruction* hlo,
auto add_slices_for = [&](const HloInstruction* instr) { auto add_slices_for = [&](const HloInstruction* instr) {
ShapeUtil::ForEachSubshape( ShapeUtil::ForEachSubshape(
instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
if (slices.count({instr, index})) { if (!inserted_buffer_slices.insert({instr, index}).second) {
// HLOs can have duplicate operands; don't bother redoing work. // HLOs can have duplicate operands; don't bother redoing work.
return; return;
} }
auto maybe_slice = find_slice_for(instr, index); auto maybe_slice = find_slice_for(instr, index);
if (maybe_slice.has_value()) { if (maybe_slice.has_value()) {
slices[{instr, index}] = *maybe_slice; HloBufferSlice hlo_buffer_slice;
hlo_buffer_slice.instr = instr;
hlo_buffer_slice.hlo_index = index;
hlo_buffer_slice.buffer_slice = maybe_slice->first;
hlo_buffer_slice.gte_index = maybe_slice->second;
result.push_back(hlo_buffer_slice);
} else { } else {
VLOG(1) << "Couldn't find buffer for " << instr->ToString() VLOG(1) << "Couldn't find buffer for " << instr->ToString()
<< " at index " << index.ToString(); << " at index " << index.ToString();
@ -1667,7 +1678,7 @@ GetHloBufferSlices(const HloInstruction* hlo,
add_slices_for(operand); add_slices_for(operand);
} }
return slices; return result;
} }
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
@ -1675,9 +1686,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
const BufferAssignment& buffer_assn = const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment(); ir_emitter_context_->buffer_assignment();
std::map<std::pair<const HloInstruction*, ShapeIndex>, std::vector<HloBufferSlice> hlo_slices =
std::pair<BufferAllocation::Slice, ShapeIndex>> GetHloBufferSlices(inst, buffer_assn);
hlo_slices = GetHloBufferSlices(inst, buffer_assn);
// Figure out which buffer allocations need to be passed as arguments to our // Figure out which buffer allocations need to be passed as arguments to our
// kernel. This is simply all of the allocations referenced in hlo_slices, // kernel. This is simply all of the allocations referenced in hlo_slices,
@ -1685,8 +1695,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// buffer because even if the kernel itself doesn't use it, a nested // buffer because even if the kernel itself doesn't use it, a nested
// subcomputation within the kernel (e.g. a kMap's computation) might. // subcomputation within the kernel (e.g. a kMap's computation) might.
std::unordered_set<const BufferAllocation*> buffers_needed; std::unordered_set<const BufferAllocation*> buffers_needed;
for (const auto& kv : hlo_slices) { for (const auto& hlo_buffer_slice : hlo_slices) {
buffers_needed.insert(kv.second.first.allocation()); buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation());
} }
absl::optional<const BufferAllocation*> temp_buffer; absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) { for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
@ -1730,11 +1740,11 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// For each buffer our kernel might want to touch, bind it to a value derived // For each buffer our kernel might want to touch, bind it to a value derived
// from our kernel args. // from our kernel args.
for (const auto& kv : hlo_slices) { for (const auto& hlo_buffer_slice : hlo_slices) {
const HloInstruction* instr = kv.first.first; const HloInstruction* instr = hlo_buffer_slice.instr;
const ShapeIndex& index = kv.first.second; const ShapeIndex& index = hlo_buffer_slice.hlo_index;
const BufferAllocation::Slice& slice = kv.second.first; const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice;
const ShapeIndex& gte_index = kv.second.second; const ShapeIndex& gte_index = hlo_buffer_slice.gte_index;
VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
<< " is found in slice " << slice.ToString() << " at GTE index " << " is found in slice " << slice.ToString() << " at GTE index "

View File

@ -5,11 +5,12 @@
# need to run on machines with GPUs present. # need to run on machines with GPUs present.
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load( load(
"//tensorflow/core/platform:build_config_root.bzl", "//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags", "tf_cuda_tests_tags",
) )
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package( package(
default_visibility = [":friends"], default_visibility = [":friends"],
@ -456,3 +457,37 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
], ],
) )
tf_cc_binary(
name = "hlo_to_llvm_ir",
srcs = ["hlo_to_llvm_ir.cc"],
deps = [
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
"//tensorflow/compiler/xla/service/gpu:gpu_device_info",
"//tensorflow/compiler/xla/service/gpu:target_constants",
"//tensorflow/compiler/xla/tools:hlo_module_loader",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
glob_lit_tests(
data = [":test_utilities"],
default_tags = tf_cuda_tests_tags() + [
"no_pip",
],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["hlo"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/xla/service/gpu/tests:hlo_to_llvm_ir",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
const char* const kUsage = R"(
This tool reads in an HloMoudle from a file, compiles it using the NVPTX
compiler and prints out the LLVM IR generated by the IR emitter. The LLVM IR is
not optimized by the LLVM pass pipeline, so this tool can be used to unit test
the XLA GPU IR emitters.
Note that the LLVM IR does not contain the *full* module, but only parts that
will be code generated into PTX. The NVPTX compiler also generates a
GpuExecutable on the size that is not printed.)";
namespace {
xla::Status CompileAndPrintLlvmIr(const std::string& hlo_text) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::HloModule> hlo_module,
xla::LoadModuleFromData(/*data=*/hlo_text, /*format=*/"hlo"));
llvm::LLVMContext llvm_context;
// For now we pretend we're compiling for V100. This can be generalized
// later.
xla::gpu::GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit = 1024;
gpu_device_info.threads_per_warp = 32;
gpu_device_info.shared_memory_per_block = 1024 * 96;
xla::gpu::CudaComputeCapability cuda_compute_capability;
cuda_compute_capability.cc_major = 7;
cuda_compute_capability.cc_minor = 0;
std::string target_triple = "nvptx64-nvidia-cuda";
std::string datalayout = "nvptx64-nvidia-cuda";
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::Module> llvm_module,
xla::gpu::CompileModuleToLlvmIr(
hlo_module.get(), &llvm_context,
/*target_triple=*/xla::gpu::nvptx::kTargetTriple,
/*data_layout=*/xla::gpu::nvptx::kDataLayout,
/*platform_name=*/"CUDA", gpu_device_info,
cuda_compute_capability, /*pointer_size=*/8));
llvm_module->print(llvm::outs(), nullptr);
return xla::Status::OK();
}
xla::Status CompileAndPrintLlvmIrFromFile(const std::string& file_name) {
std::string full_text;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
file_name, &full_text));
std::vector<std::string> hlo_module_texts =
absl::StrSplit(full_text, "// -----");
for (const std::string& hlo_module_text : hlo_module_texts) {
TF_RETURN_IF_ERROR(CompileAndPrintLlvmIr(hlo_module_text));
}
return xla::Status::OK();
}
} // namespace
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::AppendDebugOptionsFlags(&flag_list);
// The usage string includes the message at the top of the file, the
// DebugOptions flags and the flags defined above.
const std::string kUsageString = absl::StrCat(
kUsage, "\n\n", tensorflow::Flags::Usage(argv[0], flag_list));
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(kUsageString.c_str(), &argc, &argv);
if (!parse_ok) {
LOG(QFATAL) << kUsageString;
}
QCHECK(argc == 2) << "Must specify a single input file";
TF_CHECK_OK(CompileAndPrintLlvmIrFromFile(argv[1]));
return 0;
}

View File

@ -0,0 +1,296 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [2 x i32]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [2 x [3 x i32]]*
// CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_14:.*]] = mul nuw nsw i32 %[[VAL_12]], 6
// CHECK: %[[VAL_15:.*]] = add nuw nsw i32 %[[VAL_14]], %[[VAL_13]]
// CHECK: %[[VAL_16:.*]] = icmp ult i32 %[[VAL_15]], 6
// CHECK: call void @llvm.assume(i1 %[[VAL_16]])
// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_15]], 1
// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 3
// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_15]], 3
// CHECK: %[[VAL_20:.*]] = icmp ult i32 %[[VAL_15]], 6
// CHECK: br i1 %[[VAL_20]], label %[[VAL_21:.*]], label %[[VAL_22:.*]]
// CHECK: scatter_TensorFlowScatterV1.in_bounds-after: ; preds = %[[VAL_23:.*]], %[[VAL_24:.*]]
// CHECK: ret void
// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_24]]
// CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_8]], i32 0, i32 %[[VAL_19]]
// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_27:.*]] = add i32 0, %[[VAL_26]]
// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_26]], 3
// CHECK: %[[VAL_29:.*]] = and i1 true, %[[VAL_28]]
// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_23]]
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_30]], %[[VAL_21]]
// CHECK: br label %[[VAL_22]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_21]]
// CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]]
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
// CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32*
// CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]]
// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5
// CHECK: store i32 %[[VAL_35]], i32* %[[VAL_32]], align 4
// CHECK: %[[VAL_36:.*]] = load i32, i32* %[[VAL_32]], align 4
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
// CHECK: br label %[[VAL_23]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
// CHECK: !5 = !{!6}
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
// CHECK: !7 = !{!"XLA global AA domain"}
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter_TensorFlowScatterV1 = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
// -----
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32*
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0
// CHECK: %[[VAL_42:.*]] = bitcast i8* %[[VAL_40]] to i32*
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds i8, i8* %[[VAL_44:.*]], i64 0
// CHECK: %[[VAL_45:.*]] = bitcast i8* %[[VAL_43]] to [0 x i32]*
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds i8, i8* %[[VAL_47:.*]], i64 0
// CHECK: %[[VAL_48:.*]] = bitcast i8* %[[VAL_46]] to i32*
// CHECK: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 1
// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_51]], %[[VAL_50]]
// CHECK: %[[VAL_53:.*]] = icmp ult i32 %[[VAL_52]], 1
// CHECK: call void @llvm.assume(i1 %[[VAL_53]])
// CHECK: %[[VAL_54:.*]] = icmp ult i32 %[[VAL_52]], 1
// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]]
// CHECK: scatter_ScatterIntoScalar.in_bounds-after: ; preds = %[[VAL_57:.*]], %[[VAL_58:.*]]
// CHECK: ret void
// CHECK: scatter_ScatterIntoScalar.in_bounds-true: ; preds = %[[VAL_58]]
// CHECK: br i1 true, label %[[VAL_59:.*]], label %[[VAL_57]]
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]]
// CHECK: br label %[[VAL_56]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4
// CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4
// CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
// CHECK: br label %[[VAL_57]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}
// CHECK: !4 = !{!5}
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:4}", !6}
// CHECK: !6 = !{!"XLA global AA domain"}
HloModule ScatterIntoScalar
update_s32 {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
parameter.1 = s32[] parameter(0)
parameter.2 = s32[0]{0} parameter(1)
parameter.3 = s32[] parameter(2)
ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3),
update_window_dims={},
inserted_window_dims={},
scatter_dims_to_operand_dims={},
index_vector_dim=0,
to_apply=update_s32
}
// -----
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0
// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0
// CHECK: %[[VAL_70:.*]] = bitcast i8* %[[VAL_68]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_71:.*]] = getelementptr inbounds i8, i8* %[[VAL_72:.*]], i64 0
// CHECK: %[[VAL_73:.*]] = bitcast i8* %[[VAL_71]] to [2 x i32]*
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds i8, i8* %[[VAL_75:.*]], i64 0
// CHECK: %[[VAL_76:.*]] = bitcast i8* %[[VAL_74]] to [2 x [3 x i32]]*
// CHECK: %[[VAL_77:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_79:.*]] = mul nuw nsw i32 %[[VAL_77]], 6
// CHECK: %[[VAL_80:.*]] = add nuw nsw i32 %[[VAL_79]], %[[VAL_78]]
// CHECK: %[[VAL_81:.*]] = icmp ult i32 %[[VAL_80]], 6
// CHECK: call void @llvm.assume(i1 %[[VAL_81]])
// CHECK: %[[VAL_82:.*]] = udiv i32 %[[VAL_80]], 1
// CHECK: %[[VAL_83:.*]] = urem i32 %[[VAL_82]], 3
// CHECK: %[[VAL_84:.*]] = udiv i32 %[[VAL_80]], 3
// CHECK: %[[VAL_85:.*]] = icmp ult i32 %[[VAL_80]], 6
// CHECK: br i1 %[[VAL_85]], label %[[VAL_86:.*]], label %[[VAL_87:.*]]
// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-after: ; preds = %[[VAL_88:.*]], %[[VAL_89:.*]]
// CHECK: ret void
// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_89]]
// CHECK: %[[VAL_90:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_73]], i32 0, i32 %[[VAL_84]]
// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_92:.*]] = add i32 0, %[[VAL_91]]
// CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_91]], 3
// CHECK: %[[VAL_94:.*]] = and i1 true, %[[VAL_93]]
// CHECK: br i1 %[[VAL_94]], label %[[VAL_95:.*]], label %[[VAL_88]]
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_96:.*]], %[[VAL_86]]
// CHECK: br label %[[VAL_87]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_86]]
// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]]
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
// CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32*
// CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]]
// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5
// CHECK: store i32 %[[VAL_101]], i32* %[[VAL_98]], align 4
// CHECK: %[[VAL_102:.*]] = load i32, i32* %[[VAL_98]], align 4
// CHECK: %[[VAL_103:.*]] = load i32, i32* %[[VAL_97]], align 4
// CHECK: store i32 %[[VAL_103]], i32* %[[VAL_64]], align 4
// CHECK: br label %[[VAL_104:.*]]
// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_104]]
// CHECK: br label %[[VAL_88]]
// CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]]
// CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4
// CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4
// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]], i8* null)
// CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4
// CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst
// CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0
// CHECK: store i32 %[[VAL_108]], i32* %[[VAL_64]], align 4
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
// CHECK: !5 = !{!6}
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
// CHECK: !7 = !{!"XLA global AA domain"}
// CHECK: !8 = !{!9}
// CHECK: !9 = !{!"buffer: {index:4, offset:0, size:4}", !7}
// CHECK: !10 = !{!11}
// CHECK: !11 = !{!"buffer: {index:6, offset:0, size:4}", !7}
// CHECK: !12 = !{!13}
// CHECK: !13 = !{!"buffer: {index:5, offset:0, size:4}", !7}
HloModule TensorFlowScatter_Mul
mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter_TensorFlowScatter_Mul = s32[3,3] scatter(operand, indices, updates),
to_apply=mul_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
// -----
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]*
// CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0
// CHECK: %[[VAL_123:.*]] = bitcast i8* %[[VAL_121]] to [4 x i32]*
// CHECK: %[[VAL_124:.*]] = getelementptr inbounds i8, i8* %[[VAL_125:.*]], i64 0
// CHECK: %[[VAL_126:.*]] = bitcast i8* %[[VAL_124]] to i32*
// CHECK: %[[VAL_127:.*]] = getelementptr inbounds i8, i8* %[[VAL_128:.*]], i64 0
// CHECK: %[[VAL_129:.*]] = bitcast i8* %[[VAL_127]] to i32*
// CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_131:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_130]], 1
// CHECK: %[[VAL_133:.*]] = add nuw nsw i32 %[[VAL_132]], %[[VAL_131]]
// CHECK: %[[VAL_134:.*]] = icmp ult i32 %[[VAL_133]], 1
// CHECK: call void @llvm.assume(i1 %[[VAL_134]])
// CHECK: %[[VAL_135:.*]] = icmp ult i32 %[[VAL_133]], 1
// CHECK: br i1 %[[VAL_135]], label %[[VAL_136:.*]], label %[[VAL_137:.*]]
// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_138:.*]], %[[VAL_139:.*]]
// CHECK: ret void
// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_139]]
// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3, !noalias !4
// CHECK: %[[VAL_141:.*]] = add i32 0, %[[VAL_140]]
// CHECK: %[[VAL_142:.*]] = icmp ult i32 %[[VAL_140]], 4
// CHECK: %[[VAL_143:.*]] = and i1 true, %[[VAL_142]]
// CHECK: br i1 %[[VAL_143]], label %[[VAL_144:.*]], label %[[VAL_138]]
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_144]], %[[VAL_136]]
// CHECK: br label %[[VAL_137]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]]
// CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]]
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4
// CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4
// CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
// CHECK: br label %[[VAL_138]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}
// CHECK: !4 = !{!5}
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:16}", !6}
// CHECK: !6 = !{!"XLA global AA domain"}
HloModule ScalarUpdate
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
updates = s32[] parameter(2)
ROOT scatter_ScalarUpdate = s32[4]{0} scatter(operand, index, updates),
to_apply=update_s32,
update_window_dims={},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=0
}