Migrate TF MLIR shape inference pass to use declarative pass registration instead of manually defined pass registration (NFC).
PiperOrigin-RevId: 347430171 Change-Id: Iff3e9c3a6c2ddcca1ef7351adadc2c9ba75e0d4a
This commit is contained in:
parent
89ac5d4c81
commit
e290ea66ba
@ -79,6 +79,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
|
@ -639,6 +639,7 @@ cc_library(
|
||||
":tensorflow_tfrt_ops_inc_gen",
|
||||
":tensorflow_traits",
|
||||
":tensorflow_types",
|
||||
":tf_pass_inc_gen",
|
||||
":tf_saved_model_inc_gen",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -401,6 +401,9 @@ CreateTPUCompileOpReplicationPass();
|
||||
|
||||
} // namespace TFTPU
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
|
||||
|
@ -36,9 +36,6 @@ class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
|
||||
}
|
||||
};
|
||||
|
||||
PassRegistration<ShapeInference> pass(
|
||||
"tf-shape-inference", "Simple Shape Inference on TensorFlow Dialect");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
|
||||
|
@ -21,7 +21,7 @@ def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
|
||||
let summary = "Simple Shape Inference on TensorFlow Dialect";
|
||||
// TODO(jpienaar): Write `description`.
|
||||
|
||||
let constructor = "CreateTFShapeInferencePass()";
|
||||
let constructor = "TF::CreateTFShapeInferencePass()";
|
||||
|
||||
let options = [
|
||||
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
@ -29,6 +30,7 @@ int main(int argc, char **argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTensorFlowPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user