Do not add outside compilation attributes to ops without usages in

TPUHostComputationExpansion pass.

PiperOrigin-RevId: 315998431
Change-Id: I391001219031f6a8fb11acaa6f1cbaf9c565c64d
This commit is contained in:
A. Unique TensorFlower 2020-06-11 16:12:57 -07:00 committed by TensorFlower Gardener
parent f6938bff05
commit e3aa45345b
2 changed files with 24 additions and 8 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -26,7 +26,7 @@ func @cast_at_head_expanded(%arg0: tensor<?xi32>) {
"tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -44,7 +44,7 @@ func @check_consecutive_unary_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -59,7 +59,7 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%1) : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -67,9 +67,9 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Cast"
// CHECK-NOT: _xla_outside_compilation = ""
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation = ""
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.B"
"tf_device.cluster"() ( {
%1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
@ -77,6 +77,19 @@ func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<
"tf.B"(%2) : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) : () -> ()
return
}
// CHECK-LABEL: func @check_op_without_usage_not_outside_compiled
func @check_op_without_usage_not_outside_compiled(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
"tf_device.cluster"() ( {
"tf.Identity"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
"tf.C"() : () -> ()
tf_device.return
}) : () -> ()
return
}

View File

@ -92,10 +92,13 @@ void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
for (auto head_outside_compiled_op :
llvm::reverse(head_outside_compiled_ops)) {
if (HasOutsideCompilationAttribute(head_outside_compiled_op)) continue;
auto users = head_outside_compiled_op->getUsers();
if (users.empty() ||
HasOutsideCompilationAttribute(head_outside_compiled_op))
continue;
bool should_expand_op_to_host_computation = true;
for (auto consumer_op : head_outside_compiled_op->getUsers()) {
for (auto consumer_op : users) {
if (should_expand_op_to_host_computation &&
!HasOutsideCompilationAttribute(consumer_op)) {
should_expand_op_to_host_computation = false;