diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_tpu_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/executor_tpuv1_outline_tpu_island.mlir similarity index 100% rename from tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_tpu_island.mlir rename to tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/executor_tpuv1_outline_tpu_island.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir new file mode 100644 index 00000000000..b1dee63ca03 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/while_op.mlir @@ -0,0 +1,48 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail + +// CHECK: func @control_input +// CHECK-NOT: func @ +// CHECK-LABEL: module @_tpu_v1_compat_outlined +// CHECK: @_tpu_v1_compat_outlined_func0 +// CHECK: func @while_body_with_cluster_attr +// CHECK: func @while_cond_with_cluster_attr +// CHECK: func @while_body_without_cluster_attr +// CHECK: func @while_cond_without_cluster_attr +// CHECK: func @callee_func +module { + func @control_input(%arg0: tensor) -> tensor { + %0:4 = tf_executor.graph { + %outputs:4, %control = tf_executor.island { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () + %1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor) -> tensor + %3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor) -> tensor + %4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor) -> tensor + tf_executor.yield %1, %2, %3, %4 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor, tensor, tensor, tensor + + } + return %0#0 : tensor + } + func @while_body_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_body_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor) -> tensor + return %0 : tensor + } + func @callee_func(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index b553a74d097..57ea1822b5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -133,9 +133,23 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() { /*executor_type=*/builder.getStringAttr("")); SmallVector yield_operands(call_op.getResults()); builder.create(island_op.getLoc(), yield_operands); + } - // TODO(aminim): handle transitively referenced function and clone them in - // the new module. + // Outlined all the transitively called functions by moving them in the + // outlined module. + for (FuncOp func : outlined_module.getOps()) { + func.walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) { + auto symbol_ref = attr.second.dyn_cast(); + if (!symbol_ref) continue; + if (outlined_symbol_table.lookup(symbol_ref.getValue())) + continue; + FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); + callee.getOperation()->getBlock()->getOperations().remove( + callee.getOperation()); + outlined_symbol_table.insert(callee); + } + }); } }