Migrate TF MLIR shape inference pass to use declarative pass registration instead of manually defined pass registration (NFC).

PiperOrigin-RevId: 347430171
Change-Id: Iff3e9c3a6c2ddcca1ef7351adadc2c9ba75e0d4a
This commit is contained in:
Andy Ly 2020-12-14 11:27:49 -08:00 committed by TensorFlower Gardener
parent 89ac5d4c81
commit e290ea66ba
6 changed files with 8 additions and 4 deletions

View File

@ -79,6 +79,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"//tensorflow/core:lib",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",

View File

@ -639,6 +639,7 @@ cc_library(
":tensorflow_tfrt_ops_inc_gen",
":tensorflow_traits",
":tensorflow_types",
":tf_pass_inc_gen",
":tf_saved_model_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/core:framework",

View File

@ -401,6 +401,9 @@ CreateTPUCompileOpReplicationPass();
} // namespace TFTPU
#define GEN_PASS_REGISTRATION
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_

View File

@ -36,9 +36,6 @@ class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
}
};
PassRegistration<ShapeInference> pass(
"tf-shape-inference", "Simple Shape Inference on TensorFlow Dialect");
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {

View File

@ -21,7 +21,7 @@ def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
let summary = "Simple Shape Inference on TensorFlow Dialect";
// TODO(jpienaar): Write `description`.
let constructor = "CreateTFShapeInferencePass()";
let constructor = "TF::CreateTFShapeInferencePass()";
let options = [
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/core/platform/init_main.h"
@ -29,6 +30,7 @@ int main(int argc, char **argv) {
tensorflow::InitMlir y(&argc, &argv);
mlir::registerAllPasses();
mlir::registerTensorFlowPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();