Fix tpuv1_outline_tpu_island to handle transitive function calls

PiperOrigin-RevId: 295883908
Change-Id: I384bed4144942ebf31b1b3875513fe9e1bae8019
This commit is contained in:
Mehdi Amini 2020-02-18 20:26:53 -08:00 committed by TensorFlower Gardener
parent 6cb8ec0e31
commit f9e9fb9de2
3 changed files with 64 additions and 2 deletions

View File

@ -0,0 +1,48 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
// CHECK: func @control_input
// CHECK-NOT: func @
// CHECK-LABEL: module @_tpu_v1_compat_outlined
// CHECK: @_tpu_v1_compat_outlined_func0
// CHECK: func @while_body_with_cluster_attr
// CHECK: func @while_cond_with_cluster_attr
// CHECK: func @while_body_without_cluster_attr
// CHECK: func @while_cond_without_cluster_attr
// CHECK: func @callee_func
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
tf_executor.yield %1, %2, %3, %4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}

View File

@ -133,9 +133,23 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
/*executor_type=*/builder.getStringAttr(""));
SmallVector<Value, 16> yield_operands(call_op.getResults());
builder.create<YieldOp>(island_op.getLoc(), yield_operands);
}
// TODO(aminim): handle transitively referenced function and clone them in
// the new module.
// Outlined all the transitively called functions by moving them in the
// outlined module.
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
func.walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
if (!symbol_ref) continue;
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
continue;
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
callee.getOperation()->getBlock()->getOperations().remove(
callee.getOperation());
outlined_symbol_table.insert(callee);
}
});
}
}