Don't outside compile tf.Assert Op even if it contains string operands.
The tf.Assert op is removed during legalization and shouldn't be outside compiled for performance reasons. PiperOrigin-RevId: 341669459 Change-Id: I956662b63aafaef05a269d4786d5504825d20dca
This commit is contained in:
parent
fa595eb8fa
commit
f4307fa6f5
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @unsupported_op_no_soft_placement
|
||||
func @unsupported_op_no_soft_placement() -> tensor<i32> {
|
||||
// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute
|
||||
func @unsupported_op_missing_soft_placement_attribute() -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.UnsupportedOp"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @assert_op_string_operand
|
||||
func @assert_op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.Assert"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.UnsupportedOp"
|
||||
// CHECK-SAME: _xla_outside_compilation
|
||||
// CHECK: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
%t = constant dense<true> : tensor<i1>
|
||||
"tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<!tf.string>) -> ()
|
||||
%1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %2 : tensor<i32>
|
||||
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unsupported_op
|
||||
func @unsupported_op() -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
|
@ -164,10 +164,12 @@ bool IsSupportedOp(Operation& op,
|
||||
const Dialect* tf_dialect) {
|
||||
if (op.getDialect() != tf_dialect)
|
||||
return true;
|
||||
else
|
||||
return !HasStringOperand(op) && !HasStringResult(op) &&
|
||||
(MatchesPattern(op, supported_ops) ||
|
||||
mhlo::IsOpAllowedTf2XlaFallback(&op));
|
||||
// Assert has a legalization that later removes it so we don't want to outside
|
||||
// compile it ever for performance reasons.
|
||||
if (llvm::isa<TF::AssertOp>(op)) return true;
|
||||
return !HasStringOperand(op) && !HasStringResult(op) &&
|
||||
(MatchesPattern(op, supported_ops) ||
|
||||
mhlo::IsOpAllowedTf2XlaFallback(&op));
|
||||
}
|
||||
|
||||
// Checks all regions of `op` for captured string operands.
|
||||
|
Loading…
x
Reference in New Issue
Block a user