diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 4c14bcf8960..32a2ed1c272 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -133,7 +133,6 @@ cc_library( ":hlo_utils", ":mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", @@ -334,6 +333,7 @@ cc_library( ":mlir_hlo_to_hlo", "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -342,6 +342,7 @@ cc_library( "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Translation", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc index 1a3f0c16247..de8d6fc697b 100644 --- a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -42,13 +42,13 @@ class XlaBuilderTest : public ::testing::Test { protected: XlaBuilderTest() : name_(SetupTest()), - context_(), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))), builder_(&module_->getBodyRegion()), - xla_builder_(name_, builder_, module_->getLoc()) {} + xla_builder_(name_, builder_, module_->getLoc()) { + context_.loadDialect(); + } string SetupTest() { - mlir::registerDialect(); return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir index 97c53cb5f9f..0c2aee5a2fd 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir @@ -2,6 +2,6 @@ // CHECK: Opaque elements attr not supported func @main() { - %0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> + %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> return } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 3462b3b7a5a..2c733bb5ca2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -69,6 +69,10 @@ namespace { constexpr char kShardingAttr[] = "mhlo.sharding"; class LegalizeTF : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 22462428367..ef362d95b97 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -34,6 +35,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" @@ -133,6 +136,11 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // MLIR LHLO. class XlaHloToLhloPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: XlaHloToLhloPass() = default; XlaHloToLhloPass(const XlaHloToLhloPass&) {} @@ -438,7 +446,7 @@ Status LhloDialectEmitter::Initialize() { builder_.setInsertionPointToEnd(block); auto return_op = builder_.create(builder_.getUnknownLoc()); - builder_ = mlir::OpBuilder(return_op); + builder_ = OpBuilder(return_op); return Status::OK(); } @@ -449,6 +457,9 @@ std::unique_ptr> createXlaHloToLhloWithXlaPass() { Status HloToLhloModule(const BufferAssignment& assignment, const HloModule& hlo_module, ModuleOp module) { + module.getContext() + ->loadDialect(); HloComputation* computation = hlo_module.entry_computation(); LhloDialectEmitter emitter(assignment, *computation, module); @@ -462,15 +473,14 @@ Status HloToLhloModule(const BufferAssignment& assignment, return computation->AcceptOrdered(&emitter, ordering); } -mlir::OwningModuleRef HloTextToLhloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context) { +OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, + MLIRContext* context) { StatusOr> maybe_module = xla::ParseAndReturnUnverifiedModule( absl::string_view(input.data(), input.size())); TF_CHECK_OK(maybe_module.status()); - mlir::OwningModuleRef module = - mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context)); TF_CHECK_OK( ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host")); diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index ce709b10462..a4a2bc42d99 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -64,7 +64,6 @@ inline ::testing::PolymorphicMatcher EqualsProto( TEST(TypeToShapeTest, ConvertPrimitiveTypes) { MLIRContext context; - context.loadAllGloballyRegisteredDialects(); Builder b(&context); EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32); @@ -75,7 +74,6 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) { TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { MLIRContext context; - context.loadAllGloballyRegisteredDialects(); Builder b(&context); EXPECT_TRUE( @@ -97,7 +95,6 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) { MLIRContext context; - context.loadAllGloballyRegisteredDialects(); Builder b(&context); // Memref without any affine map. Note: memory space is ignored for shape. diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 158671a6242..d5c598615b7 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -17,8 +17,11 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" @@ -173,11 +176,17 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( } // namespace xla +static void RegisterInputDialects(mlir::DialectRegistry& registry) { + registry.insert(); +} + static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate( - "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction); + "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction, + RegisterInputDialects); static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( - "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction); + "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction, + RegisterInputDialects); static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);