From 466275b90e95a975230cffdf4c33cd8ef2eb16a5 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 13 Aug 2020 10:34:37 -0700 Subject: [PATCH] [MLIR] Add generic (pre/in/post)-order IR walker - Use this generic walker which invokes the callback before, in-between, and after visiting regions attached to Ops. - Add a transform that will annotate the IR with the walk order seen using this walker and another that does the same with an interrupting walker. - Add unit test that will use the walk annotation transform to check that the walk visits operations in the expected order. PiperOrigin-RevId: 326475851 Change-Id: I228b46a1bd93ff325f22233066955fc07855987d --- tensorflow/compiler/mlir/tensorflow/BUILD | 17 ++ .../tests/visitor-interrupt-util.mlir | 91 ++++++++++ .../mlir/tensorflow/tests/visitor-util.mlir | 102 +++++++++++ .../transforms/test_visitor_util.cc | 114 ++++++++++++ .../mlir/tensorflow/utils/visitor_util.cc | 70 ++++++++ .../mlir/tensorflow/utils/visitor_util.h | 168 ++++++++++++++++++ 6 files changed, 562 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d2e57f72774..f9b1abcccc6 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -778,6 +778,7 @@ cc_library( "transforms/tensor_list_ops_decomposition.cc", "transforms/test_resource_alias_analysis.cc", "transforms/test_side_effect_analysis.cc", + "transforms/test_visitor_util.cc", "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", @@ -825,6 +826,7 @@ cc_library( ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", + ":visitor_util", ":xla_sharding_util", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", @@ -1785,6 +1787,21 @@ cc_library( ], ) +cc_library( + name = "visitor_util", + srcs = [ + "utils/visitor_util.cc", + ], + hdrs = [ + "utils/visitor_util.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "xla_sharding_util", srcs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir new file mode 100644 index 00000000000..1770b4e146d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir @@ -0,0 +1,91 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s + +// Test simple operations with no regions and no interrupts. They should be +// visited with stage "before all regions". + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{4: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{3: before all regions}} + return %0 : tensor +} + +// ----- + +// Test simple operations with no regions and interrupts. No remarks after +// the interrupting operation is visited. + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{2: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + %0 = "tf.Identity"(%arg0) {interrupt_before_all = true} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test operation with non empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) {interrupt_after_all = true} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test operation with multiple regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + "tf.yield"(%1) : (tensor) -> () + }) {interrupt_after_region = 0} : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// Test static filtering +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{7: walk was interrupted}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: before all regions}} + // expected-remark@below {{9: before region #1}} + // expected-remark@below {{10: after all regions}} + %0 = "tf.IfRegion"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + "tf.Yield"(%1) { interrupt_after_all = true } : (tensor) -> () + }) {is_stateless = true}: (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir new file mode 100644 index 00000000000..d376fad5c33 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir @@ -0,0 +1,102 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s + +// Test simple operations with no regions. They should be visited with stage +// = before all regions. + +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{4: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{3: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{5: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{3: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + }) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with non empty regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{7: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) : (tensor) -> tensor + // expected-remark@below {{6: before all regions}} + return %0 : tensor +} + +// ----- +// Test operation with multiple regions. +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{10: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: after all regions}} + %0 = "tf.unknownop"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{7: before all regions}} + "tf.yield"(%1) : (tensor) -> () + }) : (tensor) -> tensor + // expected-remark@below {{9: before all regions}} + return %0 : tensor +} + +// ----- +// Test static filtering +// expected-remark@below {{0: before all regions}} +// expected-remark@below {{10: after all regions}} +func @foo(%arg0: tensor) -> tensor { + // expected-remark@below {{1: before all regions}} + %cst = constant dense<1.0> : tensor + // expected-remark@below {{2: before all regions}} + // expected-remark@below {{5: before region #1}} + // expected-remark@below {{8: after all regions}} + // expected-remark@below {{11: before all regions}} + // expected-remark@below {{12: before region #1}} + // expected-remark@below {{13: after all regions}} + %0 = "tf.IfRegion"(%arg0) ({ + // expected-remark@below {{3: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{4: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }, { + // expected-remark@below {{6: before all regions}} + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + // expected-remark@below {{7: before all regions}} + "tf.Yield"(%1) : (tensor) -> () + }) {is_stateless = true}: (tensor) -> tensor + // expected-remark@below {{9: before all regions}} + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc new file mode 100644 index 00000000000..689becb796b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_visitor_util.cc @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +namespace tensorflow { +namespace { + +std::string get_stage_description(const WalkStage &stage) { + if (stage.IsBeforeAllRegions()) return "before all regions"; + if (stage.IsAfterAllRegions()) return "after all regions"; + return "before region #" + std::to_string(stage.GetNextRegion()); +} + +// A pass that annotates each operation with an remarks that include a unique +// step ID and a description of the visitor step. +class TestVisitorUtil + : public mlir::PassWrapper { + public: + void runOnFunction() override { + mlir::FuncOp func = getOperation(); + int step_id = 0; + GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) { + op->emitRemark() << step_id++ << ": " << get_stage_description(stage); + }); + + // Exercise static inference of operation type + GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) { + op.emitRemark() << step_id++ << ": " << get_stage_description(stage); + }); + } +}; + +class TestVisitorUtilInterrupt + : public mlir::PassWrapper { + public: + void runOnFunction() override { + mlir::FuncOp func = getOperation(); + int step_id = 0; + + auto walker = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto interrupt_before_all = + op->getAttrOfType("interrupt_before_all")) + if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions()) + return mlir::WalkResult::interrupt(); + + if (auto interrupt_after_all = + op->getAttrOfType("interrupt_after_all")) + if (interrupt_after_all.getValue() && stage.IsAfterAllRegions()) + return mlir::WalkResult::interrupt(); + + if (auto interrupt_after_region = + op->getAttrOfType("interrupt_after_region")) + if (stage.IsAfterRegion( + static_cast(interrupt_after_region.getInt()))) + return mlir::WalkResult::interrupt(); + + op->emitRemark() << step_id++ << ": " << get_stage_description(stage); + return mlir::WalkResult::advance(); + }; + + // Interrupt the walk based on attributes on the operation. + auto result = GenericWalk(func, walker); + + if (result.wasInterrupted()) + func.emitRemark() << step_id++ << ": walk was interrupted"; + + // Exercise static inference of operation type for interrupting callback. + result = + GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) { + return walker(op, stage); + }); + + if (result.wasInterrupted()) + func.emitRemark() << step_id++ << ": walk was interrupted"; + } +}; + +mlir::PassRegistration pass( + "tf-test-visitor-util", + "Add remarks that trace order of visiting operations using TF visitor " + "utilities."); + +mlir::PassRegistration pass_interrupt( + "tf-test-visitor-util-interrupt", + "Add remarks that trace order of visiting operations using TF visitor " + "utilities, interrupt version."); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc new file mode 100644 index 00000000000..0647d42f315 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { + +WalkStage::WalkStage(mlir::Operation *op) + : num_regions_(op->getNumRegions()), next_region_(0) {} + +namespace detail { + +/// Walk all of the operations nested under and including the given operations. +void WalkOperations(mlir::Operation *op, VoidCallback callback) { + WalkStage stage(op); + + for (auto ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + callback(op, stage); + stage.Advance(); + + for (auto &block : region) + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : llvm::make_early_inc_range(block)) + WalkOperations(&nestedOp, callback); + } + + // Invoke callback after all regions have been visited. + callback(op, stage); +} + +/// Walk all of the operations nested under and including the given operations. +/// This methods walks operations until an interrupt signal is received. +mlir::WalkResult WalkOperations(mlir::Operation *op, + InterruptCallback callback) { + WalkStage stage(op); + + for (auto ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + if (callback(op, stage).wasInterrupted()) + return mlir::WalkResult::interrupt(); + + stage.Advance(); + + for (auto &block : region) { + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : llvm::make_early_inc_range(block)) + if (WalkOperations(&nestedOp, callback).wasInterrupted()) + return mlir::WalkResult::interrupt(); + } + } + return callback(op, stage); +} + +} // namespace detail +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h new file mode 100644 index 00000000000..31c1f4b62e6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h @@ -0,0 +1,168 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_ + +#include + +#include "mlir/IR/Visitors.h" // from @llvm-project + +// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The +// walk() utility that MLIR core provides traverses operations in a block/ +// blocks in a region in the program order, and these walkers do the same. When +// operations have regions attached to them, the core MLIR walkers visit the +// regions attached to an Op first, and then visit the op. So within the context +// of a single Op, the traversal is post-order (considering the Op as the parent +// node and regions as the children). For certain use cases, it may be more +// efficient/desirable to visit the parent Op before visiting the attached +// regions. As an example, if the attached regions have region arguments that +// are related to the operation inputs (tf.WhileRegion is an example), then we +// may want to propagate some information from the Op inputs to the region +// inputs and then visit the regions to continue progagating that information +// within the regions. With just post-order traversal, to acheive the same we +// may need to schedule another walk so make sure child regions get visited. +// A pre-order walk (within the context of a single operation) will avoid that. +// Similarly, for certain operations, we may want to visit the Op both before +// and after all regions have been visited (say to propagate information from +// inputs -> region arguments and then from region results -> outputs). + +// In general, since the data flow between an operation and its regions is +// opaque in MLIR, we may need to visit the operation in-between regions as well +// if say region0 is transferring control back to the Op and from then to +// region1. So a more general walker that supports pre/in/post-order walk is +// desirable. To support this, the generic walkers defined below will invoke +// the walk callback on the parent Op at each stage of the child region walk, +// i.e., before visiting any region, in between regions, and after visiting all +// regions. To indicate the current walk stage, the callback will also get a +// `WalkState` parameter. The callback can inspect the current walk stage and +// decide to take appropriate actions (incuding not doing anything). With this +// the walker below can support pre/in/post-order walks as well as combined +// walks (pre+in+post)-order walk. + +namespace tensorflow { + +// A class to indicate the current walk stage. +class WalkStage { + public: + explicit WalkStage(mlir::Operation *op); + + bool IsBeforeAllRegions() const { return next_region_ == 0; } + bool IsBeforeRegion(int region) const { return next_region_ == region; } + bool IsAfterRegion(int region) const { return next_region_ == region + 1; } + bool IsAfterAllRegions() const { return next_region_ == num_regions_; } + void Advance() { next_region_++; } + int GetNextRegion() const { return next_region_; } + + private: + const int num_regions_; + int next_region_; +}; + +namespace detail { +// This is similar to MLIR version, but works with multiple argument functions. +// Helper templates to deduce the first argument of a callback parameter. +template +Arg first_argument_type(Ret (*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...) const); +template +decltype(first_argument_type(&F::operator())) first_argument_type(F); + +/// Type definition of the first argument to the given callable 'T'. +template +using first_argument = decltype(first_argument_type(std::declval())); + +using VoidCallback = + llvm::function_ref; +using InterruptCallback = + llvm::function_ref; + +// Walk all of the operations nested under and including the given operation. +void WalkOperations(mlir::Operation *op, VoidCallback callback); + +// Walk all of the operations nested under and including the given operation. +// This methods walks operations until an interrupt result is returned by the +// callback. +mlir::WalkResult WalkOperations(mlir::Operation *op, + InterruptCallback callback); + +} // namespace detail + +// Walk all of the operations nested under and including the given operation. +// This method is selected for stage-aware callbacks that operate on Operation*. +// +// Example: +// tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + return detail::WalkOperations( + op, llvm::function_ref(callback)); +} + +// Walk all of the operations of type 'ArgT' nested under and including the +// given operation. This method is selected for void returning callbacks that +// operate on a specific derived operation type. +// +// Example: +// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto derivedOp = llvm::dyn_cast(op)) callback(derivedOp, stage); + }; + return detail::WalkOperations(op, + static_cast(wrapperFn)); +} + +// Walk all of the operations of type 'ArgT' nested under and including the +// given operation. This method is selected for WalkReturn returning +// interruptible callbacks that operate on a specific derived operation type. +// +// Example: +// tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { +// if (some_invariant) +// return WalkResult::interrupt(); +// return WalkResult::advance(); +// }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +GenericWalk(mlir::Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) { + if (auto derivedOp = llvm::dyn_cast(op)) + return callback(derivedOp, stage); + return mlir::WalkResult::advance(); + }; + return detail::WalkOperations( + op, static_cast(wrapperFn)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_