diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 2d225342b56..00d909b523e 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -73,7 +73,7 @@ tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', - 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 3e7596c75d7..c5cd2b17920 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -48,6 +48,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/aot', 'tensorflow/compiler/xla/service/mlir_gpu', + 'tensorflow/compiler/xla/service/gpu/tests', ] config.mlir_tf_tools_dirs = [ os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 5fb405ced53..785122e23b4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1133,14 +1133,12 @@ cc_library( deps = [ ":alias_passthrough_params", ":cudnn_batchnorm_rewriter", - ":cudnn_pad_for_convolutions", ":fusion_merger", ":gemm_rewriter", ":gpu_constants", ":gpu_conv_algorithm_picker", - ":gpu_conv_padding_legalization", - ":gpu_conv_rewriter", ":gpu_copy_insertion", + ":gpu_device_info", ":gpu_executable", ":gpu_hlo_schedule", ":gpu_layout_assignment", @@ -1188,7 +1186,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", - "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", @@ -1214,11 +1211,8 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3dd722c885d..2b31099d26f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -464,6 +464,66 @@ StatusOr> GpuCompiler::RunHloPasses( return std::move(module); } +static Status CompileModuleToLlvmIrImpl( + HloModule* hlo_module, llvm::LLVMContext* llvm_context, + const std::string& target_triple, const std::string& data_layout, + const std::string& platform_name, GpuDeviceInfo gpu_device_info, + absl::optional cuda_compute_capability, + const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function, + int pointer_size, std::unique_ptr* llvm_module, + std::unique_ptr* stream_assignment, + std::unique_ptr* hlo_schedule, + std::unique_ptr* buffer_assignment, + std::unique_ptr* thunk_sequence) { + *llvm_module = absl::make_unique("", *llvm_context); + + (*llvm_module)->setTargetTriple(target_triple); + (*llvm_module)->setDataLayout(data_layout); + + *stream_assignment = AssignStreams(*hlo_module); + TF_ASSIGN_OR_RETURN( + *hlo_schedule, + GpuHloSchedule::Build(*hlo_module, **stream_assignment, pointer_size)); + + auto buffer_size_bytes_function = + [pointer_size](const BufferValue& buffer_value) -> int64 { + return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size); + }; + + TF_ASSIGN_OR_RETURN( + *buffer_assignment, + BufferAssigner::Run( + hlo_module, (*hlo_schedule)->ConsumeHloOrdering(), + buffer_size_bytes_function, + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, can_share_buffer_function)); + + VLOG(1) << "Buffer Assignment Stats " + << (*buffer_assignment)->GetStats().ToString(); + DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment, + "after_optimizations"); + + IrEmitterContext ir_emitter_context( + hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, + cuda_compute_capability, llvm_module->get()); + + HloComputation* entry_computation = hlo_module->entry_computation(); + IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, + &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); + } + *thunk_sequence = ir_emitter.ConsumeThunkSequence(); + return Status::OK(); +} + StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { @@ -483,35 +543,6 @@ StatusOr> GpuCompiler::RunBackend( }; llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); - llvm::Module llvm_module(module->name().c_str(), llvm_context); - // Set the target triple and the data layout. - llvm_module.setTargetTriple(target_triple_); - llvm_module.setDataLayout(data_layout_); - - // Determine the HLO schedule, which is an ordering of HLO instructions. This - // is used by buffer assignment to enable buffer reuse, and the same ordering - // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/ - [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, GetCanShareBuffer())); - VLOG(1) << "Buffer Assignment Stats " - << buffer_assignment->GetStats().ToString(); - DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); - GpuDeviceInfo gpu_device_info; gpu_device_info.threads_per_block_limit = stream_exec->GetDeviceDescription().threads_per_block_limit(); @@ -531,32 +562,29 @@ StatusOr> GpuCompiler::RunBackend( return cuda_compute_capability; }(); - IrEmitterContext ir_emitter_context( - module.get(), buffer_assignment.get(), stream_exec->platform()->Name(), - gpu_device_info, cuda_compute_capability, &llvm_module); + std::unique_ptr llvm_module; + std::unique_ptr stream_assignment; + std::unique_ptr hlo_schedule; + std::unique_ptr buffer_assignment; + std::unique_ptr thunk_sequence; - HloComputation* entry_computation = module->entry_computation(); - IrEmitterUnnested ir_emitter(module->config(), entry_computation, - &ir_emitter_context); - - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); - - { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); - } + TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( + module.get(), &llvm_context, target_triple_, data_layout_, + stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, + GetCanShareBuffer(), pointer_size_, &llvm_module, &stream_assignment, + &hlo_schedule, &buffer_assignment, &thunk_sequence)); if (user_pre_optimization_hook_) { - user_pre_optimization_hook_(llvm_module); + user_pre_optimization_hook_(*llvm_module); } string ir_module_string_before_opt; const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); if (embed_ir_in_executable) { - ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); + ir_module_string_before_opt = llvm_ir::DumpModuleToString(*llvm_module); } - llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); + llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false); { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); @@ -565,7 +593,7 @@ StatusOr> GpuCompiler::RunBackend( llvm::raw_string_ostream err_stream(err); // verifyModule() returns true if the module is broken. - TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream)) << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " @@ -578,11 +606,11 @@ StatusOr> GpuCompiler::RunBackend( using BackendCompileResult = std::pair>; TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, - CompileTargetBinary(module.get(), &llvm_module, + CompileTargetBinary(module.get(), llvm_module.get(), gpu_version, stream_exec)); auto thunk_schedule = absl::make_unique( - ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), + std::move(thunk_sequence), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_schedule", @@ -602,8 +630,9 @@ StatusOr> GpuCompiler::RunBackend( cost_analysis.bytes_accessed()); if (module->config().hlo_profiling_enabled()) { profile_index_map = absl::make_unique(*module); - profile_printer = CreateHloProfilePrinterData( - *profile_index_map, cost_analysis, entry_computation->name()); + profile_printer = + CreateHloProfilePrinterData(*profile_index_map, cost_analysis, + module->entry_computation()->name()); } } @@ -625,5 +654,30 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } +static absl::optional DummyCanShareBufferFunction(const HloInstruction*, + const HloInstruction*, + const ShapeIndex&) { + return absl::nullopt; +} + +StatusOr> CompileModuleToLlvmIr( + HloModule* hlo_module, llvm::LLVMContext* llvm_context, + const std::string& target_triple, const std::string& data_layout, + const std::string& platform_name, GpuDeviceInfo gpu_device_info, + absl::optional cuda_compute_capability, + int pointer_size) { + std::unique_ptr llvm_module; + std::unique_ptr stream_assignment; + std::unique_ptr hlo_schedule; + std::unique_ptr buffer_assignment; + std::unique_ptr thunk_sequence; + + TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( + hlo_module, llvm_context, target_triple, data_layout, platform_name, + gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, + pointer_size, &llvm_module, &stream_assignment, &hlo_schedule, + &buffer_assignment, &thunk_sequence)); + return llvm_module; +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index a7706005ba2..7b6e4c78832 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -94,15 +95,19 @@ class GpuCompiler : public LLVMCompiler { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { // Capture just the pointer size, not the entire GpuCompiler object. return [pointer_size = pointer_size_](const Shape& shape) { - if (shape.is_static() || shape.IsTuple()) { - return ShapeUtil::ByteSizeOf(shape, pointer_size); - } - // Each dynamic dimension size is represented as a S32. - int64 metadata_size = sizeof(int32) * shape.dimensions_size(); - return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size; + return GetSizeOfShape(shape, pointer_size); }; } + static int64 GetSizeOfShape(const Shape& shape, int pointer_size) { + if (shape.is_static() || shape.IsTuple()) { + return ShapeUtil::ByteSizeOf(shape, pointer_size); + } + // Each dynamic dimension size is represented as a S32. + int64 metadata_size = sizeof(int32) * shape.dimensions_size(); + return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size; + } + private: se::Platform::Id platform_id_; @@ -117,6 +122,16 @@ class GpuCompiler : public LLVMCompiler { TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); }; + +// Compile `hlo_module` using XLA GPU and return the LLVM module thus generated. +// The GpuExecutable (and the Thunks that are part of it) are not returned. +StatusOr> CompileModuleToLlvmIr( + HloModule* hlo_module, llvm::LLVMContext* llvm_context, + const std::string& target_triple, const std::string& data_layout, + const std::string& platform_name, GpuDeviceInfo gpu_device_info, + absl::optional cuda_compute_capability, + int pointer_size); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 93dc8de0e0c..1065928687f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1444,8 +1444,6 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { return Status::OK(); } -namespace {} // namespace - Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count() << "; operand count: " << crs->operand_count() @@ -1557,29 +1555,37 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } +// Describes how to access a particular subshape for an HLO. For instance if +// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at +// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found +// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we +// dereference twice -- first at index 3, and then at index 4 -- to get the +// address of our buffer. +struct HloBufferSlice { + const HloInstruction* instr; + ShapeIndex hlo_index; + + // The root buffer to look at. + BufferAllocation::Slice buffer_slice; + + // Describes how to dereference starting at that buffer to get to the buffer + // in question. + ShapeIndex gte_index; +}; + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // -// Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for -// this key is a pair {Slice, ShapeIndex}, where the slice tells you the root -// buffer to look in, and the ShapeIndex describes how to dereference starting -// at that buffer to get to the buffer in question. -// -// For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for -// hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) -// is found at slice[3][4]. That is, slice is a void***, which we dereference -// twice -- first at index 3, and then at index 4 -- to get the address of our -// buffer. +// Returns a vector of `HloBufferSlice`s, one for each HLO subshape `hlo` needs +// to access (including one or more for itself). // // This function conservatively assumes that we'll touch all sub-buffers of // every operand and of the output. -static std::map, - std::pair> -GetHloBufferSlices(const HloInstruction* hlo, - const BufferAssignment& buffer_assn) { - std::map, - std::pair> - slices; +static std::vector GetHloBufferSlices( + const HloInstruction* hlo, const BufferAssignment& buffer_assn) { + std::vector result; + absl::flat_hash_set> + inserted_buffer_slices; // Tries to find a slice plus an array of indices i1, ..., iN such that the // sub-buffer for instr at index can be found at slice[i1]...[iN]. @@ -1646,13 +1652,18 @@ GetHloBufferSlices(const HloInstruction* hlo, auto add_slices_for = [&](const HloInstruction* instr) { ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { - if (slices.count({instr, index})) { + if (!inserted_buffer_slices.insert({instr, index}).second) { // HLOs can have duplicate operands; don't bother redoing work. return; } auto maybe_slice = find_slice_for(instr, index); if (maybe_slice.has_value()) { - slices[{instr, index}] = *maybe_slice; + HloBufferSlice hlo_buffer_slice; + hlo_buffer_slice.instr = instr; + hlo_buffer_slice.hlo_index = index; + hlo_buffer_slice.buffer_slice = maybe_slice->first; + hlo_buffer_slice.gte_index = maybe_slice->second; + result.push_back(hlo_buffer_slice); } else { VLOG(1) << "Couldn't find buffer for " << instr->ToString() << " at index " << index.ToString(); @@ -1667,7 +1678,7 @@ GetHloBufferSlices(const HloInstruction* hlo, add_slices_for(operand); } - return slices; + return result; } std::unique_ptr IrEmitterUnnested::BuildKernelThunk( @@ -1675,9 +1686,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); - std::map, - std::pair> - hlo_slices = GetHloBufferSlices(inst, buffer_assn); + std::vector hlo_slices = + GetHloBufferSlices(inst, buffer_assn); // Figure out which buffer allocations need to be passed as arguments to our // kernel. This is simply all of the allocations referenced in hlo_slices, @@ -1685,8 +1695,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // buffer because even if the kernel itself doesn't use it, a nested // subcomputation within the kernel (e.g. a kMap's computation) might. std::unordered_set buffers_needed; - for (const auto& kv : hlo_slices) { - buffers_needed.insert(kv.second.first.allocation()); + for (const auto& hlo_buffer_slice : hlo_slices) { + buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation()); } absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { @@ -1730,11 +1740,11 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // For each buffer our kernel might want to touch, bind it to a value derived // from our kernel args. - for (const auto& kv : hlo_slices) { - const HloInstruction* instr = kv.first.first; - const ShapeIndex& index = kv.first.second; - const BufferAllocation::Slice& slice = kv.second.first; - const ShapeIndex& gte_index = kv.second.second; + for (const auto& hlo_buffer_slice : hlo_slices) { + const HloInstruction* instr = hlo_buffer_slice.instr; + const ShapeIndex& index = hlo_buffer_slice.hlo_index; + const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice; + const ShapeIndex& gte_index = hlo_buffer_slice.gte_index; VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() << " is found in slice " << slice.ToString() << " at GTE index " diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index e2765e429a0..a23c14017a4 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -5,11 +5,12 @@ # need to run on machines with GPUs present. load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( default_visibility = [":friends"], @@ -456,3 +457,37 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) + +tf_cc_binary( + name = "hlo_to_llvm_ir", + srcs = ["hlo_to_llvm_ir.cc"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_device_info", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/tools:hlo_module_loader", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +glob_lit_tests( + data = [":test_utilities"], + default_tags = tf_cuda_tests_tags() + [ + "no_pip", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = ["hlo"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/service/gpu/tests:hlo_to_llvm_ir", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc new file mode 100644 index 00000000000..588f96bdf8d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +const char* const kUsage = R"( +This tool reads in an HloMoudle from a file, compiles it using the NVPTX +compiler and prints out the LLVM IR generated by the IR emitter. The LLVM IR is +not optimized by the LLVM pass pipeline, so this tool can be used to unit test +the XLA GPU IR emitters. + +Note that the LLVM IR does not contain the *full* module, but only parts that +will be code generated into PTX. The NVPTX compiler also generates a +GpuExecutable on the size that is not printed.)"; + +namespace { +xla::Status CompileAndPrintLlvmIr(const std::string& hlo_text) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + xla::LoadModuleFromData(/*data=*/hlo_text, /*format=*/"hlo")); + llvm::LLVMContext llvm_context; + + // For now we pretend we're compiling for V100. This can be generalized + // later. + + xla::gpu::GpuDeviceInfo gpu_device_info; + gpu_device_info.threads_per_block_limit = 1024; + gpu_device_info.threads_per_warp = 32; + gpu_device_info.shared_memory_per_block = 1024 * 96; + + xla::gpu::CudaComputeCapability cuda_compute_capability; + cuda_compute_capability.cc_major = 7; + cuda_compute_capability.cc_minor = 0; + std::string target_triple = "nvptx64-nvidia-cuda"; + std::string datalayout = "nvptx64-nvidia-cuda"; + TF_ASSIGN_OR_RETURN(std::unique_ptr llvm_module, + xla::gpu::CompileModuleToLlvmIr( + hlo_module.get(), &llvm_context, + /*target_triple=*/xla::gpu::nvptx::kTargetTriple, + /*data_layout=*/xla::gpu::nvptx::kDataLayout, + /*platform_name=*/"CUDA", gpu_device_info, + cuda_compute_capability, /*pointer_size=*/8)); + + llvm_module->print(llvm::outs(), nullptr); + return xla::Status::OK(); +} + +xla::Status CompileAndPrintLlvmIrFromFile(const std::string& file_name) { + std::string full_text; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + file_name, &full_text)); + + std::vector hlo_module_texts = + absl::StrSplit(full_text, "// -----"); + for (const std::string& hlo_module_text : hlo_module_texts) { + TF_RETURN_IF_ERROR(CompileAndPrintLlvmIr(hlo_module_text)); + } + + return xla::Status::OK(); +} +} // namespace + +int main(int argc, char** argv) { + std::vector flag_list; + xla::AppendDebugOptionsFlags(&flag_list); + // The usage string includes the message at the top of the file, the + // DebugOptions flags and the flags defined above. + const std::string kUsageString = absl::StrCat( + kUsage, "\n\n", tensorflow::Flags::Usage(argv[0], flag_list)); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(kUsageString.c_str(), &argc, &argv); + if (!parse_ok) { + LOG(QFATAL) << kUsageString; + } + + QCHECK(argc == 2) << "Must specify a single input file"; + TF_CHECK_OK(CompileAndPrintLlvmIrFromFile(argv[1])); + + return 0; +} diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo new file mode 100644 index 00000000000..b1cfb826e5f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -0,0 +1,296 @@ +// RUN: hlo_to_llvm_ir %s | FileCheck %s + +// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { +// CHECK: entry: +// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]* +// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 +// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [3 x [3 x i32]]* +// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0 +// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [2 x i32]* +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0 +// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [2 x [3 x i32]]* +// CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK: %[[VAL_14:.*]] = mul nuw nsw i32 %[[VAL_12]], 6 +// CHECK: %[[VAL_15:.*]] = add nuw nsw i32 %[[VAL_14]], %[[VAL_13]] +// CHECK: %[[VAL_16:.*]] = icmp ult i32 %[[VAL_15]], 6 +// CHECK: call void @llvm.assume(i1 %[[VAL_16]]) +// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_15]], 1 +// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 3 +// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_15]], 3 +// CHECK: %[[VAL_20:.*]] = icmp ult i32 %[[VAL_15]], 6 +// CHECK: br i1 %[[VAL_20]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] +// CHECK: scatter_TensorFlowScatterV1.in_bounds-after: ; preds = %[[VAL_23:.*]], %[[VAL_24:.*]] +// CHECK: ret void +// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_24]] +// CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_8]], i32 0, i32 %[[VAL_19]] +// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_27:.*]] = add i32 0, %[[VAL_26]] +// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_26]], 3 +// CHECK: %[[VAL_29:.*]] = and i1 true, %[[VAL_28]] +// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_23]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_30]], %[[VAL_21]] +// CHECK: br label %[[VAL_22]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]] +// CHECK: %[[VAL_32:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32* +// CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]] +// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5 +// CHECK: store i32 %[[VAL_35]], i32* %[[VAL_32]], align 4 +// CHECK: %[[VAL_36:.*]] = load i32, i32* %[[VAL_32]], align 4 +// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4 +// CHECK: br label %[[VAL_23]] +// CHECK: !nvvm.annotations = !{!0, !1} +// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} +// CHECK: !2 = !{i32 0, i32 1} +// CHECK: !3 = !{i32 0, i32 6} +// CHECK: !4 = !{} +// CHECK: !5 = !{!6} +// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7} +// CHECK: !7 = !{!"XLA global AA domain"} + + +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter_TensorFlowScatterV1 = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} + + +// ----- + +// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) { +// CHECK: entry: +// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 +// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32* +// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0 +// CHECK: %[[VAL_42:.*]] = bitcast i8* %[[VAL_40]] to i32* +// CHECK: %[[VAL_43:.*]] = getelementptr inbounds i8, i8* %[[VAL_44:.*]], i64 0 +// CHECK: %[[VAL_45:.*]] = bitcast i8* %[[VAL_43]] to [0 x i32]* +// CHECK: %[[VAL_46:.*]] = getelementptr inbounds i8, i8* %[[VAL_47:.*]], i64 0 +// CHECK: %[[VAL_48:.*]] = bitcast i8* %[[VAL_46]] to i32* +// CHECK: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 1 +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_51]], %[[VAL_50]] +// CHECK: %[[VAL_53:.*]] = icmp ult i32 %[[VAL_52]], 1 +// CHECK: call void @llvm.assume(i1 %[[VAL_53]]) +// CHECK: %[[VAL_54:.*]] = icmp ult i32 %[[VAL_52]], 1 +// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] +// CHECK: scatter_ScatterIntoScalar.in_bounds-after: ; preds = %[[VAL_57:.*]], %[[VAL_58:.*]] +// CHECK: ret void +// CHECK: scatter_ScatterIntoScalar.in_bounds-true: ; preds = %[[VAL_58]] +// CHECK: br i1 true, label %[[VAL_59:.*]], label %[[VAL_57]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]] +// CHECK: br label %[[VAL_56]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]] +// CHECK: %[[VAL_60:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4 +// CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4 +// CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4 +// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4 +// CHECK: br label %[[VAL_57]] +// CHECK: !nvvm.annotations = !{!0, !1} +// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} +// CHECK: !2 = !{i32 0, i32 1} +// CHECK: !3 = !{} +// CHECK: !4 = !{!5} +// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:4}", !6} +// CHECK: !6 = !{!"XLA global AA domain"} + +HloModule ScatterIntoScalar + +update_s32 { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + parameter.1 = s32[] parameter(0) + parameter.2 = s32[0]{0} parameter(1) + parameter.3 = s32[] parameter(2) + ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3), + update_window_dims={}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={}, + index_vector_dim=0, + to_apply=update_s32 +} + + +// ----- + +// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { +// CHECK: %[[VAL_63:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_64:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0 +// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]* +// CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0 +// CHECK: %[[VAL_70:.*]] = bitcast i8* %[[VAL_68]] to [3 x [3 x i32]]* +// CHECK: %[[VAL_71:.*]] = getelementptr inbounds i8, i8* %[[VAL_72:.*]], i64 0 +// CHECK: %[[VAL_73:.*]] = bitcast i8* %[[VAL_71]] to [2 x i32]* +// CHECK: %[[VAL_74:.*]] = getelementptr inbounds i8, i8* %[[VAL_75:.*]], i64 0 +// CHECK: %[[VAL_76:.*]] = bitcast i8* %[[VAL_74]] to [2 x [3 x i32]]* +// CHECK: %[[VAL_77:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK: %[[VAL_79:.*]] = mul nuw nsw i32 %[[VAL_77]], 6 +// CHECK: %[[VAL_80:.*]] = add nuw nsw i32 %[[VAL_79]], %[[VAL_78]] +// CHECK: %[[VAL_81:.*]] = icmp ult i32 %[[VAL_80]], 6 +// CHECK: call void @llvm.assume(i1 %[[VAL_81]]) +// CHECK: %[[VAL_82:.*]] = udiv i32 %[[VAL_80]], 1 +// CHECK: %[[VAL_83:.*]] = urem i32 %[[VAL_82]], 3 +// CHECK: %[[VAL_84:.*]] = udiv i32 %[[VAL_80]], 3 +// CHECK: %[[VAL_85:.*]] = icmp ult i32 %[[VAL_80]], 6 +// CHECK: br i1 %[[VAL_85]], label %[[VAL_86:.*]], label %[[VAL_87:.*]] +// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-after: ; preds = %[[VAL_88:.*]], %[[VAL_89:.*]] +// CHECK: ret void +// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_89]] +// CHECK: %[[VAL_90:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_73]], i32 0, i32 %[[VAL_84]] +// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_92:.*]] = add i32 0, %[[VAL_91]] +// CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_91]], 3 +// CHECK: %[[VAL_94:.*]] = and i1 true, %[[VAL_93]] +// CHECK: br i1 %[[VAL_94]], label %[[VAL_95:.*]], label %[[VAL_88]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_96:.*]], %[[VAL_86]] +// CHECK: br label %[[VAL_87]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_86]] +// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]] +// CHECK: %[[VAL_98:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32* +// CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]] +// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5 +// CHECK: store i32 %[[VAL_101]], i32* %[[VAL_98]], align 4 +// CHECK: %[[VAL_102:.*]] = load i32, i32* %[[VAL_98]], align 4 +// CHECK: %[[VAL_103:.*]] = load i32, i32* %[[VAL_97]], align 4 +// CHECK: store i32 %[[VAL_103]], i32* %[[VAL_64]], align 4 +// CHECK: br label %[[VAL_104:.*]] +// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_104]] +// CHECK: br label %[[VAL_88]] +// CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]] +// CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4 +// CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4 +// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]], i8* null) +// CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4 +// CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst +// CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0 +// CHECK: store i32 %[[VAL_108]], i32* %[[VAL_64]], align 4 +// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1 +// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]] +// CHECK: !nvvm.annotations = !{!0, !1} +// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} +// CHECK: !2 = !{i32 0, i32 1} +// CHECK: !3 = !{i32 0, i32 6} +// CHECK: !4 = !{} +// CHECK: !5 = !{!6} +// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7} +// CHECK: !7 = !{!"XLA global AA domain"} +// CHECK: !8 = !{!9} +// CHECK: !9 = !{!"buffer: {index:4, offset:0, size:4}", !7} +// CHECK: !10 = !{!11} +// CHECK: !11 = !{!"buffer: {index:6, offset:0, size:4}", !7} +// CHECK: !12 = !{!13} +// CHECK: !13 = !{!"buffer: {index:5, offset:0, size:4}", !7} + +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter_TensorFlowScatter_Mul = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} + +// ----- + +// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) { +// CHECK: entry: +// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 +// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]* +// CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0 +// CHECK: %[[VAL_123:.*]] = bitcast i8* %[[VAL_121]] to [4 x i32]* +// CHECK: %[[VAL_124:.*]] = getelementptr inbounds i8, i8* %[[VAL_125:.*]], i64 0 +// CHECK: %[[VAL_126:.*]] = bitcast i8* %[[VAL_124]] to i32* +// CHECK: %[[VAL_127:.*]] = getelementptr inbounds i8, i8* %[[VAL_128:.*]], i64 0 +// CHECK: %[[VAL_129:.*]] = bitcast i8* %[[VAL_127]] to i32* +// CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_131:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_130]], 1 +// CHECK: %[[VAL_133:.*]] = add nuw nsw i32 %[[VAL_132]], %[[VAL_131]] +// CHECK: %[[VAL_134:.*]] = icmp ult i32 %[[VAL_133]], 1 +// CHECK: call void @llvm.assume(i1 %[[VAL_134]]) +// CHECK: %[[VAL_135:.*]] = icmp ult i32 %[[VAL_133]], 1 +// CHECK: br i1 %[[VAL_135]], label %[[VAL_136:.*]], label %[[VAL_137:.*]] +// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_138:.*]], %[[VAL_139:.*]] +// CHECK: ret void +// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_139]] +// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3, !noalias !4 +// CHECK: %[[VAL_141:.*]] = add i32 0, %[[VAL_140]] +// CHECK: %[[VAL_142:.*]] = icmp ult i32 %[[VAL_140]], 4 +// CHECK: %[[VAL_143:.*]] = and i1 true, %[[VAL_142]] +// CHECK: br i1 %[[VAL_143]], label %[[VAL_144:.*]], label %[[VAL_138]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_144]], %[[VAL_136]] +// CHECK: br label %[[VAL_137]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]] +// CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]] +// CHECK: %[[VAL_146:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4 +// CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4 +// CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4 +// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4 +// CHECK: br label %[[VAL_138]] +// CHECK: !nvvm.annotations = !{!0, !1} +// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} +// CHECK: !2 = !{i32 0, i32 1} +// CHECK: !3 = !{} +// CHECK: !4 = !{!5} +// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:16}", !6} +// CHECK: !6 = !{!"XLA global AA domain"} + +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter_ScalarUpdate = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +}