Move shape inference pass description to ODS
This moves the description of the TF shape inference pass to ODS. Generating the documentation for the pass in a follow up. PiperOrigin-RevId: 344145233 Change-Id: Ic435f441a563c39060abc2d5c418b9f574335700
This commit is contained in:
parent
daf6e17a73
commit
e623786212
@ -825,6 +825,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tf_pass_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-pass-decls -name TensorFlow",
|
||||
"transforms/tf_passes.h.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/tf_passes.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:PassBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -920,6 +936,8 @@ cc_library(
|
||||
includes = ["include"],
|
||||
textual_hdrs = [
|
||||
"ir/tf_ops_helpers.inc",
|
||||
"transforms/passes_detail.h",
|
||||
"transforms/tf_passes.h.inc",
|
||||
],
|
||||
deps = [
|
||||
":attribute_utils",
|
||||
@ -942,6 +960,7 @@ cc_library(
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization",
|
||||
":tf_pass_inc_gen",
|
||||
":tpu_rewrite_device_util",
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
|
||||
@ -193,8 +193,6 @@ std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
class GraphOp;
|
||||
|
||||
// Returns a pass that folds switch nodes with constant predicates.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
|
||||
|
||||
@ -221,14 +219,10 @@ CreateTFExecutorTPUV1IslandInliningPass();
|
||||
// Creates a pass to prune tf_executor.graph from dead nodes.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
|
||||
|
||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||
void PruneGraph(GraphOp graph);
|
||||
|
||||
// Sink `tf.Const` operations in the LaunchOp region using them. This is
|
||||
// performed in order to limit the number of values implicitly captured in this
|
||||
// region before outlining.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
|
||||
|
||||
} // namespace tf_executor
|
||||
|
||||
namespace TFDevice {
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_PASS_DETAIL_H_
|
||||
@ -35,8 +35,6 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations = 10);
|
||||
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
||||
// Note: This affects the entire module, and changes are not just scoped to the
|
||||
// function being inferred.
|
||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
|
||||
// SCCP pass instead.
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version,
|
||||
|
||||
@ -17,11 +17,10 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-shape-inference"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
@ -29,25 +28,12 @@ namespace {
|
||||
|
||||
// This transformation pass propagate shapes on the TensorFlow graph.
|
||||
// It is a ModulePass in order to be able to change function types.
|
||||
class ShapeInference
|
||||
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
|
||||
class ShapeInference : public TensorFlowShapeInferencePassBase<ShapeInference> {
|
||||
public:
|
||||
ShapeInference() = default;
|
||||
ShapeInference(const ShapeInference &) {}
|
||||
explicit ShapeInference(int64_t max_iterations) {
|
||||
max_iterations_ = max_iterations;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
if (failed(InferModuleShape(getOperation(), max_iterations_))) {
|
||||
if (failed(InferModuleShape(getOperation(), max_iterations_)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Option<int64_t> max_iterations_{
|
||||
*this, "max-iterations",
|
||||
llvm::cl::desc("Maximum shape inference iterations."),
|
||||
llvm::cl::init(10)};
|
||||
};
|
||||
|
||||
PassRegistration<ShapeInference> pass(
|
||||
|
||||
30
tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
Normal file
30
tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
// TF dialect passes.
|
||||
|
||||
def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> {
|
||||
let summary = "Simple Shape Inference on TensorFlow Dialect";
|
||||
// TODO(jpienaar): Write `description`.
|
||||
|
||||
let constructor = "CreateTFShapeInferencePass()";
|
||||
|
||||
let options = [
|
||||
Option<"max_iterations_", "max-iterations", "int64_t", /*default=*/"10",
|
||||
"Maximum shape inference iterations">
|
||||
];
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user