[XLA:GPU] Migrate CollectivePermute thunk generation to MLIR

- Also extend mlir_gpu_test to have a compile only API and use that to test failure of
  attribute export for an invalid collective permute op.

PiperOrigin-RevId: 355652447
Change-Id: Icd34714057541acaad168021e519aa6a62c25f9b
This commit is contained in:
Rahul Joshi 2021-02-04 10:07:13 -08:00 committed by TensorFlower Gardener
parent f7680f0926
commit 46c82584dc
12 changed files with 163 additions and 65 deletions

View File

@ -97,7 +97,7 @@ StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
// rank 0 is num_groups, rank 1 is group size.
auto replica_group_values_it = input.getValues<uint64_t>().begin();
std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
for (ReplicaGroup &group : replica_groups) {
for (ReplicaGroup& group : replica_groups) {
for (int64 element_idx = 0; element_idx < type.getDimSize(1);
++element_idx, ++replica_group_values_it) {
// For replica group attribute, -1 indicates padding added by
@ -110,4 +110,26 @@ StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
}
return replica_groups;
}
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
// and source-target pairs are defined in HLO.
StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
if (!optional_attr.hasValue()) return std::vector<std::pair<int64, int64>>{};
mlir::DenseIntElementsAttr attr = *optional_attr;
auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
auto it = attr.getValues<int64>().begin();
std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
for (auto& item : out) {
int64 first = *it;
++it;
int64 second = *it;
++it;
item = {first, second};
}
return out;
}
} // namespace xla

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/dnn.h"
namespace xla {
@ -36,5 +36,10 @@ StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
mlir::DenseIntElementsAttr input);
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
// and source-target pairs are defined in HLO.
StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_

View File

