[MLIR] Add generic (pre/in/post)-order IR walker

- Use this generic walker which invokes the callback before, in-between, and after visiting
  regions attached to Ops.
- Add a transform that will annotate the IR with the walk order seen using this walker and
  another that does the same with an interrupting walker.
- Add unit test that will use the walk annotation transform to check that the walk visits
  operations in the expected order.

PiperOrigin-RevId: 326475851
Change-Id: I228b46a1bd93ff325f22233066955fc07855987d
This commit is contained in:
Rahul Joshi 2020-08-13 10:34:37 -07:00 committed by TensorFlower Gardener
parent b07e34b7b3
commit 466275b90e
6 changed files with 562 additions and 0 deletions

View File

@ -778,6 +778,7 @@ cc_library(
"transforms/tensor_list_ops_decomposition.cc",
"transforms/test_resource_alias_analysis.cc",
"transforms/test_side_effect_analysis.cc",
"transforms/test_visitor_util.cc",
"transforms/tf_data_optimization_pass.cc",
"transforms/tf_device_assignment.cc",
"transforms/tpu_cluster_formation.cc",
@ -825,6 +826,7 @@ cc_library(
":tpu_rewrite_device_util",
":translate_utils",
":unroll_batch_matmul_pass",
":visitor_util",
":xla_sharding_util",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite:validators",
@ -1785,6 +1787,21 @@ cc_library(
],
)
cc_library(
name = "visitor_util",
srcs = [
"utils/visitor_util.cc",
],
hdrs = [
"utils/visitor_util.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "xla_sharding_util",
srcs = [

View File

@ -0,0 +1,91 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s
// Test simple operations with no regions and no interrupts. They should be
// visited with stage "before all regions".
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{4: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{3: before all regions}}
return %0 : tensor<f32>
}
// -----
// Test simple operations with no regions and interrupts. No remarks after
// the interrupting operation is visited.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{2: walk was interrupted}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
%0 = "tf.Identity"(%arg0) {interrupt_before_all = true} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// Test operation with non empty regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{5: walk was interrupted}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
%0 = "tf.unknownop"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.yield"(%1) : (tensor<f32>) -> ()
}) {interrupt_after_all = true} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// Test operation with multiple regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{5: walk was interrupted}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
%0 = "tf.unknownop"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.yield"(%1) : (tensor<f32>) -> ()
}, {
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
"tf.yield"(%1) : (tensor<f32>) -> ()
}) {interrupt_after_region = 0} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// Test static filtering
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{7: walk was interrupted}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
// expected-remark@below {{5: before region #1}}
// expected-remark@below {{8: before all regions}}
// expected-remark@below {{9: before region #1}}
// expected-remark@below {{10: after all regions}}
%0 = "tf.IfRegion"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.Yield"(%1) : (tensor<f32>) -> ()
}, {
// expected-remark@below {{6: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
"tf.Yield"(%1) { interrupt_after_all = true } : (tensor<f32>) -> ()
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -0,0 +1,102 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s
// Test simple operations with no regions. They should be visited with stage
// = before all regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{4: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
%0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{3: before all regions}}
return %0 : tensor<f32>
}
// -----
// Test operation with empty regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{5: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
// expected-remark@below {{3: after all regions}}
%0 = "tf.unknownop"(%arg0) ({
}) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
return %0 : tensor<f32>
}
// -----
// Test operation with non empty regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{7: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
// expected-remark@below {{5: after all regions}}
%0 = "tf.unknownop"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.yield"(%1) : (tensor<f32>) -> ()
}) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{6: before all regions}}
return %0 : tensor<f32>
}
// -----
// Test operation with multiple regions.
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{10: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
// expected-remark@below {{5: before region #1}}
// expected-remark@below {{8: after all regions}}
%0 = "tf.unknownop"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.yield"(%1) : (tensor<f32>) -> ()
}, {
// expected-remark@below {{6: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{7: before all regions}}
"tf.yield"(%1) : (tensor<f32>) -> ()
}) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{9: before all regions}}
return %0 : tensor<f32>
}
// -----
// Test static filtering
// expected-remark@below {{0: before all regions}}
// expected-remark@below {{10: after all regions}}
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
// expected-remark@below {{1: before all regions}}
%cst = constant dense<1.0> : tensor<f32>
// expected-remark@below {{2: before all regions}}
// expected-remark@below {{5: before region #1}}
// expected-remark@below {{8: after all regions}}
// expected-remark@below {{11: before all regions}}
// expected-remark@below {{12: before region #1}}
// expected-remark@below {{13: after all regions}}
%0 = "tf.IfRegion"(%arg0) ({
// expected-remark@below {{3: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{4: before all regions}}
"tf.Yield"(%1) : (tensor<f32>) -> ()
}, {
// expected-remark@below {{6: before all regions}}
%1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
// expected-remark@below {{7: before all regions}}
"tf.Yield"(%1) : (tensor<f32>) -> ()
}) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
// expected-remark@below {{9: before all regions}}
return %0 : tensor<f32>
}

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
namespace tensorflow {
namespace {
std::string get_stage_description(const WalkStage &stage) {
if (stage.IsBeforeAllRegions()) return "before all regions";
if (stage.IsAfterAllRegions()) return "after all regions";
return "before region #" + std::to_string(stage.GetNextRegion());
}
// A pass that annotates each operation with an remarks that include a unique
// step ID and a description of the visitor step.
class TestVisitorUtil
: public mlir::PassWrapper<TestVisitorUtil, mlir::FunctionPass> {
public:
void runOnFunction() override {
mlir::FuncOp func = getOperation();
int step_id = 0;
GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) {
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
});
// Exercise static inference of operation type
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
op.emitRemark() << step_id++ << ": " << get_stage_description(stage);
});
}
};
class TestVisitorUtilInterrupt
: public mlir::PassWrapper<TestVisitorUtilInterrupt, mlir::FunctionPass> {
public:
void runOnFunction() override {
mlir::FuncOp func = getOperation();
int step_id = 0;
auto walker = [&](mlir::Operation *op, const WalkStage &stage) {
if (auto interrupt_before_all =
op->getAttrOfType<mlir::BoolAttr>("interrupt_before_all"))
if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions())
return mlir::WalkResult::interrupt();
if (auto interrupt_after_all =
op->getAttrOfType<mlir::BoolAttr>("interrupt_after_all"))
if (interrupt_after_all.getValue() && stage.IsAfterAllRegions())
return mlir::WalkResult::interrupt();
if (auto interrupt_after_region =
op->getAttrOfType<mlir::IntegerAttr>("interrupt_after_region"))
if (stage.IsAfterRegion(
static_cast<int>(interrupt_after_region.getInt())))
return mlir::WalkResult::interrupt();
op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
return mlir::WalkResult::advance();
};
// Interrupt the walk based on attributes on the operation.
auto result = GenericWalk(func, walker);
if (result.wasInterrupted())
func.emitRemark() << step_id++ << ": walk was interrupted";
// Exercise static inference of operation type for interrupting callback.
result =
GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
return walker(op, stage);
});
if (result.wasInterrupted())
func.emitRemark() << step_id++ << ": walk was interrupted";
}
};
mlir::PassRegistration<TestVisitorUtil> pass(
"tf-test-visitor-util",
"Add remarks that trace order of visiting operations using TF visitor "
"utilities.");
mlir::PassRegistration<TestVisitorUtilInterrupt> pass_interrupt(
"tf-test-visitor-util-interrupt",
"Add remarks that trace order of visiting operations using TF visitor "
"utilities, interrupt version.");
} // anonymous namespace
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
#include "mlir/IR/Operation.h" // from @llvm-project
namespace tensorflow {
WalkStage::WalkStage(mlir::Operation *op)
: num_regions_(op->getNumRegions()), next_region_(0) {}
namespace detail {
/// Walk all of the operations nested under and including the given operations.
void WalkOperations(mlir::Operation *op, VoidCallback callback) {
WalkStage stage(op);
for (auto &region : op->getRegions()) {
// Invoke callback on the parent op before visiting each child region.
callback(op, stage);
stage.Advance();
for (auto &block : region)
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block))
WalkOperations(&nestedOp, callback);
}
// Invoke callback after all regions have been visited.
callback(op, stage);
}
/// Walk all of the operations nested under and including the given operations.
/// This methods walks operations until an interrupt signal is received.
mlir::WalkResult WalkOperations(mlir::Operation *op,
InterruptCallback callback) {
WalkStage stage(op);
for (auto &region : op->getRegions()) {
// Invoke callback on the parent op before visiting each child region.
if (callback(op, stage).wasInterrupted())
return mlir::WalkResult::interrupt();
stage.Advance();
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block))
if (WalkOperations(&nestedOp, callback).wasInterrupted())
return mlir::WalkResult::interrupt();
}
}
return callback(op, stage);
}
} // namespace detail
} // namespace tensorflow

