diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir new file mode 100644 index 00000000000..de6f9b42ba4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir @@ -0,0 +1,47 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail + +// CHECK: func @control_input +// CHECK-NOT: func @ +// CHECK-LABEL: module @_tpu_v1_compat_outlined +// CHECK: @_tpu_v1_compat_outlined_func0 +// CHECK: func @branch_0 +// CHECK: func @branch_1 +// CHECK: func @branch_2 +// CHECK: func @branch_3 +// CHECK: func @branch_4 +module { + func @control_input(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %output, %control = tf_executor.island { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () + %index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor, tensor) -> tensor + tf_executor.yield %result : tensor + } + tf_executor.fetch %output : tensor + + } + return %0 : tensor + } + func @branch_0(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_1(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_2(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_3(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_4(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 08645333d5d..e04f6bf3daa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -49,6 +49,16 @@ struct TPUBridgeExecutorIslandOutlining void runOnOperation() override; }; +// Move FuncOp referenced by `symbol_ref` from one symbol table to another. +void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from, + SymbolTable &to) { + if (to.lookup(symbol_ref.getValue())) return; + FuncOp callee = from.lookup(symbol_ref.getValue()); + callee.getOperation()->getBlock()->getOperations().remove( + callee.getOperation()); + to.insert(callee); +} + void TPUBridgeExecutorIslandOutlining::runOnOperation() { MLIRContext *ctx = &getContext(); @@ -141,14 +151,17 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() { for (FuncOp func : outlined_module.getOps()) { func.walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto symbol_ref = attr.second.dyn_cast(); - if (!symbol_ref) continue; - if (outlined_symbol_table.lookup(symbol_ref.getValue())) + if (auto symbol_ref = attr.second.dyn_cast()) { + MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); continue; - FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); - callee.getOperation()->getBlock()->getOperations().remove( - callee.getOperation()); - outlined_symbol_table.insert(callee); + } + if (auto array_attr = attr.second.dyn_cast()) { + for (const Attribute &attribute : array_attr) { + auto symbol_ref = attribute.dyn_cast(); + if (!symbol_ref) continue; + MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); + } + } } }); }