Move shape inference pass description to ODS
This moves the description of the TF shape inference pass to ODS. Generating the documentation for the pass in a follow up. PiperOrigin-RevId: 344145233 Change-Id: Ic435f441a563c39060abc2d5c418b9f574335700
This commit is contained in:
parent
daf6e17a73
commit
e623786212
@ -825,6 +825,22 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gentbl(
|
||||||
|
name = "tf_pass_inc_gen",
|
||||||
|
compatible_with = get_compatible_with_cloud(),
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
"-gen-pass-decls -name TensorFlow",
|
||||||
|
"transforms/tf_passes.h.inc",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||||
|
td_file = "transforms/tf_passes.td",
|
||||||
|
td_srcs = [
|
||||||
|
"@llvm-project//mlir:PassBaseTdFiles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_passes",
|
name = "tensorflow_passes",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -920,6 +936,8 @@ cc_library(
|
|||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"ir/tf_ops_helpers.inc",
|
"ir/tf_ops_helpers.inc",
|
||||||
|
"transforms/passes_detail.h",
|
||||||
|
"transforms/tf_passes.h.inc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":attribute_utils",
|
":attribute_utils",
|
||||||
@ -942,6 +960,7 @@ cc_library(
|
|||||||
":tensorflow_optimize_inc_gen",
|
":tensorflow_optimize_inc_gen",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
":tf_data_optimization",
|
":tf_data_optimization",
|
||||||
|
":tf_pass_inc_gen",
|
||||||
":tpu_rewrite_device_util",
|
":tpu_rewrite_device_util",
|
||||||
":translate_utils",
|
":translate_utils",
|
||||||
":unroll_batch_matmul_pass",
|
":unroll_batch_matmul_pass",
|
||||||
|
|||||||
@ -193,8 +193,6 @@ std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
|
|||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
class GraphOp;
|
|
||||||
|
|
||||||
// Returns a pass that folds switch nodes with constant predicates.
|
// Returns a pass that folds switch nodes with constant predicates.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
|
||||||
|
|
||||||
@ -221,14 +219,10 @@ CreateTFExecutorTPUV1IslandInliningPass();
|
|||||||
// Creates a pass to prune tf_executor.graph from dead nodes.
|
// Creates a pass to prune tf_executor.graph from dead nodes.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
|
||||||
|
|
||||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
|
||||||
void PruneGraph(GraphOp graph);
|
|
||||||
|
|
||||||
// Sink `tf.Const` operations in the LaunchOp region using them. This is
|
// Sink `tf.Const` operations in the LaunchOp region using them. This is
|
||||||
// performed in order to limit the number of values implicitly captured in this
|
// performed in order to limit the number of values implicitly captured in this
|
||||||
// region before outlining.
|
// region before outlining.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
|
||||||
|
|
||||||
} // namespace tf_executor
|
} // namespace tf_executor
|
||||||
|
|
||||||
namespace TFDevice {
|
namespace TFDevice {
|
||||||
|
|||||||
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||||
@ -35,8 +35,6 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations = 10);
|
|||||||
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
||||||
// Note: This affects the entire module, and changes are not just scoped to the
|
// Note: This affects the entire module, and changes are not just scoped to the
|
||||||
// function being inferred.
|
// function being inferred.
|
||||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
|
|
||||||
// SCCP pass instead.
|
|
||||||
LogicalResult InferShapeForFunction(FuncOp func,
|
LogicalResult InferShapeForFunction(FuncOp func,
|
||||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||||
int64_t graph_version,
|
int64_t graph_version,
|
||||||
|
|||||||
@ -17,11 +17,10 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "tf-shape-inference"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
@ -29,25 +28,12 @@ namespace {
|
|||||||
|
|
||||||
// This transformation pass propagate shapes on the TensorFlow graph.
|
// This transformation pass propagate shapes on the TensorFlow graph.
|
||||||
// It is a ModulePass in order to be able to change function types.
|
// It is a ModulePass in order to be able to change function types.
|
||||||
class ShapeInference
|
class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
|
||||||
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
|
|
||||||
public:
|
public:
|
||||||
ShapeInference() = default;
|
|
||||||
ShapeInference(const ShapeInference &) {}
|
|
||||||
explicit ShapeInference(int64_t max_iterations) {
|
|
||||||
max_iterations_ = max_iterations;
|
|
||||||
}
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
if (failed(InferModuleShape(getOperation(), max_iterations_))) {
|
if (failed(InferModuleShape(getOperation(), max_iterations_)))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
Option<int64_t> max_iterations_{
|
|
||||||
*this, "max-iterations",
|
|
||||||
llvm::cl::desc("Maximum shape inference iterations."),
|
|
||||||
llvm::cl::init(10)};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
PassRegistration<ShapeInference> pass(
|
PassRegistration<ShapeInference> pass(
|
||||||
|
|||||||
30
tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
Normal file
30
tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
// TF dialect passes.
|
||||||
|
|
||||||
|
def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
|
||||||
|
let summary = "Simple Shape Inference on TensorFlow Dialect";
|
||||||
|
// TODO(jpienaar): Write `description`.
|
||||||
|
|
||||||
|
let constructor = "CreateTFShapeInferencePass()";
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
|
||||||
|
"Maximum shape inference iterations">
|
||||||
|
];
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user