Include shape dialect registration
Registering it everywhere where TF dialect is as this will be used for dynamic shape lowering. PiperOrigin-RevId: 313265819 Change-Id: Ic14f19324d043f52699052f3c3ce3ac3ea0302ff
This commit is contained in:
parent
8182ab3bfc
commit
280665cb81
@ -104,6 +104,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
|
@ -559,8 +559,7 @@ cc_library(
|
||||
srcs = ["ir/dialect_registration.cc"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:Shape",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect>
|
||||
tf_device_dialect;
|
||||
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
|
||||
tf_saved_model_dialect;
|
||||
static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect;
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
return true;
|
||||
}();
|
||||
|
Loading…
Reference in New Issue
Block a user