From 46c82584dc94ef1288235e4291603bfad434c52a Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 4 Feb 2021 10:07:13 -0800 Subject: [PATCH] [XLA:GPU] Migrate CollectivePermute thunk generation to MLIR - Also extend mlir_gpu_test to have a compile only API and use that to test failure of attribute export for an invalid collective permute op. PiperOrigin-RevId: 355652447 Change-Id: Icd34714057541acaad168021e519aa6a62c25f9b --- .../compiler/mlir/xla/attribute_exporter.cc | 24 +++++++++- .../compiler/mlir/xla/attribute_exporter.h | 7 ++- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 22 +-------- .../service/gpu/collective_permute_thunk.cc | 17 +++---- .../service/gpu/collective_permute_thunk.h | 15 +++---- .../compiler/xla/service/gpu/gpu_compiler.cc | 7 +-- .../xla/service/gpu/ir_emitter_unnested.cc | 30 +++++++++++-- .../xla/service/gpu/ir_emitter_unnested.h | 1 + .../compiler/xla/service/gpu/tests/BUILD | 11 +++++ .../gpu/tests/mlir_gpu_compile_test.cc | 40 +++++++++++++++++ .../service/gpu/tests/mlir_gpu_test_base.cc | 45 +++++++++++++------ .../service/gpu/tests/mlir_gpu_test_base.h | 9 ++++ 12 files changed, 163 insertions(+), 65 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_compile_test.cc diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.cc b/tensorflow/compiler/mlir/xla/attribute_exporter.cc index 6f6cc59fc84..ffa05cee707 100644 --- a/tensorflow/compiler/mlir/xla/attribute_exporter.cc +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.cc @@ -97,7 +97,7 @@ StatusOr> ConvertReplicaGroups( // rank 0 is num_groups, rank 1 is group size. auto replica_group_values_it = input.getValues().begin(); std::vector replica_groups(type.getDimSize(0)); - for (ReplicaGroup &group : replica_groups) { + for (ReplicaGroup& group : replica_groups) { for (int64 element_idx = 0; element_idx < type.getDimSize(1); ++element_idx, ++replica_group_values_it) { // For replica group attribute, -1 indicates padding added by @@ -110,4 +110,26 @@ StatusOr> ConvertReplicaGroups( } return replica_groups; } + +// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding +// and source-target pairs are defined in HLO. +StatusOr>> ConvertNx2Attribute( + llvm::Optional optional_attr) { + if (!optional_attr.hasValue()) return std::vector>{}; + mlir::DenseIntElementsAttr attr = *optional_attr; + auto type = attr.getType().dyn_cast(); + if (!type || type.getRank() != 2 || type.getShape()[1] != 2) + return InternalError("expected Nx2 attribute to be a tensor of shape Nx2"); + auto it = attr.getValues().begin(); + std::vector> out(attr.getNumElements() / 2); + for (auto& item : out) { + int64 first = *it; + ++it; + int64 second = *it; + ++it; + item = {first, second}; + } + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.h b/tensorflow/compiler/mlir/xla/attribute_exporter.h index 3046b9c3eaa..be0529d5058 100644 --- a/tensorflow/compiler/mlir/xla/attribute_exporter.h +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.h @@ -20,8 +20,8 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/dnn.h" namespace xla { @@ -36,5 +36,10 @@ StatusOr ConvertConvActivationMode( StatusOr> ConvertReplicaGroups( mlir::DenseIntElementsAttr input); +// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding +// and source-target pairs are defined in HLO. +StatusOr>> ConvertNx2Attribute( + llvm::Optional optional_attr); + } // namespace xla #endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 622a0b1dc0b..a58dabfb65f 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -163,32 +163,14 @@ static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) { return fft_type_enum; } -// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding -// and source-target pairs are defined in HLO. -static std::vector> Convert_Nx2_attribute( - llvm::Optional optional_attr) { - if (!optional_attr.hasValue()) return {}; - mlir::DenseIntElementsAttr attr = *optional_attr; - auto it = attr.getValues().begin(); - std::vector> out(attr.getNumElements() / 2); - for (auto& item : out) { - int64 first = *it; - ++it; - int64 second = *it; - ++it; - item = {first, second}; - } - return out; -} - static std::vector> Convert_padding( llvm::Optional padding) { - return Convert_Nx2_attribute(padding); + return xla::ConvertNx2Attribute(padding).ValueOrDie(); } static std::vector> Convert_source_target_pairs( llvm::Optional source_target_pairs) { - return Convert_Nx2_attribute(source_target_pairs); + return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie(); } static std::vector Convert_replica_groups( diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 10d8763133c..d32b517dc04 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -218,19 +218,12 @@ RefcountingHashMap& GlobalRendezvousMap() { } // anonymous namespace -CollectivePermuteConfig GetCollectivePermuteConfig( - const HloInstruction* instr) { - CollectivePermuteConfig config; - auto* collective_permute = Cast(instr); - config.source_target_pairs = collective_permute->source_target_pairs(); - return config; -} - CollectivePermuteThunk::CollectivePermuteThunk( - ThunkInfo thunk_info, CollectivePermuteConfig config, + ThunkInfo thunk_info, + std::vector> source_target_pairs, const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest) : Thunk(kCollectivePermute, thunk_info), - config_(std::move(config)), + source_target_pairs_(std::move(source_target_pairs)), src_(src), dest_(dest) {} @@ -254,7 +247,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // Figure out which replicas our data is copied to. std::vector dest_replicas; - for (const auto& src_dest : config_.source_target_pairs) { + for (const auto& src_dest : source_target_pairs_) { if (src_dest.first == replica_id) { dest_replicas.push_back(src_dest.second); } @@ -269,7 +262,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // If no replica writes into us (i.e. we aren't the target of any copies), our // contract is that we zero our output. - if (absl::c_none_of(config_.source_target_pairs, + if (absl::c_none_of(source_target_pairs_, [&](std::pair src_dest) { return src_dest.second == replica_id; })) { diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h index 35dda8dad7d..8bff5dc60a3 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h @@ -24,23 +24,18 @@ limitations under the License. namespace xla { namespace gpu { -struct CollectivePermuteConfig { - std::vector> source_target_pairs; -}; - -CollectivePermuteConfig GetCollectivePermuteConfig(const HloInstruction* instr); - // Thunk that implements the collective-permute HLO. class CollectivePermuteThunk : public Thunk { public: - CollectivePermuteThunk(ThunkInfo thunk_info, CollectivePermuteConfig config, - const BufferAllocation::Slice& src, - const BufferAllocation::Slice& dest); + CollectivePermuteThunk( + ThunkInfo thunk_info, + std::vector> source_target_pairs, + const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest); Status ExecuteOnStream(const ExecuteParams& params) override; private: - const CollectivePermuteConfig config_; + const std::vector> source_target_pairs_; const BufferAllocation::Slice src_; const BufferAllocation::Slice dest_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index a91c258387f..d104c26964b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -1119,11 +1119,8 @@ StatusOr> CompileLmhloToExecutable( IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr, ir_emitter_context)); ThunkSequence thunk_sequence; - for (mlir::Operation& op : entry_function.getBody().front()) { - // Omit the terminator. - if (&op == &entry_function.getBody().front().back()) { - continue; - } + for (mlir::Operation& op : + entry_function.getBody().front().without_terminator()) { MlirEmitterInput input; input.op = &op; TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 685976882d5..9c779f74fc6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2940,10 +2940,31 @@ Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { - CollectivePermuteConfig config = GetCollectivePermuteConfig(hlo); + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo)); + return EmitCollectivePermuteFromMlir(input); +} + +Status IrEmitterUnnested::EmitCollectivePermuteFromMlir( + MlirEmitterInput input) { + auto collective_permute_op = + mlir::cast(input.op); + if (collective_permute_op.channel_id()) + return Unimplemented("collective permute with channel_id not implemented"); + using source_dest_pairs_t = std::vector>; + TF_ASSIGN_OR_RETURN( + source_dest_pairs_t source_dest_pairs, + ConvertNx2Attribute(collective_permute_op.source_target_pairs())); + + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice source_slice, + GetAllocationSliceForMlir(collective_permute_op.operand())); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice result_slice, + GetAllocationSliceForMlir(collective_permute_op.output())); + AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(hlo), std::move(config), - GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo))); + input.thunk_info, std::move(source_dest_pairs), source_slice, + result_slice)); return Status::OK(); } @@ -5763,6 +5784,9 @@ Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) { if (mlir::isa(mlir_input.op)) { return EmitSortFromMlir(mlir_input); } + if (mlir::isa(mlir_input.op)) { + return EmitCollectivePermuteFromMlir(mlir_input); + } LOG(FATAL) << "This function is for test only, and the op is not implemented: " << MlirToString(mlir_input.op); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 6cfed8dea33..0d864511510 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -215,6 +215,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandlePartitionId(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status EmitCollectivePermuteFromMlir(MlirEmitterInput input); Status EmitOp(MlirEmitterInput mlir_input); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 616c0316498..ef299911153 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -88,6 +88,17 @@ tf_cc_test( ], ) +tf_cc_test( + name = "mlir_gpu_compile_test", + srcs = ["mlir_gpu_compile_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":mlir_gpu_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gemm_rewrite_test", srcs = [ diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_compile_test.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_compile_test.cc new file mode 100644 index 00000000000..a7c6b25cd68 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_compile_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +// Tests XLA GPU compilation using MLIR LMHLO dialect as the input. +class CompileTest : public MlirGpuTestBase {}; + +TEST_F(CompileTest, InvalidCollectivePermuteOp) { + const char* mlir_text = R"( + func @main(%arg0: memref<4xf32> {lmhlo.alloc = 0 : index, lmhlo.params = 0 : index}, + %arg1: memref<4xf32> {lmhlo.alloc = 1 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>}) -> () { + "lmhlo.collective_permute"(%arg0, %arg1) {source_target_pairs = dense<[[0, 1, 2]]> : tensor<1x3xi64>} : (memref<4xf32>, memref<4xf32>) -> () + "std.return" () : () -> () + })"; + auto executable = CompileMlirText(mlir_text); + ASSERT_FALSE(executable.ok()); + EXPECT_STREQ("expected Nx2 attribute to be a tensor of shape Nx2", + executable.status().error_message().c_str()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc index d0ba544289c..6f74909dd6e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h" #include "llvm/IR/LLVMContext.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" @@ -37,9 +38,8 @@ MlirGpuTestBase::MlirGpuTestBase() { backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie(); } -StatusOr MlirGpuTestBase::RunMlirModule( - mlir::ModuleOp module, se::Stream* stream, - absl::Span arguments) { +StatusOr> MlirGpuTestBase::CompileMlirModule( + mlir::ModuleOp module, se::Stream* stream) { llvm::LLVMContext llvm_context; auto llvm_module = absl::make_unique("", llvm_context); #if TENSORFLOW_USE_ROCM @@ -72,12 +72,16 @@ StatusOr MlirGpuTestBase::RunMlirModule( HloModuleConfig module_config; module_config.set_debug_options(DefaultDebugOptionsIgnoringFlags()); - TF_ASSIGN_OR_RETURN( - auto executable, - CompileLmhloToExecutable(static_cast(backend_->compiler()), - module, "TestModule", module_config, - Compiler::CompileOptions(), "main", stream_exec, - std::move(llvm_module), &ir_emitter_context)); + return CompileLmhloToExecutable( + static_cast(backend_->compiler()), module, "TestModule", + module_config, Compiler::CompileOptions(), "main", stream_exec, + std::move(llvm_module), &ir_emitter_context); +} + +StatusOr MlirGpuTestBase::RunMlirModule( + mlir::ModuleOp module, se::Stream* stream, + absl::Span arguments) { + TF_ASSIGN_OR_RETURN(auto executable, CompileMlirModule(module, stream)); ExecutableRunOptions executable_run_options; executable_run_options.set_stream(stream); @@ -137,10 +141,8 @@ MlirGpuTestBase::RunMlirModuleWithHostBuffers( return host_outputs; } -StatusOr>> -MlirGpuTestBase::RunMlirTextWithHostBuffers( - absl::string_view module_text, std::vector> arguments) { - mlir::MLIRContext context; +mlir::OwningModuleRef MlirGpuTestBase::ParseMlirModule( + absl::string_view module_text, mlir::MLIRContext& context) { context.loadDialect(); @@ -148,8 +150,25 @@ MlirGpuTestBase::RunMlirTextWithHostBuffers( mlir::OwningModuleRef module = parseSourceString( llvm::StringRef(module_text.data(), module_text.size()), &context); CHECK(module); + return module; +} + +StatusOr>> +MlirGpuTestBase::RunMlirTextWithHostBuffers( + absl::string_view module_text, std::vector> arguments) { + mlir::MLIRContext context; + mlir::OwningModuleRef module = ParseMlirModule(module_text, context); return RunMlirModuleWithHostBuffers(*module, arguments); } +StatusOr> MlirGpuTestBase::CompileMlirText( + absl::string_view module_text) { + mlir::MLIRContext context; + mlir::OwningModuleRef module = ParseMlirModule(module_text, context); + auto stream = backend_->BorrowStream(backend_->default_device_ordinal()) + .ConsumeValueOrDie(); + return CompileMlirModule(*module, stream.get()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h index aa0054924b1..eb407ddfce5 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h +++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h @@ -29,6 +29,9 @@ class MlirGpuTestBase : public HloTestBase { StatusOr>> RunMlirTextWithHostBuffers( absl::string_view module_text, std::vector> arguments); + StatusOr> CompileMlirText( + absl::string_view module_text); + template static absl::Span ToUint8Span(std::vector* v) { return absl::Span(reinterpret_cast(v->data()), @@ -46,10 +49,16 @@ class MlirGpuTestBase : public HloTestBase { StatusOr>> RunMlirModuleWithHostBuffers( mlir::ModuleOp module, std::vector> arguments); + StatusOr> CompileMlirModule(mlir::ModuleOp module, + se::Stream* stream); + StatusOr RunMlirModule( mlir::ModuleOp module, se::Stream* stream, absl::Span arguments); + mlir::OwningModuleRef ParseMlirModule(absl::string_view module_text, + mlir::MLIRContext& context); + std::unique_ptr backend_; };