From 280665cb81e01691959b478f883cdf5ac89bd152 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar <jpienaar@google.com> Date: Tue, 26 May 2020 14:16:22 -0700 Subject: [PATCH] Include shape dialect registration Registering it everywhere where TF dialect is as this will be used for dynamic shape lowering. PiperOrigin-RevId: 313265819 Change-Id: Ic14f19324d043f52699052f3c3ce3ac3ea0302ff --- tensorflow/compiler/mlir/BUILD | 1 + tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc | 4 +++- tensorflow/compiler/mlir/tensorflow/BUILD | 3 +-- .../compiler/mlir/tensorflow/ir/dialect_registration.cc | 2 ++ .../compiler/mlir/tensorflow/utils/compile_mlir_util.cc | 6 ++++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index c0066ecda03..c4472e1185c 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -104,6 +104,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 11d3e7332db..b2225ec1c4f 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { static void RegisterDialects() { static bool init_once = []() { mlir::registerDialect<mlir::StandardOpsDialect>(); + mlir::registerDialect<mlir::TF::TensorFlowDialect>(); + mlir::registerDialect<mlir::shape::ShapeDialect>(); mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>(); mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); - mlir::registerDialect<mlir::TF::TensorFlowDialect>(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index b2b4c09df3b..de0af94f0cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -559,8 +559,7 @@ cc_library( srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:Shape", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index ac468d9810c..c95d7b7ca7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect> tf_device_dialect; static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect> tf_saved_model_dialect; +static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect; } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 03283da0112..fd1ba3b1901 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes, static void RegisterDialects() { static bool init_once = []() { - mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); - mlir::registerDialect<mlir::TF::TensorFlowDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>(); + mlir::registerDialect<mlir::TF::TensorFlowDialect>(); + mlir::registerDialect<mlir::shape::ShapeDialect>(); + mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>(); return true; }();