diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index c8a6d5489c3..ae5f98da85f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s -// CHECK-LABEL: func @unsupported_op_no_soft_placement -func @unsupported_op_no_soft_placement() -> tensor { +// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute +func @unsupported_op_missing_soft_placement_attribute() -> tensor { %0 = "tf_device.cluster"() ( { // CHECK: "tf.UnsupportedOp" // CHECK-NOT: _xla_outside_compilation @@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor { return %0 : tensor } +// CHECK-LABEL: func @assert_op_string_operand +func @assert_op_string_operand(%arg0: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Assert" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %t = constant dense : tensor + "tf.Assert"(%t, %arg0) {summarize = 3} : (tensor, tensor) -> () + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + // CHECK-LABEL: func @unsupported_op func @unsupported_op() -> tensor { %0 = "tf_device.cluster"() ( { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index ac844b925ce..b5607d63af9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -164,10 +164,12 @@ bool IsSupportedOp(Operation& op, const Dialect* tf_dialect) { if (op.getDialect() != tf_dialect) return true; - else - return !HasStringOperand(op) && !HasStringResult(op) && - (MatchesPattern(op, supported_ops) || - mhlo::IsOpAllowedTf2XlaFallback(&op)); + // Assert has a legalization that later removes it so we don't want to outside + // compile it ever for performance reasons. + if (llvm::isa(op)) return true; + return !HasStringOperand(op) && !HasStringResult(op) && + (MatchesPattern(op, supported_ops) || + mhlo::IsOpAllowedTf2XlaFallback(&op)); } // Checks all regions of `op` for captured string operands.