@ -163,32 +163,14 @@ static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
return fft_type_enum;
}
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
// and source-target pairs are defined in HLO.
static std::vector<std::pair<int64, int64>> Convert_Nx2_attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
if (!optional_attr.hasValue()) return {};
mlir::DenseIntElementsAttr attr = *optional_attr;
auto it = attr.getValues<int64>().begin();
std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
for (auto& item : out) {
int64 first = *it;
++it;
int64 second = *it;
++it;
item = {first, second};
}
return out;
}
static std::vector<std::pair<int64, int64>> Convert_padding(
llvm::Optional<mlir::DenseIntElementsAttr> padding) {
return Convert_Nx2_attribute(padding);
return xla::ConvertNx2Attribute(padding).ValueOrDie();
}
static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
return Convert_Nx2_attribute(source_target_pairs);
return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
}
static std::vector<xla::ReplicaGroup> Convert_replica_groups(

View File

@ -218,19 +218,12 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
} // anonymous namespace
CollectivePermuteConfig GetCollectivePermuteConfig(
const HloInstruction* instr) {
CollectivePermuteConfig config;
auto* collective_permute = Cast<HloCollectivePermuteInstruction>(instr);
config.source_target_pairs = collective_permute->source_target_pairs();
return config;
}
CollectivePermuteThunk::CollectivePermuteThunk(
ThunkInfo thunk_info, CollectivePermuteConfig config,
ThunkInfo thunk_info,
std::vector<std::pair<int64, int64>> source_target_pairs,
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest)
: Thunk(kCollectivePermute, thunk_info),
config_(std::move(config)),
source_target_pairs_(std::move(source_target_pairs)),
src_(src),
dest_(dest) {}
@ -254,7 +247,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
// Figure out which replicas our data is copied to.
std::vector<int64> dest_replicas;
for (const auto& src_dest : config_.source_target_pairs) {
for (const auto& src_dest : source_target_pairs_) {
if (src_dest.first == replica_id) {
dest_replicas.push_back(src_dest.second);
}
@ -269,7 +262,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
// If no replica writes into us (i.e. we aren't the target of any copies), our
// contract is that we zero our output.
if (absl::c_none_of(config_.source_target_pairs,
if (absl::c_none_of(source_target_pairs_,
[&](std::pair<int64, int64> src_dest) {
return src_dest.second == replica_id;
})) {

View File

@ -24,23 +24,18 @@ limitations under the License.
namespace xla {
namespace gpu {
struct CollectivePermuteConfig {
std::vector<std::pair<int64, int64>> source_target_pairs;
};
CollectivePermuteConfig GetCollectivePermuteConfig(const HloInstruction* instr);
// Thunk that implements the collective-permute HLO.
class CollectivePermuteThunk : public Thunk {
public:
CollectivePermuteThunk(ThunkInfo thunk_info, CollectivePermuteConfig config,
const BufferAllocation::Slice& src,
const BufferAllocation::Slice& dest);
CollectivePermuteThunk(
ThunkInfo thunk_info,
std::vector<std::pair<int64, int64>> source_target_pairs,
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const CollectivePermuteConfig config_;
const std::vector<std::pair<int64, int64>> source_target_pairs_;
const BufferAllocation::Slice src_;
const BufferAllocation::Slice dest_;
};

View File

@ -1119,11 +1119,8 @@ StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
ir_emitter_context));
ThunkSequence thunk_sequence;
for (mlir::Operation& op : entry_function.getBody().front()) {
// Omit the terminator.
if (&op == &entry_function.getBody().front().back()) {
continue;
}
for (mlir::Operation& op :
entry_function.getBody().front().without_terminator()) {
MlirEmitterInput input;
input.op = &op;
TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));

View File

@ -2940,10 +2940,31 @@ Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
CollectivePermuteConfig config = GetCollectivePermuteConfig(hlo);
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitCollectivePermuteFromMlir(input);
}
Status IrEmitterUnnested::EmitCollectivePermuteFromMlir(
MlirEmitterInput input) {
auto collective_permute_op =
mlir::cast<mlir::lmhlo::CollectivePermuteOp>(input.op);
if (collective_permute_op.channel_id())
return Unimplemented("collective permute with channel_id not implemented");
using source_dest_pairs_t = std::vector<std::pair<int64, int64>>;
TF_ASSIGN_OR_RETURN(
source_dest_pairs_t source_dest_pairs,
ConvertNx2Attribute(collective_permute_op.source_target_pairs()));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice source_slice,
GetAllocationSliceForMlir(collective_permute_op.operand()));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice result_slice,
GetAllocationSliceForMlir(collective_permute_op.output()));
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
GetThunkInfo(hlo), std::move(config),
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo)));
input.thunk_info, std::move(source_dest_pairs), source_slice,
result_slice));
return Status::OK();
}
@ -5763,6 +5784,9 @@ Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
return EmitSortFromMlir(mlir_input);
}
if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(mlir_input.op)) {
return EmitCollectivePermuteFromMlir(mlir_input);
}
LOG(FATAL)
<< "This function is for test only, and the op is not implemented: "
<< MlirToString(mlir_input.op);

View File

