Don't outside compile tf.Assert Op even if it contains string operands.
The tf.Assert op is removed during legalization and shouldn't be outside compiled for performance reasons. PiperOrigin-RevId: 341669459 Change-Id: I956662b63aafaef05a269d4786d5504825d20dca
This commit is contained in:
parent
fa595eb8fa
commit
f4307fa6f5
@ -1,7 +1,7 @@
|
|||||||
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
|
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @unsupported_op_no_soft_placement
|
// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute
|
||||||
func @unsupported_op_no_soft_placement() -> tensor<i32> {
|
func @unsupported_op_missing_soft_placement_attribute() -> tensor<i32> {
|
||||||
%0 = "tf_device.cluster"() ( {
|
%0 = "tf_device.cluster"() ( {
|
||||||
// CHECK: "tf.UnsupportedOp"
|
// CHECK: "tf.UnsupportedOp"
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor<i32> {
|
|||||||
return %0 : tensor<i32>
|
return %0 : tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @assert_op_string_operand
|
||||||
|
func @assert_op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
|
||||||
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
// CHECK: "tf.Assert"
|
||||||
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
|
// CHECK: "tf.UnsupportedOp"
|
||||||
|
// CHECK-SAME: _xla_outside_compilation
|
||||||
|
// CHECK: "tf.Identity"
|
||||||
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
|
%t = constant dense<true> : tensor<i1>
|
||||||
|
"tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<!tf.string>) -> ()
|
||||||
|
%1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
|
||||||
|
tf_device.return %2 : tensor<i32>
|
||||||
|
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @unsupported_op
|
// CHECK-LABEL: func @unsupported_op
|
||||||
func @unsupported_op() -> tensor<i32> {
|
func @unsupported_op() -> tensor<i32> {
|
||||||
%0 = "tf_device.cluster"() ( {
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
@ -164,7 +164,9 @@ bool IsSupportedOp(Operation& op,
|
|||||||
const Dialect* tf_dialect) {
|
const Dialect* tf_dialect) {
|
||||||
if (op.getDialect() != tf_dialect)
|
if (op.getDialect() != tf_dialect)
|
||||||
return true;
|
return true;
|
||||||
else
|
// Assert has a legalization that later removes it so we don't want to outside
|
||||||
|
// compile it ever for performance reasons.
|
||||||
|
if (llvm::isa<TF::AssertOp>(op)) return true;
|
||||||
return !HasStringOperand(op) && !HasStringResult(op) &&
|
return !HasStringOperand(op) && !HasStringResult(op) &&
|
||||||
(MatchesPattern(op, supported_ops) ||
|
(MatchesPattern(op, supported_ops) ||
|
||||||
mhlo::IsOpAllowedTf2XlaFallback(&op));
|
mhlo::IsOpAllowedTf2XlaFallback(&op));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user