Add first test for MLIR GPU backend.
This adds the first IR generation tests. The test compiles a simple HLO computation consisting of just an add operation to LHLO mlir format. Also fixes a couple of issues I have encountered while writing the test. PiperOrigin-RevId: 265890901
This commit is contained in:
parent
c4fc64c728
commit
f78a3d92b2
@ -940,6 +940,16 @@ cc_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_gpu_plugin",
|
||||||
|
deps = [
|
||||||
|
":service",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
|
||||||
|
"//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "interpreter_plugin",
|
name = "interpreter_plugin",
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -77,3 +77,20 @@ cc_library(
|
|||||||
"@local_config_mlir//:StandardOps",
|
"@local_config_mlir//:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_irgen_test_base",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["mlir_irgen_test_base.cc"],
|
||||||
|
hdrs = ["mlir_irgen_test_base.h"],
|
||||||
|
deps = [
|
||||||
|
":failover_compiler",
|
||||||
|
":mlir_compiler",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla/tests:codegen_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:filecheck",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"@llvm//:support",
|
||||||
|
"@local_config_mlir//:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@ -68,6 +68,9 @@ class FailoverCompiler final : public Compiler {
|
|||||||
|
|
||||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
|
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
|
||||||
|
|
||||||
|
Compiler* GetPrimary() const { return primary_.get(); }
|
||||||
|
Compiler* GetSecondary() const { return secondary_.get(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<Compiler> primary_;
|
std::unique_ptr<Compiler> primary_;
|
||||||
std::unique_ptr<Compiler> secondary_;
|
std::unique_ptr<Compiler> secondary_;
|
||||||
|
|||||||
@ -79,8 +79,8 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
|
|||||||
func_builder.create<::mlir::xla_lhlo::MaxOp>(loc, rets, args, attrs);
|
func_builder.create<::mlir::xla_lhlo::MaxOp>(loc, rets, args, attrs);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(absl::StrCat(
|
||||||
absl::StrCat("Opcode: ", opcode, " is not supported."));
|
"Opcode ", HloOpcodeString(opcode), " is not supported."));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ StatusOr<FuncOp> LhloDialectEmitter::CreateFunction(
|
|||||||
mlir_module_.push_back(function);
|
mlir_module_.push_back(function);
|
||||||
function.addEntryBlock();
|
function.addEntryBlock();
|
||||||
instruction_to_mlir_func_[&instr] = function;
|
instruction_to_mlir_func_[&instr] = function;
|
||||||
return Status::OK();
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||||
@ -211,8 +211,12 @@ Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status LhloDialectEmitter::HandleParameter(HloInstruction* parameter) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status LhloDialectEmitter::FinishVisit(HloInstruction* root) {
|
Status LhloDialectEmitter::FinishVisit(HloInstruction* root) {
|
||||||
LOG(FATAL) << "Not implemented yet.";
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlir_gpu
|
} // namespace mlir_gpu
|
||||||
|
|||||||
@ -53,6 +53,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
|||||||
|
|
||||||
Status HandleFusion(HloInstruction* fusion) override;
|
Status HandleFusion(HloInstruction* fusion) override;
|
||||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||||
|
Status HandleParameter(HloInstruction* parameter) override;
|
||||||
|
|
||||||
Status FinishVisit(HloInstruction* root) override;
|
Status FinishVisit(HloInstruction* root) override;
|
||||||
|
|
||||||
|
|||||||
@ -41,6 +41,7 @@ namespace {
|
|||||||
|
|
||||||
using ::mlir::MLIRContext;
|
using ::mlir::MLIRContext;
|
||||||
using ::mlir::ModuleOp;
|
using ::mlir::ModuleOp;
|
||||||
|
using ::mlir::OwningModuleRef;
|
||||||
using ::mlir::UnknownLoc;
|
using ::mlir::UnknownLoc;
|
||||||
using ::mlir::LLVM::LLVMDialect;
|
using ::mlir::LLVM::LLVMDialect;
|
||||||
using ::xla::gpu::GpuExecutable;
|
using ::xla::gpu::GpuExecutable;
|
||||||
@ -143,9 +144,17 @@ StatusOr<std::unique_ptr<Executable>> MlirCompiler::RunBackend(
|
|||||||
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
|
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
|
||||||
|
|
||||||
MLIRContext mlir_context;
|
MLIRContext mlir_context;
|
||||||
auto mlir_module = ModuleOp::create(UnknownLoc::get(&mlir_context));
|
OwningModuleRef mlir_module =
|
||||||
|
ModuleOp::create(UnknownLoc::get(&mlir_context));
|
||||||
LhloDialectEmitter lhlo_emitter(*module, *buffer_assignment,
|
LhloDialectEmitter lhlo_emitter(*module, *buffer_assignment,
|
||||||
stream_exec->platform(), mlir_module);
|
stream_exec->platform(), *mlir_module);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
lhlo_emitter.EmitComputation(*module->entry_computation()));
|
||||||
|
|
||||||
|
if (module_hook_.callback && !module_hook_.apply_on_lowered) {
|
||||||
|
module_hook_.callback(*mlir_module);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/137624192): Emit function per hlo and turn into ptx string and blob.
|
// TODO(b/137624192): Emit function per hlo and turn into ptx string and blob.
|
||||||
std::string ptx;
|
std::string ptx;
|
||||||
@ -181,6 +190,12 @@ MlirCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
|||||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MlirCompiler::SetModuleHook(IRHook module_hook) {
|
||||||
|
module_hook_ = module_hook;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirCompiler::RemoveModuleHook() { module_hook_ = {nullptr, false}; }
|
||||||
|
|
||||||
} // namespace mlir_gpu
|
} // namespace mlir_gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||||
|
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -55,9 +56,18 @@ class MlirCompiler : public Compiler {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IRHook {
|
||||||
|
std::function<void(mlir::ModuleOp)> callback;
|
||||||
|
bool apply_on_lowered;
|
||||||
|
};
|
||||||
|
|
||||||
|
void SetModuleHook(IRHook module_hook);
|
||||||
|
void RemoveModuleHook();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
::mlir::MLIRContext context_;
|
::mlir::MLIRContext context_;
|
||||||
int64 pointer_size_;
|
int64 pointer_size_;
|
||||||
|
IRHook module_hook_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir_gpu
|
} // namespace mlir_gpu
|
||||||
|
|||||||
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace mlir_gpu {
|
||||||
|
|
||||||
|
void MlirIrGenTestBase::CompileAndVerifyIr(
|
||||||
|
std::unique_ptr<HloModule> hlo_module, const string& pattern,
|
||||||
|
bool match_lowered_ir) {
|
||||||
|
MlirCompiler* compiler = GetMLIRCompiler();
|
||||||
|
string ir;
|
||||||
|
compiler->SetModuleHook({[&ir](mlir::ModuleOp module) -> Status {
|
||||||
|
std::string buffer_string;
|
||||||
|
llvm::raw_string_ostream ostream(buffer_string);
|
||||||
|
module.print(ostream);
|
||||||
|
ostream.flush();
|
||||||
|
ir = buffer_string;
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
match_lowered_ir});
|
||||||
|
Status status = CompileToExecutable(std::move(hlo_module)).status();
|
||||||
|
compiler->RemoveModuleHook();
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
|
||||||
|
StatusOr<bool> filecheck_result = RunFileCheck(ir, pattern);
|
||||||
|
TF_ASSERT_OK(filecheck_result.status());
|
||||||
|
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
|
||||||
|
const string& expected_llvm_ir,
|
||||||
|
bool match_lowered_ir) {
|
||||||
|
HloModuleConfig config;
|
||||||
|
config.set_debug_options(GetDebugOptionsForTest());
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnUnverifiedModule(hlo_text, config));
|
||||||
|
CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_lowered_ir);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() {
|
||||||
|
// TODO(b/137624192): Remove failover once no longer in place.
|
||||||
|
FailoverCompiler* failover =
|
||||||
|
static_cast<FailoverCompiler*>(backend().compiler());
|
||||||
|
return static_cast<MlirCompiler*>(failover->GetPrimary());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir_gpu
|
||||||
|
} // namespace xla
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/codegen_test_base.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace mlir_gpu {
|
||||||
|
|
||||||
|
// Tests that verify IR emitted by the CPU/GPU backend is as expected.
|
||||||
|
class MlirIrGenTestBase : public CodegenTestBase {
|
||||||
|
protected:
|
||||||
|
// Compiles the given HLO module to MLIR IR and verifies the IR matches the
|
||||||
|
// given pattern. `pattern` is in the FileCheck pattern matching syntax
|
||||||
|
// (http://llvm.org/docs/CommandGuide/FileCheck.html).
|
||||||
|
//
|
||||||
|
// This function invokes the JIT compiler.
|
||||||
|
//
|
||||||
|
// If `match_lowered_ir` is true, match the version of the IR after lowering
|
||||||
|
// steps to LLVM IR are applied; otherwise, the IR before lowering is
|
||||||
|
// matched.
|
||||||
|
void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
|
||||||
|
const string& pattern, bool match_lowered_ir = false);
|
||||||
|
|
||||||
|
// A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create
|
||||||
|
// an HLO module.
|
||||||
|
void CompileAndVerifyIr(const string& hlo_text,
|
||||||
|
const string& expected_llvm_ir,
|
||||||
|
bool match_lowered_ir = false);
|
||||||
|
|
||||||
|
// Compiles and returns module with optimizations from a given HLO.
|
||||||
|
StatusOr<std::unique_ptr<HloModule>> GetOptimizedModule(
|
||||||
|
absl::string_view hlo);
|
||||||
|
|
||||||
|
private:
|
||||||
|
MlirCompiler* GetMLIRCompiler();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir_gpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_
|
||||||
42
tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD
Normal file
42
tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# TODO(herhut): describe this package.
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
load(
|
||||||
|
"//tensorflow/core/platform:default/build_config_root.bzl",
|
||||||
|
"tf_cuda_tests_tags",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [":friends"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
package_group(
|
||||||
|
name = "friends",
|
||||||
|
includes = [
|
||||||
|
"//tensorflow/compiler/xla:friends",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "mlir_gpu_lhlo_gen_test",
|
||||||
|
srcs = ["mlir_gpu_lhlo_gen_test.cc"],
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
|
||||||
|
"//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:filecheck",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace mlir_gpu {
|
||||||
|
|
||||||
|
class LhloGenTest : public MlirIrGenTestBase {};
|
||||||
|
|
||||||
|
TEST_F(LhloGenTest, Add) {
|
||||||
|
CompileAndVerifyIr(R"(
|
||||||
|
HloModule Add
|
||||||
|
|
||||||
|
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
|
||||||
|
%x = f32[2,2]{1,0} parameter(0)
|
||||||
|
%y = f32[2,2]{1,0} parameter(1)
|
||||||
|
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
|
||||||
|
})",
|
||||||
|
R"(
|
||||||
|
;CHECK: module {
|
||||||
|
;CHECK: func @add(%{{.*}}: memref<2x2xf32>, %{{.*}}: memref<2x2xf32>, %{{.*}}: memref<2x2xf32>) {
|
||||||
|
;CHECK: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %{{.*}}) {name = "add"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
;CHECK: }
|
||||||
|
;CHECK: }
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir_gpu
|
||||||
|
} // namespace xla
|
||||||
Loading…
x
Reference in New Issue
Block a user