Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/tfrt/...
PiperOrigin-RevId: 328168585 Change-Id: I702964790826638b194eb7c5b21f98492e90c727
This commit is contained in:
parent
db4bab7ccc
commit
2d0592a000
@ -133,7 +133,6 @@ cc_library(
|
||||
":hlo_utils",
|
||||
":mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -334,6 +333,7 @@ cc_library(
|
||||
":mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -342,6 +342,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -42,13 +42,13 @@ class XlaBuilderTest : public ::testing::Test {
|
||||
protected:
|
||||
XlaBuilderTest()
|
||||
: name_(SetupTest()),
|
||||
context_(),
|
||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
|
||||
builder_(&module_->getBodyRegion()),
|
||||
xla_builder_(name_, builder_, module_->getLoc()) {}
|
||||
xla_builder_(name_, builder_, module_->getLoc()) {
|
||||
context_.loadDialect<mlir::mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
string SetupTest() {
|
||||
mlir::registerDialect<mlir::mhlo::MhloDialect>();
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,6 @@
|
||||
|
||||
// CHECK: Opaque elements attr not supported
|
||||
func @main() {
|
||||
%0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
|
@ -69,6 +69,10 @@ namespace {
|
||||
constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF &) {}
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
@ -34,6 +35,8 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -133,6 +136,11 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
||||
// MLIR LHLO.
|
||||
class XlaHloToLhloPass
|
||||
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
|
||||
mlir::lmhlo::LmhloDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
XlaHloToLhloPass() = default;
|
||||
XlaHloToLhloPass(const XlaHloToLhloPass&) {}
|
||||
@ -438,7 +446,7 @@ Status LhloDialectEmitter::Initialize() {
|
||||
builder_.setInsertionPointToEnd(block);
|
||||
|
||||
auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc());
|
||||
builder_ = mlir::OpBuilder(return_op);
|
||||
builder_ = OpBuilder(return_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -449,6 +457,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
|
||||
|
||||
Status HloToLhloModule(const BufferAssignment& assignment,
|
||||
const HloModule& hlo_module, ModuleOp module) {
|
||||
module.getContext()
|
||||
->loadDialect<StandardOpsDialect, mhlo::MhloDialect,
|
||||
lmhlo::LmhloDialect>();
|
||||
HloComputation* computation = hlo_module.entry_computation();
|
||||
|
||||
LhloDialectEmitter emitter(assignment, *computation, module);
|
||||
@ -462,15 +473,14 @@ Status HloToLhloModule(const BufferAssignment& assignment,
|
||||
return computation->AcceptOrdered(&emitter, ordering);
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef HloTextToLhloTranslateFunction(
|
||||
llvm::StringRef input, mlir::MLIRContext* context) {
|
||||
OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
|
||||
MLIRContext* context) {
|
||||
StatusOr<std::unique_ptr<HloModule>> maybe_module =
|
||||
xla::ParseAndReturnUnverifiedModule(
|
||||
absl::string_view(input.data(), input.size()));
|
||||
TF_CHECK_OK(maybe_module.status());
|
||||
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
|
||||
|
||||
TF_CHECK_OK(
|
||||
ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host"));
|
||||
|
@ -64,7 +64,6 @@ inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
|
||||
|
||||
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
Builder b(&context);
|
||||
|
||||
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
|
||||
@ -75,7 +74,6 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
||||
|
||||
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
Builder b(&context);
|
||||
|
||||
EXPECT_TRUE(
|
||||
@ -97,7 +95,6 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
||||
|
||||
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
|
||||
MLIRContext context;
|
||||
context.loadAllGloballyRegisteredDialects();
|
||||
Builder b(&context);
|
||||
|
||||
// Memref without any affine map. Note: memory space is ignored for shape.
|
||||
|
@ -17,8 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||
@ -173,11 +176,17 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
|
||||
} // namespace xla
|
||||
|
||||
static void RegisterInputDialects(mlir::DialectRegistry& registry) {
|
||||
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
|
||||
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction);
|
||||
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction,
|
||||
RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
|
||||
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction);
|
||||
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction,
|
||||
RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
||||
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
||||
|
Loading…
x
Reference in New Issue
Block a user