Remove switch fold pass

This pass doesn't do much due to (foremost) limited constant folding on executor dialect, most models using upgrade legacy and folks switching to control flow v2. So no practical value from this pass at the moment and less need in future. So deleting.

PiperOrigin-RevId: 353911764
Change-Id: I4515fa61ba024e29cad8b0903ce57eefff714f3a
This commit is contained in:
Jacques Pienaar 2021-01-26 11:35:25 -08:00 committed by TensorFlower Gardener
parent 8ad4669252
commit 8027470e1e
5 changed files with 0 additions and 645 deletions

View File

@ -881,7 +881,6 @@ cc_library(
"transforms/executor_tpuv1_island_coarsening.cc",
"transforms/executor_tpuv1_outline_tpu_island.cc",
"transforms/fold_broadcast.cc",
"transforms/fold_switch.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/functional_control_flow_to_regions.cc",
"transforms/fused_kernel_matcher.cc",

View File

@ -1,351 +0,0 @@
// RUN: tf-opt -tf-switch-fold %s | FileCheck %s
// CHECK-LABEL: test_single_branch_direct_f
// CHECK-NOT: Switch
// CHECK-NOT: tf.AddV2
func @test_single_branch_direct_f() -> tensor<i32> {
%cst = constant dense<false> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%7:3 = tf_executor.Switch %cst_0, %cst : tensor<i32>
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// CHECK-LABEL: test_single_branch_direct_t
// CHECK-NOT: Switch
// CHECK: tf.AddV2
func @test_single_branch_direct_t() -> tensor<i32> {
%cst = constant dense<true> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%7:3 = tf_executor.Switch %cst_0, %cst : tensor<i32>
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// CHECK-LABEL: test_single_branch_direct_arg_f
// CHECK: Switch
// CHECK: tf.AddV2
func @test_single_branch_direct_arg_f(%pred : tensor<i1>) -> tensor<i32> {
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%7:3 = tf_executor.Switch %cst_0, %pred : tensor<i32>
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// pred ? x + 1 : x - 1
// CHECK-LABEL: ControlFlowTest.testCond_1f
// CHECK-NOT: Switch
// CHECK-NOT: tf.AddV2
// CHECK: tf.Sub
func @ControlFlowTest.testCond_1f() -> tensor<i32> {
%cst = constant dense<false> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %cst, %cst : tensor<i1> {T = "tfdtype$DT_BOOL"}
%2:2 = tf_executor.island {
%12 = "tf.Identity"(%1#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%3:2 = tf_executor.island(%2#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%4:2 = tf_executor.island {
%12 = "tf.Identity"(%1#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%5:2 = tf_executor.island(%4#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%6:2 = tf_executor.island {
%12 = "tf.Identity"(%cst) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%7:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %5#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%9:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%10:2 = tf_executor.island {
%12 = "tf.Sub"(%9#0, %3#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %10#0, %8#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// pred ? x + 1 : x - 1
// CHECK-LABEL: ControlFlowTest.testCond_1t
// CHECK-NOT: Switch
// CHECK: tf.AddV2
// CHECK-NOT: tf.Sub
func @ControlFlowTest.testCond_1t() -> tensor<i32> {
%cst = constant dense<true> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %cst, %cst : tensor<i1> {T = "tfdtype$DT_BOOL"}
%2:2 = tf_executor.island {
%12 = "tf.Identity"(%1#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%3:2 = tf_executor.island(%2#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%4:2 = tf_executor.island {
%12 = "tf.Identity"(%1#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%5:2 = tf_executor.island(%4#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%6:2 = tf_executor.island {
%12 = "tf.Identity"(%cst) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %12 : tensor<i1>
}
%7:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %5#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%9:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%10:2 = tf_executor.island {
%12 = "tf.Sub"(%9#0, %3#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %10#0, %8#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// if (pred)
// return pred ? x + 1 : x - 1
// else
// return x - 1
// CHECK-LABEL: ControlFlowTest.testCond_3f
// CHECK-NOT: Switch
// CHECK-NOT: tf.AddV2
// CHECK: tf.Sub
func @ControlFlowTest.testCond_3f() -> tensor<i32> {
%cst = constant dense<false> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %cst, %cst : tensor<i1> {T = "tfdtype$DT_BOOL"}
%2:2 = tf_executor.island {
%24 = "tf.Identity"(%1#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%3:2 = tf_executor.island(%2#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%4:2 = tf_executor.island {
%24 = "tf.Identity"(%1#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%5:2 = tf_executor.island(%4#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%6:2 = tf_executor.island {
%24 = "tf.Identity"(%cst) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%7:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%8:2 = tf_executor.island {
%24 = "tf.Sub"(%7#0, %3#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%9:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%10:3 = tf_executor.Switch %cst, %6#0 : tensor<i1> {T = "tfdtype$DT_BOOL", _class = ["loc:@Less"]}
%11:3 = tf_executor.Switch %10#1, %10#1 : tensor<i1> {T = "tfdtype$DT_BOOL"}
%12:2 = tf_executor.island {
%24 = "tf.Identity"(%11#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%13:2 = tf_executor.island(%12#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%14:2 = tf_executor.island {
%24 = "tf.Identity"(%11#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%15:2 = tf_executor.island(%14#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%16:2 = tf_executor.island {
%24 = "tf.Identity"(%10#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%17:3 = tf_executor.Switch %9#1, %16#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%18:2 = tf_executor.island {
%24 = "tf.AddV2"(%17#1, %15#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%19:3 = tf_executor.Switch %9#1, %16#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%20:2 = tf_executor.island {
%24 = "tf.Sub"(%19#0, %13#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%21:3 = tf_executor.Merge %20#0, %18#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
%22:2 = tf_executor.island {
%24 = "tf.AddV2"(%21#0, %5#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%23:3 = tf_executor.Merge %8#0, %22#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
tf_executor.fetch %23#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// if (pred)
// return pred ? x + 1 : x - 1
// else
// return x - 1
// CHECK-LABEL: ControlFlowTest.testCond_3t
// CHECK-NOT: Switch
// CHECK: tf.AddV2
// CHECK-NOT: tf.Sub
// CHECK: tf.AddV2
func @ControlFlowTest.testCond_3t() -> tensor<i32> {
%cst = constant dense<true> : tensor<i1>
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %cst, %cst : tensor<i1> {T = "tfdtype$DT_BOOL"}
%2:2 = tf_executor.island {
%24 = "tf.Identity"(%1#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%3:2 = tf_executor.island(%2#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%4:2 = tf_executor.island {
%24 = "tf.Identity"(%1#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%5:2 = tf_executor.island(%4#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%6:2 = tf_executor.island {
%24 = "tf.Identity"(%cst) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%7:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%8:2 = tf_executor.island {
%24 = "tf.Sub"(%7#0, %3#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%9:3 = tf_executor.Switch %cst_0, %6#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%10:3 = tf_executor.Switch %cst, %6#0 : tensor<i1> {T = "tfdtype$DT_BOOL", _class = ["loc:@Less"]}
%11:3 = tf_executor.Switch %10#1, %10#1 : tensor<i1> {T = "tfdtype$DT_BOOL"}
%12:2 = tf_executor.island {
%24 = "tf.Identity"(%11#0) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%13:2 = tf_executor.island(%12#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%14:2 = tf_executor.island {
%24 = "tf.Identity"(%11#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%15:2 = tf_executor.island(%14#1) {
tf_executor.yield %cst_1 : tensor<i32>
}
%16:2 = tf_executor.island {
%24 = "tf.Identity"(%10#1) {T = "tfdtype$DT_BOOL"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %24 : tensor<i1>
}
%17:3 = tf_executor.Switch %9#1, %16#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%18:2 = tf_executor.island {
%24 = "tf.AddV2"(%17#1, %15#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%19:3 = tf_executor.Switch %9#1, %16#0 : tensor<i32> {T = "tfdtype$DT_INT32", _class = ["loc:@Const"]}
%20:2 = tf_executor.island {
%24 = "tf.Sub"(%19#0, %13#0) {T = "tfdtype$DT_INT32"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%21:3 = tf_executor.Merge %20#0, %18#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
%22:2 = tf_executor.island {
%24 = "tf.AddV2"(%21#0, %5#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %24 : tensor<i32>
}
%23:3 = tf_executor.Merge %8#0, %22#0 : tensor<i32> {N = 2 : i64, T = "tfdtype$DT_INT32"}
tf_executor.fetch %23#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// TODO(jpienaar): This needs to be updated post changing send/recv to executor.
// CHECK-LABEL: switch_with_send_recv
// CHECK: Switch
func @switch_with_send_recv() {
%cst = constant dense<true> : tensor<i1>
tf_executor.graph {
%1 = tf_executor.island {
"tf._Send"(%cst#0) {T = "tfdtype$DT_BOOL", client_terminated = false, device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "Const/_0", recv_device = "/job:localhost/replica:0/task:0/device:CPU:0", send_device = "/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation = 1 : i64, tensor_name = "edge_3_Const"} : (tensor<i1>) -> ()
tf_executor.yield
}
%2:2 = tf_executor.island(%1) {
%11 = "tf._Recv"() {client_terminated = false, device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "Const/_1", recv_device = "/job:localhost/replica:0/task:0/device:CPU:0", send_device = "/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation = 1 : i64, tensor_name = "edge_3_Const", tensor_type = "tfdtype$DT_BOOL"} : () -> tensor<*xi1>
tf_executor.yield %11 : tensor<*xi1>
}
%3:3 = tf_executor.Switch %2#0, %cst#0 : tensor<*xi1> {T = "tfdtype$DT_BOOL", device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "cond/Switch"}
%4:2 = tf_executor.island {
%11 = "tf.Identity"(%3#0) {T = "tfdtype$DT_BOOL", _class = ["loc:@cond/control_dependency_1"], device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "cond/switch_f"} : (tensor<*xi1>) -> tensor<*xi1>
tf_executor.yield %11 : tensor<*xi1>
}
%5:2 = tf_executor.island(%4#1) {
%11 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = "tfdtype$DT_BOOL", name = "cond/Assert/Assert/condition", value = dense<false> : tensor<i1>} : () -> tensor<i1>
tf_executor.yield %11 : tensor<i1>
}
%6:2 = tf_executor.island(%4#1) {
%11 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = "tfdtype$DT_STRING", name = "cond/Assert/Assert/data_0", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20225C30313757726F6E67206272616E636821212122"> : tensor<!tf.string>} : () -> tensor<!tf.string>
tf_executor.yield %11 : tensor<!tf.string>
}
%7 = tf_executor.island {
"tf.Assert"(%5#0, %6#0) {T = ["tfdtype$DT_STRING"], device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "cond/Assert/Assert", summarize = 3 : i64} : (tensor<i1>, tensor<!tf.string>) -> ()
tf_executor.yield
}
%8:2 = tf_executor.island(%7) {
%11 = "tf.Identity"(%4#0) {T = "tfdtype$DT_BOOL", device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "cond/control_dependency_1"} : (tensor<*xi1>) -> tensor<*xi1>
tf_executor.yield %11 : tensor<*xi1>
}
%9:3 = tf_executor.Merge %8#0, %cst#0 : (tensor<*xi1>, tensor<i1>) -> (tensor<*xi1>, tensor<i32>, !tf_executor.control) {N = 2 : i64, T = "tfdtype$DT_BOOL", device = "/job:localhost/replica:0/task:0/device:CPU:0", name = "cond/Merge"}
%10 = tf_executor.island {
"tf._Retval"(%9#0) {T = "tfdtype$DT_BOOL", device = "/job:localhost/replica:0/task:0/device:CPU:0", index = 0 : i64, name = "_retval_cond/Merge_0_0"} : (tensor<*xi1>) -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -1,288 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass folds switch and merge nodes.
// This pass assumes/requires:
// 1. Ops in an island execute all under the same condition;
// 2. It is run before graph partitioning (i.e., there are no _Send/_Recv nodes
// in the graph);
// 3. No other ops, except _Merge, in the graph execute with dead inputs;
#include <climits>
#include <cstdint>
#include <numeric>
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#define DEBUG_TYPE "tf-switch-fold"
namespace mlir {
namespace {
class SwitchFoldPass : public mlir::PassWrapper<SwitchFoldPass, FunctionPass> {
public:
void runOnFunction() override;
};
} // namespace
// Returns the defining op for a value looking through islands.
static Operation* GetDefiningOp(Value val) {
Operation* op = val.getDefiningOp();
auto island_op = dyn_cast_or_null<tf_executor::IslandOp>(op);
if (!island_op) return op;
auto yield_op = island_op.GetYield();
auto index = val.cast<mlir::OpResult>().getResultNumber();
return yield_op.getOperand(index).getDefiningOp();
}
// Returns either the value or input to an IdentityOp.
// Note: this should really be handled by constant folding, but identity nodes
// need to be treated specially in general until they are expanded into
// different types of nodes (e.g., recv identity nodes. For conditionals
// identity nodes are common so handle them specially when considering
// predicate in a minimally invasive way until identity's are handled more
// generally.
static Value LookThroughIdentityOp(Value pred_val) {
if (!pred_val) return pred_val;
auto op = GetDefiningOp(pred_val);
if (auto id_op = dyn_cast_or_null<TF::IdentityOp>(op))
pred_val = id_op.input();
return pred_val;
}
namespace {
// Worklist queue of ops to be deleted. This is a queue of ops that are dead
// and need to be removed from the graph/their outputs removed. Excluding merge
// that has to be treated specially as it fires with some dead inputs.
class DeadQueue {
public:
// Enqueue operation for deletion.
void Enqueue(Operation* op, bool due_to_control_input) {
auto merge_op = dyn_cast<tf_executor::MergeOp>(op);
// Only insert MergeOp if all its inputs are dead.
if (!merge_op) {
dead_ops_.insert(op);
return;
}
if (due_to_control_input) return;
auto pair = merge_nodes_.insert({merge_op, -1});
auto& count = pair.first->second;
if (pair.second) {
// Compute number of non-control inputs. If we have a Switch directly
// feeding into the Merge then we could have a null value here.
count = 0;
for (auto operand : op->getOperands()) {
if (operand && !operand.getType().isa<tf_executor::ControlType>())
++count;
}
}
// Decrement number of unseen inputs.
--count;
if (!count) dead_ops_.insert(op);
}
// Enqueue users of a value.
void EnqueueUsers(Value val) {
for (auto user : val.getUsers()) {
Enqueue(user, val.getType().isa<tf_executor::ControlType>());
}
}
// Delete dead ops while propagating deadness to consumers.
void DeleteDeadOps() {
while (!dead_ops_.empty()) {
auto dead = dead_ops_.pop_back_val();
for (auto res : dead->getResults()) {
EnqueueUsers(res);
}
DeleteOp(dead);
}
}
// Iterators over MergeOps. This is used below for merge_nodes_ which maps
// from merge operation to number of inputs that are dead.
using MergeMap = llvm::DenseMap<Operation*, int>;
using const_iterator = MergeMap::const_iterator;
llvm::iterator_range<const_iterator> merge_nodes() const {
return llvm::make_range(merge_nodes_.begin(), merge_nodes_.end());
}
private:
void DeleteOp(Operation* op) {
merge_nodes_.erase(op);
op->dropAllDefinedValueUses();
// If a YieldOp is being deleted, then also remove its IslandOp. This is
// only valid due to requirement that all ops in island execute under same
// conditions. YieldOp is always inside of an IslandOp and if it is dead,
// then so is its parent.
if (isa<tf_executor::YieldOp>(op))
Enqueue(op->getParentOfType<tf_executor::IslandOp>(), false);
op->erase();
}
// Dead ops that need to be removed/deadness propagated.
llvm::SetVector<Operation*> dead_ops_;
// Merge nodes that may be dead.
MergeMap merge_nodes_;
}; // namespace
} // namespace
// Enqueues values of foldable switch ops.
static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
DeadQueue* queue) {
Value pred_val = LookThroughIdentityOp(switch_op.predicate());
// If predicate or input is null then enqueue entire op for deletion.
if (pred_val == nullptr || switch_op.data() == nullptr) {
queue->Enqueue(switch_op, false);
return;
}
DenseElementsAttr pred;
if (!matchPattern(pred_val, m_Constant(&pred))) return;
bool taken = pred.getSplatValue<bool>();
Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
live.replaceAllUsesWith(switch_op.data());
queue->EnqueueUsers(dead);
// Delete switch op.
switch_op.getOperation()->dropAllDefinedValueUses();
switch_op.erase();
}
// Folds merge nodes with only a single non-dead input.
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
// Create builder for val_index of MergeOp.
auto* block = &function.front();
OpBuilder builder = OpBuilder::atBlockEnd(block);
auto type = builder.getIntegerType(32);
auto build_index = [&](Location loc, int value) {
return builder.create<ConstantOp>(loc, type,
builder.getI32IntegerAttr(value));
};
for (auto it : queue.merge_nodes()) {
// Find the valid input to merge node.
Value val = nullptr;
int index = -1;
auto* merge = it.first;
auto merge_op = cast<tf_executor::MergeOp>(merge);
for (auto e : llvm::enumerate(merge->getOperands())) {
Value operand = e.value();
if (!operand) continue;
// Skip control operands.
if (operand.getType().isa<tf_executor::ControlType>()) break;
if (val != nullptr) {
return merge->emitOpError("multiple valid inputs post switch folding");
}
val = operand;
index = e.index();
}
assert(val != nullptr && "merge node should have been deleted");
merge_op.output().replaceAllUsesWith(val);
// Build and insert value_index only if needed.
if (!merge_op.value_index().use_empty()) {
merge_op.value_index().replaceAllUsesWith(
build_index(merge->getLoc(), index));
}
// Propagate control dependencies if used.
if (!merge_op.control().use_empty()) {
// Change control dependencies from the merge to being on the parent of
// the value being propagated.
auto def_op = val.getDefiningOp();
#ifndef NDEBUG
auto exec_dialect =
function.getContext()->getLoadedDialect("tf_executor");
assert(def_op->getDialect() == exec_dialect &&
"unable to forward control dependencies");
#endif
merge_op.control().replaceAllUsesWith(
def_op->getResult(def_op->getNumResults() - 1));
}
merge->erase();
}
return success();
}
// TODO(jpienaar): This should be replace by checking ops in executor dialect.
bool HasSendOrReceive(FuncOp function) {
return function
.walk([&](::mlir::Operation* op) {
auto name = op->getName().getStringRef();
if (name == "tf._Send" || name == "tf._Recv")
return WalkResult::interrupt();
return WalkResult::advance();
})
.wasInterrupted();
}
void SwitchFoldPass::runOnFunction() {
if (HasSendOrReceive(getFunction())) return;
DeadQueue queue;
// Initialize dead queue with dead outputs of foldable SwitchOps.
getFunction().walk([&](tf_executor::SwitchOp switch_op) {
MatchSwitchFoldOps(switch_op, &queue);
});
queue.DeleteDeadOps();
if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure();
} // namespace mlir
namespace tf_executor {
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass() {
return std::make_unique<SwitchFoldPass>();
}
} // namespace tf_executor
static PassRegistration<SwitchFoldPass> pass(
"tf-switch-fold", "Fold switch nodes with constant predicates");
} // namespace mlir

View File

@ -148,11 +148,9 @@ void CreateTFStandardPipeline(OpPassManager &pm,
OpPassManager &func_pm = pm.nest<FuncOp>();
// First operates on the executor dialect:
// - eliminate trivial switch/merge.
// - remove dead islands.
// - fuse islands as much as possible.
// - materialize the eventual "pass-through" ops by inlining their content.
func_pm.addPass(tf_executor::CreateSwitchFoldPass());
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
func_pm.addPass(CreateMaterializePassthroughOpPass());

View File

@ -212,9 +212,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateConstantOpDeviceAssignmentPass();
} // namespace TF
namespace tf_executor {
// Returns a pass that folds switch nodes with constant predicates.
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
// Creates a pass to merge IslandOps from TFExecutor dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorIslandCoarseningPass();