[MLIR] Add generic (pre/in/post)-order IR walker
- Use this generic walker which invokes the callback before, in-between, and after visiting regions attached to Ops. - Add a transform that will annotate the IR with the walk order seen using this walker and another that does the same with an interrupting walker. - Add unit test that will use the walk annotation transform to check that the walk visits operations in the expected order. PiperOrigin-RevId: 326475851 Change-Id: I228b46a1bd93ff325f22233066955fc07855987d
This commit is contained in:
parent
b07e34b7b3
commit
466275b90e
@ -778,6 +778,7 @@ cc_library(
|
||||
"transforms/tensor_list_ops_decomposition.cc",
|
||||
"transforms/test_resource_alias_analysis.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/test_visitor_util.cc",
|
||||
"transforms/tf_data_optimization_pass.cc",
|
||||
"transforms/tf_device_assignment.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
@ -825,6 +826,7 @@ cc_library(
|
||||
":tpu_rewrite_device_util",
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
":visitor_util",
|
||||
":xla_sharding_util",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
@ -1785,6 +1787,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "visitor_util",
|
||||
srcs = [
|
||||
"utils/visitor_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/visitor_util.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_sharding_util",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,91 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s
|
||||
|
||||
// Test simple operations with no regions and no interrupts. They should be
|
||||
// visited with stage "before all regions".
|
||||
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{4: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test simple operations with no regions and interrupts. No remarks after
|
||||
// the interrupting operation is visited.
|
||||
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{2: walk was interrupted}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
%0 = "tf.Identity"(%arg0) {interrupt_before_all = true} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test operation with non empty regions.
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{5: walk was interrupted}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
%0 = "tf.unknownop"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}) {interrupt_after_all = true} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test operation with multiple regions.
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{5: walk was interrupted}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
%0 = "tf.unknownop"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}) {interrupt_after_region = 0} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test static filtering
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{7: walk was interrupted}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
// expected-remark@below {{5: before region #1}}
|
||||
// expected-remark@below {{8: before all regions}}
|
||||
// expected-remark@below {{9: before region #1}}
|
||||
// expected-remark@below {{10: after all regions}}
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
// expected-remark@below {{6: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"tf.Yield"(%1) { interrupt_after_all = true } : (tensor<f32>) -> ()
|
||||
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
102
tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
Normal file
102
tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
Normal file
@ -0,0 +1,102 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s
|
||||
|
||||
// Test simple operations with no regions. They should be visited with stage
|
||||
// = before all regions.
|
||||
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{4: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test operation with empty regions.
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{5: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
// expected-remark@below {{3: after all regions}}
|
||||
%0 = "tf.unknownop"(%arg0) ({
|
||||
}) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test operation with non empty regions.
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{7: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
// expected-remark@below {{5: after all regions}}
|
||||
%0 = "tf.unknownop"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{6: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test operation with multiple regions.
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{10: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
// expected-remark@below {{5: before region #1}}
|
||||
// expected-remark@below {{8: after all regions}}
|
||||
%0 = "tf.unknownop"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
// expected-remark@below {{6: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{7: before all regions}}
|
||||
"tf.yield"(%1) : (tensor<f32>) -> ()
|
||||
}) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{9: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test static filtering
|
||||
// expected-remark@below {{0: before all regions}}
|
||||
// expected-remark@below {{10: after all regions}}
|
||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// expected-remark@below {{1: before all regions}}
|
||||
%cst = constant dense<1.0> : tensor<f32>
|
||||
// expected-remark@below {{2: before all regions}}
|
||||
// expected-remark@below {{5: before region #1}}
|
||||
// expected-remark@below {{8: after all regions}}
|
||||
// expected-remark@below {{11: before all regions}}
|
||||
// expected-remark@below {{12: before region #1}}
|
||||
// expected-remark@below {{13: after all regions}}
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
// expected-remark@below {{3: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{4: before all regions}}
|
||||
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
// expected-remark@below {{6: before all regions}}
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{7: before all regions}}
|
||||
"tf.Yield"(%1) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
|
||||
// expected-remark@below {{9: before all regions}}
|
||||
return %0 : tensor<f32>
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::string get_stage_description(const WalkStage &stage) {
|
||||
if (stage.IsBeforeAllRegions()) return "before all regions";
|
||||
if (stage.IsAfterAllRegions()) return "after all regions";
|
||||
return "before region #" + std::to_string(stage.GetNextRegion());
|
||||
}
|
||||
|
||||
// A pass that annotates each operation with an remarks that include a unique
|
||||
// step ID and a description of the visitor step.
|
||||
class TestVisitorUtil
|
||||
: public mlir::PassWrapper<TestVisitorUtil, mlir::FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp func = getOperation();
|
||||
int step_id = 0;
|
||||
GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) {
|
||||
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||
});
|
||||
|
||||
// Exercise static inference of operation type
|
||||
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
|
||||
op.emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TestVisitorUtilInterrupt
|
||||
: public mlir::PassWrapper<TestVisitorUtilInterrupt, mlir::FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp func = getOperation();
|
||||
int step_id = 0;
|
||||
|
||||
auto walker = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||
if (auto interrupt_before_all =
|
||||
op->getAttrOfType<mlir::BoolAttr>("interrupt_before_all"))
|
||||
if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions())
|
||||
return mlir::WalkResult::interrupt();
|
||||
|
||||
if (auto interrupt_after_all =
|
||||
op->getAttrOfType<mlir::BoolAttr>("interrupt_after_all"))
|
||||
if (interrupt_after_all.getValue() && stage.IsAfterAllRegions())
|
||||
return mlir::WalkResult::interrupt();
|
||||
|
||||
if (auto interrupt_after_region =
|
||||
op->getAttrOfType<mlir::IntegerAttr>("interrupt_after_region"))
|
||||
if (stage.IsAfterRegion(
|
||||
static_cast<int>(interrupt_after_region.getInt())))
|
||||
return mlir::WalkResult::interrupt();
|
||||
|
||||
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
|
||||
return mlir::WalkResult::advance();
|
||||
};
|
||||
|
||||
// Interrupt the walk based on attributes on the operation.
|
||||
auto result = GenericWalk(func, walker);
|
||||
|
||||
if (result.wasInterrupted())
|
||||
func.emitRemark() << step_id++ << ": walk was interrupted";
|
||||
|
||||
// Exercise static inference of operation type for interrupting callback.
|
||||
result =
|
||||
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
|
||||
return walker(op, stage);
|
||||
});
|
||||
|
||||
if (result.wasInterrupted())
|
||||
func.emitRemark() << step_id++ << ": walk was interrupted";
|
||||
}
|
||||
};
|
||||
|
||||
mlir::PassRegistration<TestVisitorUtil> pass(
|
||||
"tf-test-visitor-util",
|
||||
"Add remarks that trace order of visiting operations using TF visitor "
|
||||
"utilities.");
|
||||
|
||||
mlir::PassRegistration<TestVisitorUtilInterrupt> pass_interrupt(
|
||||
"tf-test-visitor-util-interrupt",
|
||||
"Add remarks that trace order of visiting operations using TF visitor "
|
||||
"utilities, interrupt version.");
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
70
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
Normal file
70
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
|
||||
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
WalkStage::WalkStage(mlir::Operation *op)
|
||||
: num_regions_(op->getNumRegions()), next_region_(0) {}
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Walk all of the operations nested under and including the given operations.
|
||||
void WalkOperations(mlir::Operation *op, VoidCallback callback) {
|
||||
WalkStage stage(op);
|
||||
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
// Invoke callback on the parent op before visiting each child region.
|
||||
callback(op, stage);
|
||||
stage.Advance();
|
||||
|
||||
for (auto &block : region)
|
||||
// Early increment here in the case where the operation is erased.
|
||||
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
||||
WalkOperations(&nestedOp, callback);
|
||||
}
|
||||
|
||||
// Invoke callback after all regions have been visited.
|
||||
callback(op, stage);
|
||||
}
|
||||
|
||||
/// Walk all of the operations nested under and including the given operations.
|
||||
/// This methods walks operations until an interrupt signal is received.
|
||||
mlir::WalkResult WalkOperations(mlir::Operation *op,
|
||||
InterruptCallback callback) {
|
||||
WalkStage stage(op);
|
||||
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
// Invoke callback on the parent op before visiting each child region.
|
||||
if (callback(op, stage).wasInterrupted())
|
||||
return mlir::WalkResult::interrupt();
|
||||
|
||||
stage.Advance();
|
||||
|
||||
for (auto &block : region) {
|
||||
// Early increment here in the case where the operation is erased.
|
||||
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
||||
if (WalkOperations(&nestedOp, callback).wasInterrupted())
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return callback(op, stage);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace tensorflow
|
168
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
Normal file
168
tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
Normal file
@ -0,0 +1,168 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
|
||||
// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
|
||||
// walk() utility that MLIR core provides traverses operations in a block/
|
||||
// blocks in a region in the program order, and these walkers do the same. When
|
||||
// operations have regions attached to them, the core MLIR walkers visit the
|
||||
// regions attached to an Op first, and then visit the op. So within the context
|
||||
// of a single Op, the traversal is post-order (considering the Op as the parent
|
||||
// node and regions as the children). For certain use cases, it may be more
|
||||
// efficient/desirable to visit the parent Op before visiting the attached
|
||||
// regions. As an example, if the attached regions have region arguments that
|
||||
// are related to the operation inputs (tf.WhileRegion is an example), then we
|
||||
// may want to propagate some information from the Op inputs to the region
|
||||
// inputs and then visit the regions to continue progagating that information
|
||||
// within the regions. With just post-order traversal, to acheive the same we
|
||||
// may need to schedule another walk so make sure child regions get visited.
|
||||
// A pre-order walk (within the context of a single operation) will avoid that.
|
||||
// Similarly, for certain operations, we may want to visit the Op both before
|
||||
// and after all regions have been visited (say to propagate information from
|
||||
// inputs -> region arguments and then from region results -> outputs).
|
||||
|
||||
// In general, since the data flow between an operation and its regions is
|
||||
// opaque in MLIR, we may need to visit the operation in-between regions as well
|
||||
// if say region0 is transferring control back to the Op and from then to
|
||||
// region1. So a more general walker that supports pre/in/post-order walk is
|
||||
// desirable. To support this, the generic walkers defined below will invoke
|
||||
// the walk callback on the parent Op at each stage of the child region walk,
|
||||
// i.e., before visiting any region, in between regions, and after visiting all
|
||||
// regions. To indicate the current walk stage, the callback will also get a
|
||||
// `WalkState` parameter. The callback can inspect the current walk stage and
|
||||
// decide to take appropriate actions (incuding not doing anything). With this
|
||||
// the walker below can support pre/in/post-order walks as well as combined
|
||||
// walks (pre+in+post)-order walk.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A class to indicate the current walk stage.
|
||||
class WalkStage {
|
||||
public:
|
||||
explicit WalkStage(mlir::Operation *op);
|
||||
|
||||
bool IsBeforeAllRegions() const { return next_region_ == 0; }
|
||||
bool IsBeforeRegion(int region) const { return next_region_ == region; }
|
||||
bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
|
||||
bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
|
||||
void Advance() { next_region_++; }
|
||||
int GetNextRegion() const { return next_region_; }
|
||||
|
||||
private:
|
||||
const int num_regions_;
|
||||
int next_region_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// This is similar to MLIR version, but works with multiple argument functions.
|
||||
// Helper templates to deduce the first argument of a callback parameter.
|
||||
template <typename Ret, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (*)(Arg, Rest...));
|
||||
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
|
||||
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
|
||||
template <typename F>
|
||||
decltype(first_argument_type(&F::operator())) first_argument_type(F);
|
||||
|
||||
/// Type definition of the first argument to the given callable 'T'.
|
||||
template <typename T>
|
||||
using first_argument = decltype(first_argument_type(std::declval<T>()));
|
||||
|
||||
using VoidCallback =
|
||||
llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
|
||||
using InterruptCallback =
|
||||
llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
|
||||
|
||||
// Walk all of the operations nested under and including the given operation.
|
||||
void WalkOperations(mlir::Operation *op, VoidCallback callback);
|
||||
|
||||
// Walk all of the operations nested under and including the given operation.
|
||||
// This methods walks operations until an interrupt result is returned by the
|
||||
// callback.
|
||||
mlir::WalkResult WalkOperations(mlir::Operation *op,
|
||||
InterruptCallback callback);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Walk all of the operations nested under and including the given operation.
|
||||
// This method is selected for stage-aware callbacks that operate on Operation*.
|
||||
//
|
||||
// Example:
|
||||
// tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
|
||||
RetT>::type
|
||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||
return detail::WalkOperations(
|
||||
op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
|
||||
}
|
||||
|
||||
// Walk all of the operations of type 'ArgT' nested under and including the
|
||||
// given operation. This method is selected for void returning callbacks that
|
||||
// operate on a specific derived operation type.
|
||||
//
|
||||
// Example:
|
||||
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
|
||||
std::is_same<RetT, void>::value,
|
||||
RetT>::type
|
||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||
if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
|
||||
};
|
||||
return detail::WalkOperations(op,
|
||||
static_cast<detail::VoidCallback>(wrapperFn));
|
||||
}
|
||||
|
||||
// Walk all of the operations of type 'ArgT' nested under and including the
|
||||
// given operation. This method is selected for WalkReturn returning
|
||||
// interruptible callbacks that operate on a specific derived operation type.
|
||||
//
|
||||
// Example:
|
||||
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
|
||||
// if (some_invariant)
|
||||
// return WalkResult::interrupt();
|
||||
// return WalkResult::advance();
|
||||
// });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
|
||||
std::is_same<RetT, mlir::WalkResult>::value,
|
||||
RetT>::type
|
||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
|
||||
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
|
||||
if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
|
||||
return callback(derivedOp, stage);
|
||||
return mlir::WalkResult::advance();
|
||||
};
|
||||
return detail::WalkOperations(
|
||||
op, static_cast<detail::InterruptCallback>(wrapperFn));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
|
Loading…
Reference in New Issue
Block a user