[MLIR] Add generic (pre/in/post)-order IR walker
- Use this generic walker which invokes the callback before, in-between, and after visiting regions attached to Ops. - Add a transform that will annotate the IR with the walk order seen using this walker and another that does the same with an interrupting walker. - Add unit test that will use the walk annotation transform to check that the walk visits operations in the expected order. PiperOrigin-RevId: 326475851 Change-Id: I228b46a1bd93ff325f22233066955fc07855987d
This commit is contained in:
		
							parent
							
								
									b07e34b7b3
								
							
						
					
					
						commit
						466275b90e
					
				@ -778,6 +778,7 @@ cc_library(
 | 
			
		||||
        "transforms/tensor_list_ops_decomposition.cc",
 | 
			
		||||
        "transforms/test_resource_alias_analysis.cc",
 | 
			
		||||
        "transforms/test_side_effect_analysis.cc",
 | 
			
		||||
        "transforms/test_visitor_util.cc",
 | 
			
		||||
        "transforms/tf_data_optimization_pass.cc",
 | 
			
		||||
        "transforms/tf_device_assignment.cc",
 | 
			
		||||
        "transforms/tpu_cluster_formation.cc",
 | 
			
		||||
@ -825,6 +826,7 @@ cc_library(
 | 
			
		||||
        ":tpu_rewrite_device_util",
 | 
			
		||||
        ":translate_utils",
 | 
			
		||||
        ":unroll_batch_matmul_pass",
 | 
			
		||||
        ":visitor_util",
 | 
			
		||||
        ":xla_sharding_util",
 | 
			
		||||
        "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
 | 
			
		||||
        "//tensorflow/compiler/mlir/lite:validators",
 | 
			
		||||
@ -1785,6 +1787,21 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "visitor_util",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "utils/visitor_util.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "utils/visitor_util.h",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:Support",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "xla_sharding_util",
 | 
			
		||||
    srcs = [
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,91 @@
 | 
			
		||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util-interrupt %s
 | 
			
		||||
 | 
			
		||||
// Test simple operations with no regions and no interrupts. They should be
 | 
			
		||||
// visited with stage "before all regions".
 | 
			
		||||
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{4: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  %0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{3: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Test simple operations with no regions and interrupts. No remarks after
 | 
			
		||||
// the interrupting operation is visited.
 | 
			
		||||
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{2: walk was interrupted}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  %0 = "tf.Identity"(%arg0)  {interrupt_before_all = true} : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test operation with non empty regions.
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{5: walk was interrupted}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  %0 = "tf.unknownop"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }) {interrupt_after_all = true} : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test operation with multiple regions.
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{5: walk was interrupted}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  %0 = "tf.unknownop"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }, {
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }) {interrupt_after_region = 0} : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test static filtering
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{7: walk was interrupted}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  // expected-remark@below {{5: before region #1}}
 | 
			
		||||
  // expected-remark@below {{8: before all regions}}
 | 
			
		||||
  // expected-remark@below {{9: before region #1}}
 | 
			
		||||
  // expected-remark@below {{10: after all regions}}
 | 
			
		||||
  %0 = "tf.IfRegion"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.Yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }, {
 | 
			
		||||
    // expected-remark@below {{6: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    "tf.Yield"(%1) { interrupt_after_all = true } : (tensor<f32>) -> ()
 | 
			
		||||
  }) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										102
									
								
								tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,102 @@
 | 
			
		||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-test-visitor-util %s
 | 
			
		||||
 | 
			
		||||
// Test simple operations with no regions. They should be visited with stage
 | 
			
		||||
