[XLA:GPU] Migrate CollectivePermute thunk generation to MLIR
- Also extend mlir_gpu_test to have a compile only API and use that to test failure of attribute export for an invalid collective permute op. PiperOrigin-RevId: 355652447 Change-Id: Icd34714057541acaad168021e519aa6a62c25f9b
This commit is contained in:
parent
f7680f0926
commit
46c82584dc
@ -97,7 +97,7 @@ StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
|
||||
// rank 0 is num_groups, rank 1 is group size.
|
||||
auto replica_group_values_it = input.getValues<uint64_t>().begin();
|
||||
std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
|
||||
for (ReplicaGroup &group : replica_groups) {
|
||||
for (ReplicaGroup& group : replica_groups) {
|
||||
for (int64 element_idx = 0; element_idx < type.getDimSize(1);
|
||||
++element_idx, ++replica_group_values_it) {
|
||||
// For replica group attribute, -1 indicates padding added by
|
||||
@ -110,4 +110,26 @@ StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
|
||||
}
|
||||
return replica_groups;
|
||||
}
|
||||
|
||||
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
|
||||
// and source-target pairs are defined in HLO.
|
||||
StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
|
||||
if (!optional_attr.hasValue()) return std::vector<std::pair<int64, int64>>{};
|
||||
mlir::DenseIntElementsAttr attr = *optional_attr;
|
||||
auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
|
||||
return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
|
||||
auto it = attr.getValues<int64>().begin();
|
||||
std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
|
||||
for (auto& item : out) {
|
||||
int64 first = *it;
|
||||
++it;
|
||||
int64 second = *it;
|
||||
++it;
|
||||
item = {first, second};
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
@ -36,5 +36,10 @@ StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
|
||||
mlir::DenseIntElementsAttr input);
|
||||
|
||||
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
|
||||
// and source-target pairs are defined in HLO.
|
||||
StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr);
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
|
@ -163,32 +163,14 @@ static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
|
||||
return fft_type_enum;
|
||||
}
|
||||
|
||||
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
|
||||
// and source-target pairs are defined in HLO.
|
||||
static std::vector<std::pair<int64, int64>> Convert_Nx2_attribute(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
|
||||
if (!optional_attr.hasValue()) return {};
|
||||
mlir::DenseIntElementsAttr attr = *optional_attr;
|
||||
auto it = attr.getValues<int64>().begin();
|
||||
std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
|
||||
for (auto& item : out) {
|
||||
int64 first = *it;
|
||||
++it;
|
||||
int64 second = *it;
|
||||
++it;
|
||||
item = {first, second};
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<int64, int64>> Convert_padding(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> padding) {
|
||||
return Convert_Nx2_attribute(padding);
|
||||
return xla::ConvertNx2Attribute(padding).ValueOrDie();
|
||||
}
|
||||
|
||||
static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
|
||||
return Convert_Nx2_attribute(source_target_pairs);
|
||||
return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
|
||||
}
|
||||
|
||||
static std::vector<xla::ReplicaGroup> Convert_replica_groups(
|
||||
|
@ -218,19 +218,12 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
CollectivePermuteConfig GetCollectivePermuteConfig(
|
||||
const HloInstruction* instr) {
|
||||
CollectivePermuteConfig config;
|
||||
auto* collective_permute = Cast<HloCollectivePermuteInstruction>(instr);
|
||||
config.source_target_pairs = collective_permute->source_target_pairs();
|
||||
return config;
|
||||
}
|
||||
|
||||
CollectivePermuteThunk::CollectivePermuteThunk(
|
||||
ThunkInfo thunk_info, CollectivePermuteConfig config,
|
||||
ThunkInfo thunk_info,
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs,
|
||||
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest)
|
||||
: Thunk(kCollectivePermute, thunk_info),
|
||||
config_(std::move(config)),
|
||||
source_target_pairs_(std::move(source_target_pairs)),
|
||||
src_(src),
|
||||
dest_(dest) {}
|
||||
|
||||
@ -254,7 +247,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
// Figure out which replicas our data is copied to.
|
||||
std::vector<int64> dest_replicas;
|
||||
for (const auto& src_dest : config_.source_target_pairs) {
|
||||
for (const auto& src_dest : source_target_pairs_) {
|
||||
if (src_dest.first == replica_id) {
|
||||
dest_replicas.push_back(src_dest.second);
|
||||
}
|
||||
@ -269,7 +262,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
// If no replica writes into us (i.e. we aren't the target of any copies), our
|
||||
// contract is that we zero our output.
|
||||
if (absl::c_none_of(config_.source_target_pairs,
|
||||
if (absl::c_none_of(source_target_pairs_,
|
||||
[&](std::pair<int64, int64> src_dest) {
|
||||
return src_dest.second == replica_id;
|
||||
})) {
|
||||
|
@ -24,23 +24,18 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
struct CollectivePermuteConfig {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
};
|
||||
|
||||
CollectivePermuteConfig GetCollectivePermuteConfig(const HloInstruction* instr);
|
||||
|
||||
// Thunk that implements the collective-permute HLO.
|
||||
class CollectivePermuteThunk : public Thunk {
|
||||
public:
|
||||
CollectivePermuteThunk(ThunkInfo thunk_info, CollectivePermuteConfig config,
|
||||
const BufferAllocation::Slice& src,
|
||||
const BufferAllocation::Slice& dest);
|
||||
CollectivePermuteThunk(
|
||||
ThunkInfo thunk_info,
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs,
|
||||
const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest);
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
private:
|
||||
const CollectivePermuteConfig config_;
|
||||
const std::vector<std::pair<int64, int64>> source_target_pairs_;
|
||||
const BufferAllocation::Slice src_;
|
||||
const BufferAllocation::Slice dest_;
|
||||
};
|
||||
|
@ -1119,11 +1119,8 @@ StatusOr<std::unique_ptr<Executable>> CompileLmhloToExecutable(
|
||||
IrEmitterUnnested::Create(module_config, /*hlo_computation=*/nullptr,
|
||||
ir_emitter_context));
|
||||
ThunkSequence thunk_sequence;
|
||||
for (mlir::Operation& op : entry_function.getBody().front()) {
|
||||
// Omit the terminator.
|
||||
if (&op == &entry_function.getBody().front().back()) {
|
||||
continue;
|
||||
}
|
||||
for (mlir::Operation& op :
|
||||
entry_function.getBody().front().without_terminator()) {
|
||||
MlirEmitterInput input;
|
||||
input.op = &op;
|
||||
TF_RETURN_IF_ERROR(ir_emitter->EmitOp(input));
|
||||
|
@ -2940,10 +2940,31 @@ Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
CollectivePermuteConfig config = GetCollectivePermuteConfig(hlo);
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||
return EmitCollectivePermuteFromMlir(input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitCollectivePermuteFromMlir(
|
||||
MlirEmitterInput input) {
|
||||
auto collective_permute_op =
|
||||
mlir::cast<mlir::lmhlo::CollectivePermuteOp>(input.op);
|
||||
if (collective_permute_op.channel_id())
|
||||
return Unimplemented("collective permute with channel_id not implemented");
|
||||
using source_dest_pairs_t = std::vector<std::pair<int64, int64>>;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
source_dest_pairs_t source_dest_pairs,
|
||||
ConvertNx2Attribute(collective_permute_op.source_target_pairs()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice source_slice,
|
||||
GetAllocationSliceForMlir(collective_permute_op.operand()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice result_slice,
|
||||
GetAllocationSliceForMlir(collective_permute_op.output()));
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
|
||||
GetThunkInfo(hlo), std::move(config),
|
||||
GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo)));
|
||||
input.thunk_info, std::move(source_dest_pairs), source_slice,
|
||||
result_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -5763,6 +5784,9 @@ Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
|
||||
if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
|
||||
return EmitSortFromMlir(mlir_input);
|
||||
}
|
||||
if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(mlir_input.op)) {
|
||||
return EmitCollectivePermuteFromMlir(mlir_input);
|
||||
}
|
||||
LOG(FATAL)
|
||||
<< "This function is for test only, and the op is not implemented: "
|
||||
<< MlirToString(mlir_input.op);
|
||||
|
@ -215,6 +215,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||
Status EmitCollectivePermuteFromMlir(MlirEmitterInput input);
|
||||
|
||||
Status EmitOp(MlirEmitterInput mlir_input);
|
||||
|
||||
|
@ -88,6 +88,17 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_gpu_compile_test",
|
||||
srcs = ["mlir_gpu_compile_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":mlir_gpu_test_base",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gemm_rewrite_test",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// Tests XLA GPU compilation using MLIR LMHLO dialect as the input.
|
||||
class CompileTest : public MlirGpuTestBase {};
|
||||
|
||||
TEST_F(CompileTest, InvalidCollectivePermuteOp) {
|
||||
const char* mlir_text = R"(
|
||||
func @main(%arg0: memref<4xf32> {lmhlo.alloc = 0 : index, lmhlo.params = 0 : index},
|
||||
%arg1: memref<4xf32> {lmhlo.alloc = 1 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>}) -> () {
|
||||
"lmhlo.collective_permute"(%arg0, %arg1) {source_target_pairs = dense<[[0, 1, 2]]> : tensor<1x3xi64>} : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
"std.return" () : () -> ()
|
||||
})";
|
||||
auto executable = CompileMlirText(mlir_text);
|
||||
ASSERT_FALSE(executable.ok());
|
||||
EXPECT_STREQ("expected Nx2 attribute to be a tensor of shape Nx2",
|
||||
executable.status().error_message().c_str());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
|
||||
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
@ -37,9 +38,8 @@ MlirGpuTestBase::MlirGpuTestBase() {
|
||||
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments) {
|
||||
StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream) {
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
@ -72,12 +72,16 @@ StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
||||
|
||||
HloModuleConfig module_config;
|
||||
module_config.set_debug_options(DefaultDebugOptionsIgnoringFlags());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executable,
|
||||
CompileLmhloToExecutable(static_cast<GpuCompiler*>(backend_->compiler()),
|
||||
module, "TestModule", module_config,
|
||||
Compiler::CompileOptions(), "main", stream_exec,
|
||||
std::move(llvm_module), &ir_emitter_context));
|
||||
return CompileLmhloToExecutable(
|
||||
static_cast<GpuCompiler*>(backend_->compiler()), module, "TestModule",
|
||||
module_config, Compiler::CompileOptions(), "main", stream_exec,
|
||||
std::move(llvm_module), &ir_emitter_context);
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments) {
|
||||
TF_ASSIGN_OR_RETURN(auto executable, CompileMlirModule(module, stream));
|
||||
|
||||
ExecutableRunOptions executable_run_options;
|
||||
executable_run_options.set_stream(stream);
|
||||
@ -137,10 +141,8 @@ MlirGpuTestBase::RunMlirModuleWithHostBuffers(
|
||||
return host_outputs;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<uint8>>>
|
||||
MlirGpuTestBase::RunMlirTextWithHostBuffers(
|
||||
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef MlirGpuTestBase::ParseMlirModule(
|
||||
absl::string_view module_text, mlir::MLIRContext& context) {
|
||||
context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
|
||||
mlir::StandardOpsDialect,
|
||||
mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
@ -148,8 +150,25 @@ MlirGpuTestBase::RunMlirTextWithHostBuffers(
|
||||
mlir::OwningModuleRef module = parseSourceString(
|
||||
llvm::StringRef(module_text.data(), module_text.size()), &context);
|
||||
CHECK(module);
|
||||
return module;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<uint8>>>
|
||||
MlirGpuTestBase::RunMlirTextWithHostBuffers(
|
||||
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module = ParseMlirModule(module_text, context);
|
||||
return RunMlirModuleWithHostBuffers(*module, arguments);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirText(
|
||||
absl::string_view module_text) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module = ParseMlirModule(module_text, context);
|
||||
auto stream = backend_->BorrowStream(backend_->default_device_ordinal())
|
||||
.ConsumeValueOrDie();
|
||||
return CompileMlirModule(*module, stream.get());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -29,6 +29,9 @@ class MlirGpuTestBase : public HloTestBase {
|
||||
StatusOr<std::vector<std::vector<uint8>>> RunMlirTextWithHostBuffers(
|
||||
absl::string_view module_text, std::vector<absl::Span<uint8>> arguments);
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CompileMlirText(
|
||||
absl::string_view module_text);
|
||||
|
||||
template <typename T>
|
||||
static absl::Span<uint8> ToUint8Span(std::vector<T>* v) {
|
||||
return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
|
||||
@ -46,10 +49,16 @@ class MlirGpuTestBase : public HloTestBase {
|
||||
StatusOr<std::vector<std::vector<uint8>>> RunMlirModuleWithHostBuffers(
|
||||
mlir::ModuleOp module, std::vector<absl::Span<uint8>> arguments);
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CompileMlirModule(mlir::ModuleOp module,
|
||||
se::Stream* stream);
|
||||
|
||||
StatusOr<ExecutionOutput> RunMlirModule(
|
||||
mlir::ModuleOp module, se::Stream* stream,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments);
|
||||
|
||||
mlir::OwningModuleRef ParseMlirModule(absl::string_view module_text,
|
||||
mlir::MLIRContext& context);
|
||||
|
||||
std::unique_ptr<xla::Backend> backend_;
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user