Migrate TF MLIR shape inference pass to use declarative pass registration instead of manually defined pass registration (NFC).
PiperOrigin-RevId: 347430171 Change-Id: Iff3e9c3a6c2ddcca1ef7351adadc2c9ba75e0d4a
This commit is contained in:
parent
89ac5d4c81
commit
e290ea66ba
@ -79,6 +79,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
|
@ -639,6 +639,7 @@ cc_library(
|
|||||||
":tensorflow_tfrt_ops_inc_gen",
|
":tensorflow_tfrt_ops_inc_gen",
|
||||||
":tensorflow_traits",
|
":tensorflow_traits",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
|
":tf_pass_inc_gen",
|
||||||
":tf_saved_model_inc_gen",
|
":tf_saved_model_inc_gen",
|
||||||
"//tensorflow/compiler/mlir/lite:validators",
|
"//tensorflow/compiler/mlir/lite:validators",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -401,6 +401,9 @@ CreateTPUCompileOpReplicationPass();
|
|||||||
|
|
||||||
} // namespace TFTPU
|
} // namespace TFTPU
|
||||||
|
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
|
||||||
|
@ -36,9 +36,6 @@ class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
PassRegistration<ShapeInference> pass(
|
|
||||||
"tf-shape-inference", "Simple Shape Inference on TensorFlow Dialect");
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
|
||||||
|
@ -21,7 +21,7 @@ def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
|
|||||||
let summary = "Simple Shape Inference on TensorFlow Dialect";
|
let summary = "Simple Shape Inference on TensorFlow Dialect";
|
||||||
// TODO(jpienaar): Write `description`.
|
// TODO(jpienaar): Write `description`.
|
||||||
|
|
||||||
let constructor = "CreateTFShapeInferencePass()";
|
let constructor = "TF::CreateTFShapeInferencePass()";
|
||||||
|
|
||||||
let options = [
|
let options = [
|
||||||
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
|
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ int main(int argc, char **argv) {
|
|||||||
tensorflow::InitMlir y(&argc, &argv);
|
tensorflow::InitMlir y(&argc, &argv);
|
||||||
|
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
|
mlir::registerTensorFlowPasses();
|
||||||
mlir::mhlo::registerAllMhloPasses();
|
mlir::mhlo::registerAllMhloPasses();
|
||||||
mlir::lmhlo::registerAllLmhloPasses();
|
mlir::lmhlo::registerAllLmhloPasses();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user