// = before all regions.
 | 
			
		||||
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{4: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  %0 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{3: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test operation with empty regions.
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{5: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  // expected-remark@below {{3: after all regions}}
 | 
			
		||||
  %0 = "tf.unknownop"(%arg0) ({
 | 
			
		||||
  }) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{4: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test operation with non empty regions.
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{7: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  // expected-remark@below {{5: after all regions}}
 | 
			
		||||
  %0 = "tf.unknownop"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{6: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test operation with multiple regions.
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{10: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  // expected-remark@below {{5: before region #1}}
 | 
			
		||||
  // expected-remark@below {{8: after all regions}}
 | 
			
		||||
  %0 = "tf.unknownop"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }, {
 | 
			
		||||
    // expected-remark@below {{6: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{7: before all regions}}
 | 
			
		||||
    "tf.yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{9: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
// Test static filtering
 | 
			
		||||
// expected-remark@below {{0: before all regions}}
 | 
			
		||||
// expected-remark@below {{10: after all regions}}
 | 
			
		||||
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
  // expected-remark@below {{1: before all regions}}
 | 
			
		||||
  %cst = constant dense<1.0> : tensor<f32>
 | 
			
		||||
  // expected-remark@below {{2: before all regions}}
 | 
			
		||||
  // expected-remark@below {{5: before region #1}}
 | 
			
		||||
  // expected-remark@below {{8: after all regions}}
 | 
			
		||||
  // expected-remark@below {{11: before all regions}}
 | 
			
		||||
  // expected-remark@below {{12: before region #1}}
 | 
			
		||||
  // expected-remark@below {{13: after all regions}}
 | 
			
		||||
  %0 = "tf.IfRegion"(%arg0) ({
 | 
			
		||||
    // expected-remark@below {{3: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{4: before all regions}}
 | 
			
		||||
    "tf.Yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }, {
 | 
			
		||||
    // expected-remark@below {{6: before all regions}}
 | 
			
		||||
    %1 = "tf.Identity"(%arg0) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
    // expected-remark@below {{7: before all regions}}
 | 
			
		||||
    "tf.Yield"(%1) : (tensor<f32>) -> ()
 | 
			
		||||
  }) {is_stateless = true}: (tensor<f32>) -> tensor<f32>
 | 
			
		||||
  // expected-remark@below {{9: before all regions}}
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,114 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/Attributes.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Visitors.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Pass/Pass.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Pass/PassManager.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Support/LLVM.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Transforms/Passes.h"  // from @llvm-project
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::string get_stage_description(const WalkStage &stage) {
 | 
			
		||||
  if (stage.IsBeforeAllRegions()) return "before all regions";
 | 
			
		||||
  if (stage.IsAfterAllRegions()) return "after all regions";
 | 
			
		||||
  return "before region #" + std::to_string(stage.GetNextRegion());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A pass that annotates each operation with an remarks that include a unique
 | 
			
		||||
// step ID and a description of the visitor step.
 | 
			
		||||
class TestVisitorUtil
 | 
			
		||||
    : public mlir::PassWrapper<TestVisitorUtil, mlir::FunctionPass> {
 | 
			
		||||
 public:
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    mlir::FuncOp func = getOperation();
 | 
			
		||||
    int step_id = 0;
 | 
			
		||||
    GenericWalk(func, [&](mlir::Operation *op, const WalkStage &stage) {
 | 
			
		||||
      op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    // Exercise static inference of operation type
 | 
			
		||||
    GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
 | 
			
		||||
      op.emitRemark() << step_id++ << ": " << get_stage_description(stage);
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class TestVisitorUtilInterrupt
 | 
			
		||||
    : public mlir::PassWrapper<TestVisitorUtilInterrupt, mlir::FunctionPass> {
 | 
			
		||||
 public:
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    mlir::FuncOp func = getOperation();
 | 
			
		||||
    int step_id = 0;
 | 
			
		||||
 | 
			
		||||
    auto walker = [&](mlir::Operation *op, const WalkStage &stage) {
 | 
			
		||||
      if (auto interrupt_before_all =
 | 
			
		||||
              op->getAttrOfType<mlir::BoolAttr>("interrupt_before_all"))
 | 
			
		||||
        if (interrupt_before_all.getValue() && stage.IsBeforeAllRegions())
 | 
			
		||||
          return mlir::WalkResult::interrupt();
 | 
			
		||||
 | 
			
		||||
      if (auto interrupt_after_all =
 | 
			
		||||
              op->getAttrOfType<mlir::BoolAttr>("interrupt_after_all"))
 | 
			
		||||
        if (interrupt_after_all.getValue() && stage.IsAfterAllRegions())
 | 
			
		||||
          return mlir::WalkResult::interrupt();
 | 
			
		||||
 | 
			
		||||
      if (auto interrupt_after_region =
 | 
			
		||||
              op->getAttrOfType<mlir::IntegerAttr>("interrupt_after_region"))
 | 
			
		||||
        if (stage.IsAfterRegion(
 | 
			
		||||
                static_cast<int>(interrupt_after_region.getInt())))
 | 
			
		||||
          return mlir::WalkResult::interrupt();
 | 
			
		||||
 | 
			
		||||
      op->emitRemark() << step_id++ << ": " << get_stage_description(stage);
 | 
			
		||||
      return mlir::WalkResult::advance();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Interrupt the walk based on attributes on the operation.
 | 
			
		||||
    auto result = GenericWalk(func, walker);
 | 
			
		||||
 | 
			
		||||
    if (result.wasInterrupted())
 | 
			
		||||
      func.emitRemark() << step_id++ << ": walk was interrupted";
 | 
			
		||||
 | 
			
		||||
    // Exercise static inference of operation type for interrupting callback.
 | 
			
		||||
    result =
 | 
			
		||||
        GenericWalk(func, [&](mlir::TF::IfRegionOp op, const WalkStage &stage) {
 | 
			
		||||
          return walker(op, stage);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    if (result.wasInterrupted())
 | 
			
		||||
      func.emitRemark() << step_id++ << ": walk was interrupted";
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
mlir::PassRegistration<TestVisitorUtil> pass(
 | 
			
		||||
    "tf-test-visitor-util",
 | 
			
		||||
    "Add remarks that trace order of visiting operations using TF visitor "
 | 
			
		||||
    "utilities.");
 | 
			
		||||
 | 
			
		||||
mlir::PassRegistration<TestVisitorUtilInterrupt> pass_interrupt(
 | 
			
		||||
    "tf-test-visitor-util-interrupt",
 | 
			
		||||
    "Add remarks that trace order of visiting operations using TF visitor "
 | 
			
		||||
    "utilities, interrupt version.");
 | 
			
		||||
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
							
								
								
									
										70
									
								
								tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								tensorflow/compiler/mlir/tensorflow/utils/visitor_util.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/Operation.h"  // from @llvm-project
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
WalkStage::WalkStage(mlir::Operation *op)
 | 
			
		||||
    : num_regions_(op->getNumRegions()), next_region_(0) {}
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
/// Walk all of the operations nested under and including the given operations.
 | 
			
		||||
void WalkOperations(mlir::Operation *op, VoidCallback callback) {
 | 
			
		||||
  WalkStage stage(op);
 | 
			
		||||
 | 
			
		||||
  for (auto ®ion : op->getRegions()) {
 | 
			
		||||
    // Invoke callback on the parent op before visiting each child region.
 | 
			
		||||
    callback(op, stage);
 | 
			
		||||
    stage.Advance();
 | 
			
		||||
 | 
			
		||||
    for (auto &block : region)
 | 
			
		||||
      // Early increment here in the case where the operation is erased.
 | 
			
		||||
      for (auto &nestedOp : llvm::make_early_inc_range(block))
 | 
			
		||||
        WalkOperations(&nestedOp, callback);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Invoke callback after all regions have been visited.
 | 
			
		||||
  callback(op, stage);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Walk all of the operations nested under and including the given operations.
 | 
			
		||||
/// This methods walks operations until an interrupt signal is received.
 | 
			
		||||
mlir::WalkResult WalkOperations(mlir::Operation *op,
 | 
			
		||||
                                InterruptCallback callback) {
 | 
			
		||||
  WalkStage stage(op);
 | 
			
		||||
 | 
			
		||||
  for (auto ®ion : op->getRegions()) {
 | 
			
		||||
    // Invoke callback on the parent op before visiting each child region.
 | 
			
		||||
    if (callback(op, stage).wasInterrupted())
 | 
			
		||||
      return mlir::WalkResult::interrupt();
 | 
			
		||||
 | 
			
		||||
    stage.Advance();
 | 
			
		||||
 | 
			
		||||
    for (auto &block : region) {
 | 
			
		||||
      // Early increment here in the case where the operation is erased.
 | 
			
		||||
      for (auto &nestedOp : llvm::make_early_inc_range(block))
 | 
			
		||||
        if (WalkOperations(&nestedOp, callback).wasInterrupted())
 | 
			
		||||
          return mlir::WalkResult::interrupt();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return callback(op, stage);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace detail
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
							
								
								
									
										168
									
								
								tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,168 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
 | 
			
		||||
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/Visitors.h"  // from @llvm-project
 | 
			
		||||
 | 
			
		||||
// This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
 | 
			
		||||
// walk() utility that MLIR core provides traverses operations in a block/
 | 
			
		||||
// blocks in a region in the program order, and these walkers do the same. When
 | 
			
		||||
// operations have regions attached to them, the core MLIR walkers visit the
 | 
			
		||||
// regions attached to an Op first, and then visit the op. So within the context
 | 
			
		||||
// of a single Op, the traversal is post-order (considering the Op as the parent
 | 
			
		||||
// node and regions as the children). For certain use cases, it may be more
 | 
			
		||||
// efficient/desirable to visit the parent Op before visiting the attached
 | 
			
		||||
// regions. As an example, if the attached regions have region arguments that
 | 
			
		||||
// are related to the operation inputs (tf.WhileRegion is an example), then we
 | 
			
		||||
// may want to propagate some information from the Op inputs to the region
 | 
			
		||||
// inputs and then visit the regions to continue progagating that information
 | 
			
		||||
// within the regions. With just post-order traversal, to acheive the same we
 | 
			
		||||
// may need to schedule another walk so make sure child regions get visited.
 | 
			
		||||
// A pre-order walk (within the context of a single operation) will avoid that.
 | 
			
		||||
// Similarly, for certain operations, we may want to visit the Op both before
 | 
			
		||||
// and after all regions have been visited (say to propagate information from
 | 
			
		||||
// inputs -> region arguments and then from region results -> outputs).
 | 
			
		||||
 | 
			
		||||
// In general, since the data flow between an operation and its regions is
 | 
			
		||||
// opaque in MLIR, we may need to visit the operation in-between regions as well
 | 
			
		||||
// if say region0 is transferring control back to the Op and from then to
 | 
			
		||||
// region1. So a more general walker that supports pre/in/post-order walk is
 | 
			
		||||
// desirable. To support this, the generic walkers defined below will invoke
 | 
			
		||||
// the walk callback on the parent Op at each stage of the child region walk,
 | 
			
		||||
// i.e., before visiting any region, in between regions, and after visiting all
 | 
			
		||||
// regions. To indicate the current walk stage, the callback will also get a
 | 
			
		||||
// `WalkState` parameter. The callback can inspect the current walk stage and
 | 
			
		||||
// decide to take appropriate actions (incuding not doing anything). With this
 | 
			
		||||
// the walker below can support pre/in/post-order walks as well as combined
 | 
			
		||||
// walks (pre+in+post)-order walk.
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
// A class to indicate the current walk stage.
 | 
			
		||||
class WalkStage {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit WalkStage(mlir::Operation *op);
 | 
			
		||||
 | 
			
		||||
  bool IsBeforeAllRegions() const { return next_region_ == 0; }
 | 
			
		||||
  bool IsBeforeRegion(int region) const { return next_region_ == region; }
 | 
			
		||||
  bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
 | 
			
		||||
  bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
 | 
			
		||||
  void Advance() { next_region_++; }
 | 
			
		||||
  int GetNextRegion() const { return next_region_; }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  const int num_regions_;
 | 
			
		||||
  int next_region_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
// This is similar to MLIR version, but works with multiple argument functions.
 | 
			
		||||
// Helper templates to deduce the first argument of a callback parameter.
 | 
			
		||||
template <typename Ret, typename Arg, typename... Rest>
 | 
			
		||||
Arg first_argument_type(Ret (*)(Arg, Rest...));
 | 
			
		||||
template <typename Ret, typename F, typename Arg, typename... Rest>
 | 
			
		||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
 | 
			
		||||
template <typename Ret, typename F, typename Arg, typename... Rest>
 | 
			
		||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
 | 
			
		||||
template <typename F>
 | 
			
		||||
decltype(first_argument_type(&F::operator())) first_argument_type(F);
 | 
			
		||||
 | 
			
		||||
/// Type definition of the first argument to the given callable 'T'.
 | 
			
		||||
template <typename T>
 | 
			
		||||
using first_argument = decltype(first_argument_type(std::declval<T>()));
 | 
			
		||||
 | 
			
		||||
using VoidCallback =
 | 
			
		||||
    llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
 | 
			
		||||
using InterruptCallback =
 | 
			
		||||
    llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
 | 
			
		||||
 | 
			
		||||
// Walk all of the operations nested under and including the given operation.
 | 
			
		||||
void WalkOperations(mlir::Operation *op, VoidCallback callback);
 | 
			
		||||
 | 
			
		||||
// Walk all of the operations nested under and including the given operation.
 | 
			
		||||
// This methods walks operations until an interrupt result is returned by the
 | 
			
		||||
// callback.
 | 
			
		||||
mlir::WalkResult WalkOperations(mlir::Operation *op,
 | 
			
		||||
                                InterruptCallback callback);
 | 
			
		||||
 | 
			
		||||
}  // namespace detail
 | 
			
		||||
 | 
			
		||||
// Walk all of the operations nested under and including the given operation.
 | 
			
		||||
// This method is selected for stage-aware callbacks that operate on Operation*.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//   tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
 | 
			
		||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
 | 
			
		||||
          typename RetT = decltype(std::declval<FuncTy>()(
 | 
			
		||||
              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
 | 
			
		||||
typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
 | 
			
		||||
                        RetT>::type
 | 
			
		||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
 | 
			
		||||
  return detail::WalkOperations(
 | 
			
		||||
      op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Walk all of the operations of type 'ArgT' nested under and including the
 | 
			
		||||
// given operation. This method is selected for void returning callbacks that
 | 
			
		||||
// operate on a specific derived operation type.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//   tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
 | 
			
		||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
 | 
			
		||||
          typename RetT = decltype(std::declval<FuncTy>()(
 | 
			
		||||
              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
 | 
			
		||||
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
 | 
			
		||||
                            std::is_same<RetT, void>::value,
 | 
			
		||||
                        RetT>::type
 | 
			
		||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
 | 
			
		||||
  auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
 | 
			
		||||
    if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
 | 
			
		||||
  };
 | 
			
		||||
  return detail::WalkOperations(op,
 | 
			
		||||
                                static_cast<detail::VoidCallback>(wrapperFn));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Walk all of the operations of type 'ArgT' nested under and including the
 | 
			
		||||
// given operation. This method is selected for WalkReturn returning
 | 
			
		||||
// interruptible callbacks that operate on a specific derived operation type.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//   tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
 | 
			
		||||
//     if (some_invariant)
 | 
			
		||||
//       return WalkResult::interrupt();
 | 
			
		||||
//     return WalkResult::advance();
 | 
			
		||||
//   });
 | 
			
		||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
 | 
			
		||||
          typename RetT = decltype(std::declval<FuncTy>()(
 | 
			
		||||
              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
 | 
			
		||||
typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
 | 
			
		||||
                            std::is_same<RetT, mlir::WalkResult>::value,
 | 
			
		||||
                        RetT>::type
 | 
			
		||||
GenericWalk(mlir::Operation *op, FuncTy &&callback) {
 | 
			
		||||
  auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
 | 
			
		||||
    if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
 | 
			
		||||
      return callback(derivedOp, stage);
 | 
			
		||||
    return mlir::WalkResult::advance();
 | 
			
		||||
  };
 | 
			
		||||
  return detail::WalkOperations(
 | 
			
		||||
      op, static_cast<detail::InterruptCallback>(wrapperFn));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user