Do not add outside compilation attributes to ops without usages in
TPUHostComputationExpansion pass. PiperOrigin-RevId: 315998431 Change-Id: I391001219031f6a8fb11acaa6f1cbaf9c565c64d
This commit is contained in:
parent
f6938bff05
commit
e3aa45345b
tensorflow/compiler/mlir/tensorflow
@ -26,7 +26,7 @@ func @cast_at_head_expanded(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ func @check_consecutive_unary_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%1) : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -67,9 +67,9 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
// CHECK-NOT: _xla_outside_compilation = ""
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation = ""
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.B"
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
@ -77,6 +77,19 @@ func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<
|
||||
"tf.B"(%2) : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @check_op_without_usage_not_outside_compiled
|
||||
func @check_op_without_usage_not_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.Identity"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -92,10 +92,13 @@ void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
|
||||
|
||||
for (auto head_outside_compiled_op :
|
||||
llvm::reverse(head_outside_compiled_ops)) {
|
||||
if (HasOutsideCompilationAttribute(head_outside_compiled_op)) continue;
|
||||
auto users = head_outside_compiled_op->getUsers();
|
||||
if (users.empty() ||
|
||||
HasOutsideCompilationAttribute(head_outside_compiled_op))
|
||||
continue;
|
||||
|
||||
bool should_expand_op_to_host_computation = true;
|
||||
for (auto consumer_op : head_outside_compiled_op->getUsers()) {
|
||||
for (auto consumer_op : users) {
|
||||
if (should_expand_op_to_host_computation &&
|
||||
!HasOutsideCompilationAttribute(consumer_op)) {
|
||||
should_expand_op_to_host_computation = false;
|
||||
|
Loading…
Reference in New Issue
Block a user