Don't outside compile tf.Assert Op even if it contains string operands.

The tf.Assert op is removed during legalization and shouldn't be outside compiled for performance reasons.

PiperOrigin-RevId: 341669459
Change-Id: I956662b63aafaef05a269d4786d5504825d20dca
This commit is contained in:
Ken Franko 2020-11-10 12:05:51 -08:00 committed by TensorFlower Gardener
parent fa595eb8fa
commit f4307fa6f5
2 changed files with 26 additions and 6 deletions

View File

@ -1,7 +1,7 @@
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
// CHECK-LABEL: func @unsupported_op_no_soft_placement
func @unsupported_op_no_soft_placement() -> tensor<i32> {
// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute
func @unsupported_op_missing_soft_placement_attribute() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.UnsupportedOp"
// CHECK-NOT: _xla_outside_compilation
@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK-LABEL: func @assert_op_string_operand
func @assert_op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.Assert"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.UnsupportedOp"
// CHECK-SAME: _xla_outside_compilation
// CHECK: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
%t = constant dense<true> : tensor<i1>
"tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<!tf.string>) -> ()
%1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK-LABEL: func @unsupported_op
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {

View File

@ -164,10 +164,12 @@ bool IsSupportedOp(Operation& op,
const Dialect* tf_dialect) {
if (op.getDialect() != tf_dialect)
return true;
else
return !HasStringOperand(op) && !HasStringResult(op) &&
(MatchesPattern(op, supported_ops) ||
mhlo::IsOpAllowedTf2XlaFallback(&op));
// Assert has a legalization that later removes it so we don't want to outside
// compile it ever for performance reasons.
if (llvm::isa<TF::AssertOp>(op)) return true;
return !HasStringOperand(op) && !HasStringResult(op) &&
(MatchesPattern(op, supported_ops) ||
mhlo::IsOpAllowedTf2XlaFallback(&op));
}
// Checks all regions of `op` for captured string operands.