@ -215,6 +215,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandlePartitionId(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status EmitCollectivePermuteFromMlir(MlirEmitterInput input);
Status EmitOp(MlirEmitterInput mlir_input);

View File

@ -88,6 +88,17 @@ tf_cc_test(
],
)
tf_cc_test(
name = "mlir_gpu_compile_test",
srcs = ["mlir_gpu_compile_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":mlir_gpu_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "gemm_rewrite_test",
srcs = [

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace gpu {
// Tests XLA GPU compilation using MLIR LMHLO dialect as the input.
class CompileTest : public MlirGpuTestBase {};
TEST_F(CompileTest, InvalidCollectivePermuteOp) {
const char* mlir_text = R"(
func @main(%arg0: memref<4xf32> {lmhlo.alloc = 0 : index, lmhlo.params = 0 : index},
%arg1: memref<4xf32> {lmhlo.alloc = 1 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>}) -> () {
"lmhlo.collective_permute"(%arg0, %arg1) {source_target_pairs = dense<[[0, 1, 2]]> : tensor<1x3xi64>} : (memref<4xf32>, memref<4xf32>) -> ()
"std.return" () : () -> ()
})";
auto executable = CompileMlirText(mlir_text);
ASSERT_FALSE(executable.ok());
EXPECT_STREQ("expected Nx2 attribute to be a tensor of shape Nx2",
executable.status().error_message().c_str());
}
} // namespace gpu
} // namespace xla

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
#include "llvm/IR/LLVMContext.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
@ -37,9 +38,8 @@ MlirGpuTestBase::MlirGpuTestBase() {
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
}
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
mlir::ModuleOp module, se::Stream* stream,
absl::Span<const se::DeviceMemoryBase> arguments) {
StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirModule(
mlir::ModuleOp module, se::Stream* stream) {
llvm::LLVMContext llvm_context;
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
#if TENSORFLOW_USE_ROCM
@ -72,12 +72,16 @@ StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
HloModuleConfig module_config;
module_config.set_debug_options(DefaultDebugOptionsIgnoringFlags());
TF_ASSIGN_OR_RETURN(
auto executable,
CompileLmhloToExecutable(static_cast<GpuCompiler*>(backend_->compiler()),
module, "TestModule", module_config,
Compiler::CompileOptions(), "main", stream_exec,
std::move(llvm_module), &ir_emitter_context));
return CompileLmhloToExecutable(
static_cast<GpuCompiler*>(backend_->compiler()), module, "TestModule",
module_config, Compiler::CompileOptions(), "main", stream_exec,
std::move(llvm_module), &ir_emitter_context);
}
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
mlir::ModuleOp module, se::Stream* stream,
absl::Span<const se::DeviceMemoryBase> arguments) {
TF_ASSIGN_OR_RETURN(auto executable, CompileMlirModule(module, stream));
ExecutableRunOptions executable_run_options;
executable_run_options.set_stream(stream);
@ -137,10 +141,8 @@ MlirGpuTestBase::RunMlirModuleWithHostBuffers(
return host_outputs;
}
StatusOr<std::vector<std::vector<uint8>>>
MlirGpuTestBase::RunMlirTextWithHostBuffers(
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
mlir::MLIRContext context;
mlir::OwningModuleRef MlirGpuTestBase::ParseMlirModule(
absl::string_view module_text, mlir::MLIRContext& context) {
context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
mlir::StandardOpsDialect,
mlir::lmhlo_gpu::LmhloGpuDialect>();
@ -148,8 +150,25 @@ MlirGpuTestBase::RunMlirTextWithHostBuffers(
mlir::OwningModuleRef module = parseSourceString(
llvm::StringRef(module_text.data(), module_text.size()), &context);
CHECK(module);
return module;
}
StatusOr<std::vector<std::vector<uint8>>>
MlirGpuTestBase::RunMlirTextWithHostBuffers(
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
mlir::MLIRContext context;
mlir::OwningModuleRef module = ParseMlirModule(module_text, context);
return RunMlirModuleWithHostBuffers(*module, arguments);
}
StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirText(
absl::string_view module_text) {
mlir::MLIRContext context;
mlir::OwningModuleRef module = ParseMlirModule(module_text, context);
auto stream = backend_->BorrowStream(backend_->default_device_ordinal())
.ConsumeValueOrDie();
return CompileMlirModule(*module, stream.get());
}
} // namespace gpu
} // namespace xla

View File

@ -29,6 +29,9 @@ class MlirGpuTestBase : public HloTestBase {
StatusOr<std::vector<std::vector<uint8>>> RunMlirTextWithHostBuffers(
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments);
StatusOr<std::unique_ptr<Executable>> CompileMlirText(
absl::string_view module_text);
template <typename T>
static absl::Span<uint8> ToUint8Span(std::vector<T>* v) {
return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
@ -46,10 +49,16 @@ class MlirGpuTestBase : public HloTestBase {
StatusOr<std::vector<std::vector<uint8>>> RunMlirModuleWithHostBuffers(
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments);
StatusOr<std::unique_ptr<Executable>> CompileMlirModule(mlir::ModuleOp module,
se::Stream* stream);
StatusOr<ExecutionOutput> RunMlirModule(
mlir::ModuleOp module, se::Stream* stream,
absl::Span<const se::DeviceMemoryBase> arguments);
mlir::OwningModuleRef ParseMlirModule(absl::string_view module_text,
mlir::MLIRContext& context);
std::unique_ptr<xla::Backend> backend_;
};