From 280665cb81e01691959b478f883cdf5ac89bd152 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar@google.com>
Date: Tue, 26 May 2020 14:16:22 -0700
Subject: [PATCH] Include shape dialect registration

Registering it everywhere where TF dialect is as this will be used for dynamic
shape lowering.

PiperOrigin-RevId: 313265819
Change-Id: Ic14f19324d043f52699052f3c3ce3ac3ea0302ff
---
 tensorflow/compiler/mlir/BUILD                              | 1 +
 tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc    | 4 +++-
 tensorflow/compiler/mlir/tensorflow/BUILD                   | 3 +--
 .../compiler/mlir/tensorflow/ir/dialect_registration.cc     | 2 ++
 .../compiler/mlir/tensorflow/utils/compile_mlir_util.cc     | 6 ++++--
 5 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index c0066ecda03..c4472e1185c 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -104,6 +104,7 @@ cc_library(
         "@com_google_absl//absl/container:flat_hash_set",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:StandardOps",
     ],
     alwayslink = 1,
diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
index 11d3e7332db..b2225ec1c4f 100644
--- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
+++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_os_ostream.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
 static void RegisterDialects() {
   static bool init_once = []() {
     mlir::registerDialect<mlir::StandardOpsDialect>();
+    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
+    mlir::registerDialect<mlir::shape::ShapeDialect>();
     mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
     mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
     return true;
   }();
   (void)init_once;
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index b2b4c09df3b..de0af94f0cb 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -559,8 +559,7 @@ cc_library(
     srcs = ["ir/dialect_registration.cc"],
     deps = [
         ":tensorflow",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:SCFTransforms",
+        "@llvm-project//mlir:Shape",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc
index ac468d9810c..c95d7b7ca7c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect>
     tf_device_dialect;
 static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
     tf_saved_model_dialect;
+static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect;
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 03283da0112..fd1ba3b1901 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
@@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
 
 static void RegisterDialects() {
   static bool init_once = []() {
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
     mlir::registerDialect<mlir::StandardOpsDialect>();
+    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
+    mlir::registerDialect<mlir::shape::ShapeDialect>();
+    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
     mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
     return true;
   }();