Move shape inference pass description to ODS

This moves the description of the TF shape inference pass to ODS. Generating the documentation for the pass in a follow up.

PiperOrigin-RevId: 344145233
Change-Id: Ic435f441a563c39060abc2d5c418b9f574335700
This commit is contained in:
Jacques Pienaar 2020-11-24 15:43:50 -08:00 committed by TensorFlower Gardener
parent daf6e17a73
commit e623786212
6 changed files with 83 additions and 25 deletions

View File

@ -825,6 +825,22 @@ cc_library(
],
)
gentbl(
name = "tf_pass_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
"-gen-pass-decls -name TensorFlow",
"transforms/tf_passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tf_passes.td",
td_srcs = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
cc_library(
name = "tensorflow_passes",
srcs = [
@ -920,6 +936,8 @@ cc_library(
includes = ["include"],
textual_hdrs = [
"ir/tf_ops_helpers.inc",
"transforms/passes_detail.h",
"transforms/tf_passes.h.inc",
],
deps = [
":attribute_utils",
@ -942,6 +960,7 @@ cc_library(
":tensorflow_optimize_inc_gen",
":tensorflow_types",
":tf_data_optimization",
":tf_pass_inc_gen",
":tpu_rewrite_device_util",
":translate_utils",
":unroll_batch_matmul_pass",

View File

@ -193,8 +193,6 @@ std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
} // namespace TF
namespace tf_executor {
class GraphOp;
// Returns a pass that folds switch nodes with constant predicates.
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
@ -221,14 +219,10 @@ CreateTFExecutorTPUV1IslandInliningPass();
// Creates a pass to prune tf_executor.graph from dead nodes.
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph);
// Sink `tf.Const` operations in the LaunchOp region using them. This is
// performed in order to limit the number of values implicitly captured in this
// region before outlining.
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
} // namespace tf_executor
namespace TFDevice {

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace TF {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_

View File

@ -35,8 +35,6 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations = 10);
// If arg_shapes are empty, then argument shapes will be left unchanged.
// Note: This affects the entire module, and changes are not just scoped to the
// function being inferred.
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
// SCCP pass instead.
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version,

View File

@ -17,11 +17,10 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
#include "tensorflow/core/framework/shape_inference.h"
#define DEBUG_TYPE "tf-shape-inference"
namespace mlir {
namespace TF {
@ -29,25 +28,12 @@ namespace {
// This transformation pass propagate shapes on the TensorFlow graph.
// It is a ModulePass in order to be able to change function types.
class ShapeInference
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
public:
ShapeInference() = default;
ShapeInference(const ShapeInference &) {}
explicit ShapeInference(int64_t max_iterations) {
max_iterations_ = max_iterations;
}
void runOnOperation() override {
if (failed(InferModuleShape(getOperation(), max_iterations_))) {
if (failed(InferModuleShape(getOperation(), max_iterations_)))
return signalPassFailure();
}
}
private:
Option<int64_t> max_iterations_{
*this, "max-iterations",
llvm::cl::desc("Maximum shape inference iterations."),
llvm::cl::init(10)};
};
PassRegistration<ShapeInference> pass(

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/Pass/PassBase.td"
// TF dialect passes.
def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
let summary = "Simple Shape Inference on TensorFlow Dialect";
// TODO(jpienaar): Write `description`.
let constructor = "CreateTFShapeInferencePass()";
let options = [
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
"Maximum shape inference iterations">
];
}