Move shape inference pass description to ODS
This moves the description of the TF shape inference pass to ODS. Generating the documentation for the pass in a follow up. PiperOrigin-RevId: 344145233 Change-Id: Ic435f441a563c39060abc2d5c418b9f574335700
This commit is contained in:
		
							parent
							
								
									daf6e17a73
								
							
						
					
					
						commit
						e623786212
					
				@ -825,6 +825,22 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
gentbl(
 | 
			
		||||
    name = "tf_pass_inc_gen",
 | 
			
		||||
    compatible_with = get_compatible_with_cloud(),
 | 
			
		||||
    tbl_outs = [
 | 
			
		||||
        (
 | 
			
		||||
            "-gen-pass-decls -name TensorFlow",
 | 
			
		||||
            "transforms/tf_passes.h.inc",
 | 
			
		||||
        ),
 | 
			
		||||
    ],
 | 
			
		||||
    tblgen = "@llvm-project//mlir:mlir-tblgen",
 | 
			
		||||
    td_file = "transforms/tf_passes.td",
 | 
			
		||||
    td_srcs = [
 | 
			
		||||
        "@llvm-project//mlir:PassBaseTdFiles",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "tensorflow_passes",
 | 
			
		||||
    srcs = [
 | 
			
		||||
@ -920,6 +936,8 @@ cc_library(
 | 
			
		||||
    includes = ["include"],
 | 
			
		||||
    textual_hdrs = [
 | 
			
		||||
        "ir/tf_ops_helpers.inc",
 | 
			
		||||
        "transforms/passes_detail.h",
 | 
			
		||||
        "transforms/tf_passes.h.inc",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":attribute_utils",
 | 
			
		||||
@ -942,6 +960,7 @@ cc_library(
 | 
			
		||||
        ":tensorflow_optimize_inc_gen",
 | 
			
		||||
        ":tensorflow_types",
 | 
			
		||||
        ":tf_data_optimization",
 | 
			
		||||
        ":tf_pass_inc_gen",
 | 
			
		||||
        ":tpu_rewrite_device_util",
 | 
			
		||||
        ":translate_utils",
 | 
			
		||||
        ":unroll_batch_matmul_pass",
 | 
			
		||||
 | 
			
		||||
@ -193,8 +193,6 @@ std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
 | 
			
		||||
}  // namespace TF
 | 
			
		||||
 | 
			
		||||
namespace tf_executor {
 | 
			
		||||
class GraphOp;
 | 
			
		||||
 | 
			
		||||
// Returns a pass that folds switch nodes with constant predicates.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
 | 
			
		||||
 | 
			
		||||
@ -221,14 +219,10 @@ CreateTFExecutorTPUV1IslandInliningPass();
 | 
			
		||||
// Creates a pass to prune tf_executor.graph from dead nodes.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
 | 
			
		||||
 | 
			
		||||
// Prunes unreachable operations of a tf_executor.graph operation.
 | 
			
		||||
void PruneGraph(GraphOp graph);
 | 
			
		||||
 | 
			
		||||
// Sink `tf.Const` operations in the LaunchOp region using them. This is
 | 
			
		||||
// performed in order to limit the number of values implicitly captured in this
 | 
			
		||||
// region before outlining.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
 | 
			
		||||
 | 
			
		||||
}  // namespace tf_executor
 | 
			
		||||
 | 
			
		||||
namespace TFDevice {
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,31 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/Dialect.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
#define GEN_PASS_CLASSES
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
 | 
			
		||||
 | 
			
		||||
}  // namespace TF
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
 | 
			
		||||
@ -35,8 +35,6 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations = 10);
 | 
			
		||||
// If arg_shapes are empty, then argument shapes will be left unchanged.
 | 
			
		||||
// Note: This affects the entire module, and changes are not just scoped to the
 | 
			
		||||
// function being inferred.
 | 
			
		||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
 | 
			
		||||
// SCCP pass instead.
 | 
			
		||||
LogicalResult InferShapeForFunction(FuncOp func,
 | 
			
		||||
                                    ArrayRef<ArrayRef<int64_t>> arg_shapes,
 | 
			
		||||
                                    int64_t graph_version,
 | 
			
		||||
 | 
			
		||||
@ -17,11 +17,10 @@ limitations under the License.
 | 
			
		||||
#include "mlir/Pass/Pass.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Support/LLVM.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
 | 
			
		||||
#include "tensorflow/core/framework/shape_inference.h"
 | 
			
		||||
 | 
			
		||||
#define DEBUG_TYPE "tf-shape-inference"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
@ -29,25 +28,12 @@ namespace {
 | 
			
		||||
 | 
			
		||||
// This transformation pass propagate shapes on the TensorFlow graph.
 | 
			
		||||
// It is a ModulePass in order to be able to change function types.
 | 
			
		||||
class ShapeInference
 | 
			
		||||
    : public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
 | 
			
		||||
class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
 | 
			
		||||
 public:
 | 
			
		||||
  ShapeInference() = default;
 | 
			
		||||
  ShapeInference(const ShapeInference &) {}
 | 
			
		||||
  explicit ShapeInference(int64_t max_iterations) {
 | 
			
		||||
    max_iterations_ = max_iterations;
 | 
			
		||||
  }
 | 
			
		||||
  void runOnOperation() override {
 | 
			
		||||
    if (failed(InferModuleShape(getOperation(), max_iterations_))) {
 | 
			
		||||
    if (failed(InferModuleShape(getOperation(), max_iterations_)))
 | 
			
		||||
      return signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Option<int64_t> max_iterations_{
 | 
			
		||||
      *this, "max-iterations",
 | 
			
		||||
      llvm::cl::desc("Maximum shape inference iterations."),
 | 
			
		||||
      llvm::cl::init(10)};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
PassRegistration<ShapeInference> pass(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
include "mlir/Pass/PassBase.td"
 | 
			
		||||
 | 
			
		||||
// TF dialect passes.
 | 
			
		||||
 | 
			
		||||
def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
 | 
			
		||||
  let summary = "Simple Shape Inference on TensorFlow Dialect";
 | 
			
		||||
  // TODO(jpienaar): Write `description`.
 | 
			
		||||
 | 
			
		||||
  let constructor = "CreateTFShapeInferencePass()";
 | 
			
		||||
 | 
			
		||||
  let options = [
 | 
			
		||||
    Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
 | 
			
		||||
           "Maximum shape inference iterations">
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user