diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d26f09c5c91..7e341653bce 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -825,6 +825,22 @@ cc_library( ], ) +gentbl( + name = "tf_pass_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + "-gen-pass-decls -name TensorFlow", + "transforms/tf_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/tf_passes.td", + td_srcs = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + cc_library( name = "tensorflow_passes", srcs = [ @@ -920,6 +936,8 @@ cc_library( includes = ["include"], textual_hdrs = [ "ir/tf_ops_helpers.inc", + "transforms/passes_detail.h", + "transforms/tf_passes.h.inc", ], deps = [ ":attribute_utils", @@ -942,6 +960,7 @@ cc_library( ":tensorflow_optimize_inc_gen", ":tensorflow_types", ":tf_data_optimization", + ":tf_pass_inc_gen", ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 25945182c20..d00f19bfd38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -193,8 +193,6 @@ std::unique_ptr CreateClusterTFOpsByHostPass(); } // namespace TF namespace tf_executor { -class GraphOp; - // Returns a pass that folds switch nodes with constant predicates. std::unique_ptr> CreateSwitchFoldPass(); @@ -221,14 +219,10 @@ CreateTFExecutorTPUV1IslandInliningPass(); // Creates a pass to prune tf_executor.graph from dead nodes. std::unique_ptr> CreateTFExecutorGraphPruningPass(); -// Prunes unreachable operations of a tf_executor.graph operation. -void PruneGraph(GraphOp graph); - // Sink `tf.Const` operations in the LaunchOp region using them. This is // performed in order to limit the number of values implicitly captured in this // region before outlining. std::unique_ptr> CreateTFExecutorConstantSinkingPass(); - } // namespace tf_executor namespace TFDevice { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h new file mode 100644 index 00000000000..a6c069bfe93 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace TF { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 8eb3d355113..bbec8fa3ee6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -35,8 +35,6 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations = 10); // If arg_shapes are empty, then argument shapes will be left unchanged. // Note: This affects the entire module, and changes are not just scoped to the // function being inferred. -// TODO(b/154065712): Remove propagate_caller_callee_constants once using -// SCCP pass instead. LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 15923280ff6..36f62a7d3e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -17,11 +17,10 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h" -#define DEBUG_TYPE "tf-shape-inference" - namespace mlir { namespace TF { @@ -29,25 +28,12 @@ namespace { // This transformation pass propagate shapes on the TensorFlow graph. // It is a ModulePass in order to be able to change function types. -class ShapeInference - : public PassWrapper> { +class ShapeInference : public TensorFlowShapeInferencePassBase { public: - ShapeInference() = default; - ShapeInference(const ShapeInference &) {} - explicit ShapeInference(int64_t max_iterations) { - max_iterations_ = max_iterations; - } void runOnOperation() override { - if (failed(InferModuleShape(getOperation(), max_iterations_))) { + if (failed(InferModuleShape(getOperation(), max_iterations_))) return signalPassFailure(); - } } - - private: - Option max_iterations_{ - *this, "max-iterations", - llvm::cl::desc("Maximum shape inference iterations."), - llvm::cl::init(10)}; }; PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td new file mode 100644 index 00000000000..4a8076db651 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +// TF dialect passes. + +def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> { + let summary = "Simple Shape Inference on TensorFlow Dialect"; + // TODO(jpienaar): Write `description`. + + let constructor = "CreateTFShapeInferencePass()"; + + let options = [ + Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10", + "Maximum shape inference iterations"> + ]; +}