View File

@ -0,0 +1,168 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
#include <utility>
#include "mlir/IR/Visitors.h" // from @llvm-project
// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
// walk() utility that MLIR core provides traverses operations in a block/
// blocks in a region in the program order, and these walkers do the same. When
// operations have regions attached to them, the core MLIR walkers visit the
// regions attached to an Op first, and then visit the op. So within the context
// of a single Op, the traversal is post-order (considering the Op as the parent
// node and regions as the children). For certain use cases, it may be more
// efficient/desirable to visit the parent Op before visiting the attached
// regions. As an example, if the attached regions have region arguments that
// are related to the operation inputs (tf.WhileRegion is an example), then we
// may want to propagate some information from the Op inputs to the region
// inputs and then visit the regions to continue progagating that information
// within the regions. With just post-order traversal, to acheive the same we
// may need to schedule another walk so make sure child regions get visited.
// A pre-order walk (within the context of a single operation) will avoid that.
// Similarly, for certain operations, we may want to visit the Op both before
// and after all regions have been visited (say to propagate information from
// inputs -> region arguments and then from region results -> outputs).
// In general, since the data flow between an operation and its regions is
// opaque in MLIR, we may need to visit the operation in-between regions as well
// if say region0 is transferring control back to the Op and from then to
// region1. So a more general walker that supports pre/in/post-order walk is
// desirable. To support this, the generic walkers defined below will invoke
// the walk callback on the parent Op at each stage of the child region walk,
// i.e., before visiting any region, in between regions, and after visiting all
// regions. To indicate the current walk stage, the callback will also get a
// `WalkState` parameter. The callback can inspect the current walk stage and
// decide to take appropriate actions (incuding not doing anything). With this
// the walker below can support pre/in/post-order walks as well as combined
// walks (pre+in+post)-order walk.
namespace tensorflow {
// A class to indicate the current walk stage.
class WalkStage {
public:
explicit WalkStage(mlir::Operation *op);
bool IsBeforeAllRegions() const { return next_region_ == 0; }
bool IsBeforeRegion(int region) const { return next_region_ == region; }
bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
void Advance() { next_region_++; }
int GetNextRegion() const { return next_region_; }
private:
const int num_regions_;
int next_region_;
};
namespace detail {
// This is similar to MLIR version, but works with multiple argument functions.
// Helper templates to deduce the first argument of a callback parameter.
template <typename Ret, typename Arg, typename... Rest>
Arg first_argument_type(Ret (*)(Arg, Rest...));
template <typename Ret, typename F, typename Arg, typename... Rest>
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
template <typename Ret, typename F, typename Arg, typename... Rest>
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
template <typename F>
decltype(first_argument_type(&F::operator())) first_argument_type(F);
/// Type definition of the first argument to the given callable 'T'.
template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>()));
using VoidCallback =
llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
using InterruptCallback =
llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
// Walk all of the operations nested under and including the given operation.
void WalkOperations(mlir::Operation *op, VoidCallback callback);
// Walk all of the operations nested under and including the given operation.
// This methods walks operations until an interrupt result is returned by the
// callback.
mlir::WalkResult WalkOperations(mlir::Operation *op,
InterruptCallback callback);
} // namespace detail
// Walk all of the operations nested under and including the given operation.
// This method is selected for stage-aware callbacks that operate on Operation*.
//
// Example:
// tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
RetT>::type
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
return detail::WalkOperations(
op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
}
// Walk all of the operations of type 'ArgT' nested under and including the
// given operation. This method is selected for void returning callbacks that
// operate on a specific derived operation type.
//
// Example:
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
std::is_same<RetT, void>::value,
RetT>::type
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
};
return detail::WalkOperations(op,
static_cast<detail::VoidCallback>(wrapperFn));
}
// Walk all of the operations of type 'ArgT' nested under and including the
// given operation. This method is selected for WalkReturn returning
// interruptible callbacks that operate on a specific derived operation type.
//
// Example:
// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
// if (some_invariant)
// return WalkResult::interrupt();
// return WalkResult::advance();
// });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
std::is_same<RetT, mlir::WalkResult>::value,
RetT>::type
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
return callback(derivedOp, stage);
return mlir::WalkResult::advance();
};
return detail::WalkOperations(
op, static_cast<detail::InterruptCallback>(wrapperFn));
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_