Fix handling of single op islands when breaking up islands.

It is possible for a single op island return different results from the single op (duplicate results, aliased results). This updates the single op island check with WrapsSingleOp instead of checking if the island body only has a single op excluding the terminator.

PiperOrigin-RevId: 314541656
Change-Id: I10bc8e6348a72764579c0ee0f977d100bf069aa8
This commit is contained in:
Andy Ly 2020-06-03 09:06:05 -07:00 committed by TensorFlower Gardener
parent f72f233547
commit b75ea8de71
2 changed files with 29 additions and 1 deletions

View File

@ -344,3 +344,31 @@ func @switchn_control_input(%arg1: tensor<i32>) {
}
return
}
// CHECK-LABEL: func @single_op_island_forward_block_arg
// CHECK: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
// CHECK: tf_executor.fetch %[[CONST]], %arg0
func @single_op_island_forward_block_arg(%arg0: tensor<?x?x?x?xbf16>) -> (tensor<2048xf32>, tensor<?x?x?x?xbf16>) {
%0:2 = tf_executor.graph {
%outputs:2, %control = tf_executor.island {
%1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2048xf32>} : () -> tensor<2048xf32>
tf_executor.yield %1, %arg0 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
tf_executor.fetch %outputs#0, %outputs#1 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
return %0#0, %0#1 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
// CHECK-LABEL: func @single_op_island_duplicate_result
// CHECK: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
// CHECK: tf_executor.fetch %[[CONST]], %[[CONST]]
func @single_op_island_duplicate_result() -> (tensor<2048xf32>, tensor<2048xf32>) {
%0:2 = tf_executor.graph {
%outputs:2, %control = tf_executor.island {
%1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2048xf32>} : () -> tensor<2048xf32>
tf_executor.yield %1, %1 : tensor<2048xf32>, tensor<2048xf32>
}
tf_executor.fetch %outputs#0, %outputs#1 : tensor<2048xf32>, tensor<2048xf32>
}
return %0#0, %0#1 : tensor<2048xf32>, tensor<2048xf32>
}

View File

@ -219,7 +219,7 @@ void BreakUpIslands::BreakUpIsland(
}
// Skip islands that are already only a single op.
if (hasSingleElement(island_body)) return;
if (island_op.WrapsSingleOp()) return;
auto control_type = tf_executor::ControlType::get(&getContext());
auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());