[MLIR] Add generic (pre/in/post)-order IR walker
- Use this generic walker which invokes the callback before, in-between, and after visiting regions attached to Ops. - Add a transform that will annotate the IR with the walk order seen using this walker and another that does the same with an interrupting walker. - Add unit test that will use the walk annotation transform to check that the walk visits operations in the expected order. PiperOrigin-RevId: 326475851 Change-Id: I228b46a1bd93ff325f22233066955fc07855987d
This commit is contained in:
parent
b07e34b7b3
commit
466275b90e
@ -778,6 +778,7 @@ cc_library(
|
|||||||
"transforms/tensor_list_ops_decomposition.cc",
|
"transforms/tensor_list_ops_decomposition.cc",
|
||||||
"transforms/test_resource_alias_analysis.cc",
|
"transforms/test_resource_alias_analysis.cc",
|
||||||
"transforms/test_side_effect_analysis.cc",
|
"transforms/test_side_effect_analysis.cc",
|
||||||
|
"transforms/test_visitor_util.cc",
|
||||||
"transforms/tf_data_optimization_pass.cc",
|
"transforms/tf_data_optimization_pass.cc",
|
||||||
"transforms/tf_device_assignment.cc",
|
"transforms/tf_device_assignment.cc",
|
||||||
"transforms/tpu_cluster_formation.cc",
|
"transforms/tpu_cluster_formation.cc",
|
||||||
@ -825,6 +826,7 @@ cc_library(
|
|||||||
":tpu_rewrite_device_util",
|
":tpu_rewrite_device_util",
|
||||||
":translate_utils",
|
":translate_utils",
|
||||||
":unroll_batch_matmul_pass",
|
":unroll_batch_matmul_pass",
|
||||||
|
":visitor_util",
|
||||||
":xla_sharding_util",
|
":xla_sharding_util",
|
||||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
"//tensorflow/compiler/mlir/lite:validators",
|
"//tensorflow/compiler/mlir/lite:validators",
|
||||||
@ -1785,6 +1787,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "visitor_util",
|
||||||
|
srcs = [
|
||||||
|
"utils/visitor_util.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"utils/visitor_util.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_sharding_util",
|
name = "xla_sharding_util",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -0,0 +1,91 @@
|
|||||||
|
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s
|
||||||
|
|
||||||
|
// Test simple operations with no regions and no interrupts. They should be
|
||||||
|
// visited with stage "before all regions".
|
||||||
|
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{4: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test simple operations with no regions and interrupts. No remarks after
|
||||||
|
// the interrupting operation is visited.
|
||||||
|
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{2: walk was interrupted}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
%0 = "tf.Identity"(%arg0) {interrupt_before_all = true} : (tensor<f32>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test operation with non empty regions.
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{5: walk was interrupted}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
%0 = "tf.unknownop"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {interrupt_after_all = true} : (tensor<f32>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test operation with multiple regions.
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{5: walk was interrupted}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
%0 = "tf.unknownop"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {interrupt_after_region = 0} : (tensor<f32>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test static filtering
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{7: walk was interrupted}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
// expected-remark@below {{5: before region #1}}
|
||||||
|
// expected-remark@below {{8: before all regions}}
|
||||||
|
// expected-remark@below {{9: before region #1}}
|
||||||
|
// expected-remark@below {{10: after all regions}}
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
// expected-remark@below {{6: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
"tf.Yield"(%1) { interrupt_after_all = true } : (tensor<f32>) -> ()
|
||||||
|
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
102
tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
Normal file
102
tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s
|
||||||
|
|
||||||
|
// Test simple operations with no regions. They should be visited with stage
|
||||||
|
// = before all regions.
|
||||||
|
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{4: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test operation with empty regions.
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{5: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
// expected-remark@below {{3: after all regions}}
|
||||||
|
%0 = "tf.unknownop"(%arg0) ({
|
||||||
|
}) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test operation with non empty regions.
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{7: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
// expected-remark@below {{5: after all regions}}
|
||||||
|
%0 = "tf.unknownop"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{6: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test operation with multiple regions.
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{10: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
// expected-remark@below {{5: before region #1}}
|
||||||
|
// expected-remark@below {{8: after all regions}}
|
||||||
|
%0 = "tf.unknownop"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
// expected-remark@below {{6: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{7: before all regions}}
|
||||||
|
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{9: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// Test static filtering
|
||||||
|
// expected-remark@below {{0: before all regions}}
|
||||||
|
// expected-remark@below {{10: after all regions}}
|
||||||
|
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// expected-remark@below {{1: before all regions}}
|
||||||
|
%cst = constant dense<1.0> : tensor<f32>
|
||||||
|
// expected-remark@below {{2: before all regions}}
|
||||||
|
// expected-remark@below {{5: before region #1}}
|
||||||
|
// expected-remark@below {{8: after all regions}}
|
||||||
|
// expected-remark@below {{11: before all regions}}
|
||||||
|
// expected-remark@below {{12: before region #1}}
|
||||||
|
// expected-remark@below {{13: after all regions}}
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ({
|
||||||
|
// expected-remark@below {{3: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{4: before all regions}}
|
||||||
|
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
// expected-remark@below {{6: before all regions}}
|
||||||
|
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{7: before all regions}}
|
||||||
|
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
|
||||||
|
// expected-remark@below {{9: before all regions}}
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::string get_stage_description(const WalkStage &stage) {
|
||||||
|
if (stage.IsBeforeAllRegions()) return "before all regions";
|
||||||
|
if (stage.IsAfterAllRegions()) return "after all regions";
|
||||||
|
return "before region #" + std::to_string(stage.GetNextRegion());
|
||||||
|
}
|
||||||
|
|
||||||
|
// A pass that annotates each operation with an remarks that include a unique
|
||||||
|
// step ID and a description of the visitor step.
|
||||||
|
class TestVisitorUtil
|
||||||
|
: public mlir::PassWrapper<TestVisitorUtil, mlir::FunctionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
mlir::FuncOp func = getOperation();
|
||||||
|
int step_id = 0;
|
||||||
|
GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) {
|
||||||
|
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Exercise static inference of operation type
|
||||||
|
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
|
||||||
|
op.emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TestVisitorUtilInterrupt
|
||||||
|
: public mlir::PassWrapper<TestVisitorUtilInterrupt, mlir::FunctionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
mlir::FuncOp func = getOperation();
|
||||||
|
int step_id = 0;
|
||||||
|
|
||||||
|
auto walker = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||||
|
if (auto interrupt_before_all =
|
||||||
|
op->getAttrOfType<mlir::BoolAttr>("interrupt_before_all"))
|
||||||
|
if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions())
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
|
||||||
|
if (auto interrupt_after_all =
|
||||||
|
op->getAttrOfType<mlir::BoolAttr>("interrupt_after_all"))
|
||||||
|
if (interrupt_after_all.getValue() && stage.IsAfterAllRegions())
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
|
||||||
|
if (auto interrupt_after_region =
|
||||||
|
op->getAttrOfType<mlir::IntegerAttr>("interrupt_after_region"))
|
||||||
|
if (stage.IsAfterRegion(
|
||||||
|
static_cast<int>(interrupt_after_region.getInt())))
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
|
||||||
|
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||||
|
return mlir::WalkResult::advance();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Interrupt the walk based on attributes on the operation.
|
||||||
|
auto result = GenericWalk(func, walker);
|
||||||
|
|
||||||
|
if (result.wasInterrupted())
|
||||||
|
func.emitRemark() << step_id++ << ": walk was interrupted";
|
||||||
|
|
||||||
|
// Exercise static inference of operation type for interrupting callback.
|
||||||
|
result =
|
||||||
|
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
|
||||||
|
return walker(op, stage);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (result.wasInterrupted())
|
||||||
|
func.emitRemark() << step_id++ << ": walk was interrupted";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::PassRegistration<TestVisitorUtil> pass(
|
||||||
|
"tf-test-visitor-util",
|
||||||
|
"Add remarks that trace order of visiting operations using TF visitor "
|
||||||
|
"utilities.");
|
||||||
|
|
||||||
|
mlir::PassRegistration<TestVisitorUtilInterrupt> pass_interrupt(
|
||||||
|
"tf-test-visitor-util-interrupt",
|
||||||
|
"Add remarks that trace order of visiting operations using TF visitor "
|
||||||
|
"utilities, interrupt version.");
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace tensorflow
|
70
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
Normal file
70
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
WalkStage::WalkStage(mlir::Operation *op)
|
||||||
|
: num_regions_(op->getNumRegions()), next_region_(0) {}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Walk all of the operations nested under and including the given operations.
|
||||||
|
void WalkOperations(mlir::Operation *op, VoidCallback callback) {
|
||||||
|
WalkStage stage(op);
|
||||||
|
|
||||||
|
for (auto ®ion : op->getRegions()) {
|
||||||
|
// Invoke callback on the parent op before visiting each child region.
|
||||||
|
callback(op, stage);
|
||||||
|
stage.Advance();
|
||||||
|
|
||||||
|
for (auto &block : region)
|
||||||
|
// Early increment here in the case where the operation is erased.
|
||||||
|
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
||||||
|
WalkOperations(&nestedOp, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke callback after all regions have been visited.
|
||||||
|
callback(op, stage);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Walk all of the operations nested under and including the given operations.
|
||||||
|
/// This methods walks operations until an interrupt signal is received.
|
||||||
|
mlir::WalkResult WalkOperations(mlir::Operation *op,
|
||||||
|
InterruptCallback callback) {
|
||||||
|
WalkStage stage(op);
|
||||||
|
|
||||||
|
for (auto ®ion : op->getRegions()) {
|
||||||
|
// Invoke callback on the parent op before visiting each child region.
|
||||||
|
if (callback(op, stage).wasInterrupted())
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
|
||||||
|
stage.Advance();
|
||||||
|
|
||||||
|
for (auto &block : region) {
|
||||||
|
// Early increment here in the case where the operation is erased.
|
||||||
|
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
||||||
|
if (WalkOperations(&nestedOp, callback).wasInterrupted())
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return callback(op, stage);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace tensorflow
|
168
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
Normal file
168
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
|
|
||||||
|
// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
|
||||||
|
// walk() utility that MLIR core provides traverses operations in a block/
|
||||||
|
// blocks in a region in the program order, and these walkers do the same. When
|
||||||
|
// operations have regions attached to them, the core MLIR walkers visit the
|
||||||
|
// regions attached to an Op first, and then visit the op. So within the context
|
||||||
|
// of a single Op, the traversal is post-order (considering the Op as the parent
|
||||||
|
// node and regions as the children). For certain use cases, it may be more
|
||||||
|
// efficient/desirable to visit the parent Op before visiting the attached
|
||||||
|
// regions. As an example, if the attached regions have region arguments that
|
||||||
|
// are related to the operation inputs (tf.WhileRegion is an example), then we
|
||||||
|
// may want to propagate some information from the Op inputs to the region
|
||||||
|
// inputs and then visit the regions to continue progagating that information
|
||||||
|
// within the regions. With just post-order traversal, to acheive the same we
|
||||||
|
// may need to schedule another walk so make sure child regions get visited.
|
||||||
|
// A pre-order walk (within the context of a single operation) will avoid that.
|
||||||
|
// Similarly, for certain operations, we may want to visit the Op both before
|
||||||
|
// and after all regions have been visited (say to propagate information from
|
||||||
|
// inputs -> region arguments and then from region results -> outputs).
|
||||||
|
|
||||||
|
// In general, since the data flow between an operation and its regions is
|
||||||
|
// opaque in MLIR, we may need to visit the operation in-between regions as well
|
||||||
|
// if say region0 is transferring control back to the Op and from then to
|
||||||
|
// region1. So a more general walker that supports pre/in/post-order walk is
|
||||||
|
// desirable. To support this, the generic walkers defined below will invoke
|
||||||
|
// the walk callback on the parent Op at each stage of the child region walk,
|
||||||
|
// i.e., before visiting any region, in between regions, and after visiting all
|
||||||
|
// regions. To indicate the current walk stage, the callback will also get a
|
||||||
|
// `WalkState` parameter. The callback can inspect the current walk stage and
|
||||||
|
// decide to take appropriate actions (incuding not doing anything). With this
|
||||||
|
// the walker below can support pre/in/post-order walks as well as combined
|
||||||
|
// walks (pre+in+post)-order walk.
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// A class to indicate the current walk stage.
|
||||||
|
class WalkStage {
|
||||||
|
public:
|
||||||
|
explicit WalkStage(mlir::Operation *op);
|
||||||
|
|
||||||
|
bool IsBeforeAllRegions() const { return next_region_ == 0; }
|
||||||
|
bool IsBeforeRegion(int region) const { return next_region_ == region; }
|
||||||
|
bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
|
||||||
|
bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
|
||||||
|
void Advance() { next_region_++; }
|
||||||
|
int GetNextRegion() const { return next_region_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int num_regions_;
|
||||||
|
int next_region_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
// This is similar to MLIR version, but works with multiple argument functions.
|
||||||
|
// Helper templates to deduce the first argument of a callback parameter.
|
||||||
|
template <typename Ret, typename Arg, typename... Rest>
|
||||||
|
Arg first_argument_type(Ret (*)(Arg, Rest...));
|
||||||
|
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||||
|
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
|
||||||
|
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||||
|
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
|
||||||
|
template <typename F>
|
||||||
|
decltype(first_argument_type(&F::operator())) first_argument_type(F);
|
||||||
|
|
||||||
|
/// Type definition of the first argument to the given callable 'T'.
|
||||||
|
template <typename T>
|
||||||
|
using first_argument = decltype(first_argument_type(std::declval<T>()));
|
||||||
|
|
||||||
|
using VoidCallback =
|
||||||
|
llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
|
||||||
|
using InterruptCallback =
|
||||||
|
llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
|
||||||
|
|
||||||
|
// Walk all of the operations nested under and including the given operation.
|
||||||
|
void WalkOperations(mlir::Operation *op, VoidCallback callback);
|
||||||
|
|
||||||
|
// Walk all of the operations nested under and including the given operation.
|
||||||
|
// This methods walks operations until an interrupt result is returned by the
|
||||||
|
// callback.
|
||||||
|
mlir::WalkResult WalkOperations(mlir::Operation *op,
|
||||||
|
InterruptCallback callback);
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Walk all of the operations nested under and including the given operation.
|
||||||
|
// This method is selected for stage-aware callbacks that operate on Operation*.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
|
||||||
|
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
|
typename RetT = decltype(std::declval<FuncTy>()(
|
||||||
|
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||||
|
typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
|
||||||
|
RetT>::type
|
||||||
|
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||||
|
return detail::WalkOperations(
|
||||||
|
op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk all of the operations of type 'ArgT' nested under and including the
|
||||||
|
// given operation. This method is selected for void returning callbacks that
|
||||||
|
// operate on a specific derived operation type.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
|
||||||
|
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
|
typename RetT = decltype(std::declval<FuncTy>()(
|
||||||
|
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||||
|
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
|
||||||
|
std::is_same<RetT, void>::value,
|
||||||
|
RetT>::type
|
||||||
|
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||||
|
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||||
|
if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
|
||||||
|
};
|
||||||
|
return detail::WalkOperations(op,
|
||||||
|
static_cast<detail::VoidCallback>(wrapperFn));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk all of the operations of type 'ArgT' nested under and including the
|
||||||
|
// given operation. This method is selected for WalkReturn returning
|
||||||
|
// interruptible callbacks that operate on a specific derived operation type.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
|
||||||
|
// if (some_invariant)
|
||||||
|
// return WalkResult::interrupt();
|
||||||
|
// return WalkResult::advance();
|
||||||
|
// });
|
||||||
|
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
|
typename RetT = decltype(std::declval<FuncTy>()(
|
||||||
|
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||||
|
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
|
||||||
|
std::is_same<RetT, mlir::WalkResult>::value,
|
||||||
|
RetT>::type
|
||||||
|
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||||
|
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||||
|
if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
|
||||||
|
return callback(derivedOp, stage);
|
||||||
|
return mlir::WalkResult::advance();
|
||||||
|
};
|
||||||
|
return detail::WalkOperations(
|
||||||
|
op, static_cast<detail::InterruptCallback>(wrapperFn));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
Loading…
Reference in New Issue
Block a user