Include shape dialect registration

Registering it everywhere where TF dialect is as this will be used for dynamic
shape lowering.

PiperOrigin-RevId: 313265819
Change-Id: Ic14f19324d043f52699052f3c3ce3ac3ea0302ff
This commit is contained in:
Jacques Pienaar 2020-05-26 14:16:22 -07:00 committed by TensorFlower Gardener
parent 8182ab3bfc
commit 280665cb81
5 changed files with 11 additions and 5 deletions

View File

@ -104,6 +104,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;

View File

@ -559,8 +559,7 @@ cc_library(
srcs = ["ir/dialect_registration.cc"],
deps = [
":tensorflow",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Shape",
],
alwayslink = 1,
)

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect>
tf_device_dialect;
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
tf_saved_model_dialect;
static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect;
} // namespace mlir

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
return true;
}();