Handle Array of SymbolRef attributes during tf-executor-tpu-v1-island-outlining pass.

PiperOrigin-RevId: 313671791
Change-Id: I088c53bb45df7f1f9a6284f0dc50173c91bf1b98
This commit is contained in:
Prakalp Srivastava 2020-05-28 15:32:29 -07:00 committed by TensorFlower Gardener
parent 1e22a99527
commit b1cb3f12da
2 changed files with 67 additions and 7 deletions

View File

@ -0,0 +1,47 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
// CHECK: func @control_input
// CHECK-NOT: func @
// CHECK-LABEL: module @_tpu_v1_compat_outlined
// CHECK: @_tpu_v1_compat_outlined_func0
// CHECK: func @branch_0
// CHECK: func @branch_1
// CHECK: func @branch_2
// CHECK: func @branch_3
// CHECK: func @branch_4
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0 = tf_executor.graph {
%output, %control = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %result : tensor<i32>
}
tf_executor.fetch %output : tensor<i32>
}
return %0 : tensor<i32>
}
func @branch_0(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @branch_1(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @branch_2(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @branch_3(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @branch_4(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}

View File

@ -49,6 +49,16 @@ struct TPUBridgeExecutorIslandOutlining
void runOnOperation() override; void runOnOperation() override;
}; };
// Move FuncOp referenced by `symbol_ref` from one symbol table to another.
void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from,
SymbolTable &to) {
if (to.lookup<FuncOp>(symbol_ref.getValue())) return;
FuncOp callee = from.lookup<FuncOp>(symbol_ref.getValue());
callee.getOperation()->getBlock()->getOperations().remove(
callee.getOperation());
to.insert(callee);
}
void TPUBridgeExecutorIslandOutlining::runOnOperation() { void TPUBridgeExecutorIslandOutlining::runOnOperation() {
MLIRContext *ctx = &getContext(); MLIRContext *ctx = &getContext();
@ -141,14 +151,17 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
for (FuncOp func : outlined_module.getOps<FuncOp>()) { for (FuncOp func : outlined_module.getOps<FuncOp>()) {
func.walk([&](Operation *op) { func.walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) { for (NamedAttribute attr : op->getAttrs()) {
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>(); if (auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
if (!symbol_ref) continue; MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
continue; continue;
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue()); }
callee.getOperation()->getBlock()->getOperations().remove( if (auto array_attr = attr.second.dyn_cast<ArrayAttr>()) {
callee.getOperation()); for (const Attribute &attribute : array_attr) {
outlined_symbol_table.insert(callee); auto symbol_ref = attribute.dyn_cast<FlatSymbolRefAttr>();
if (!symbol_ref) continue;
MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
}
}
} }
}); });
} }