diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index a72d5916da9..d4c05e8d45d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -65,6 +65,12 @@ def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< "TF::SameOperandsAndResultElementTypeResolveRef">; +// Op has the same operand and result types after resolving reference types +// (i.e., after converting reference types to their corresponding TensorFlow or +// standard types). +def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait< + "TF::SameOperandsAndResultTypeResolveRef">; + // Layout agnostic operations do not depend on the operands data layout (data // format), as an example all element wise operations are layout agnostic. def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; @@ -331,8 +337,7 @@ def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; def TF_Float : AnyTypeOf< - [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, - TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref], + [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], "floating-point">; // Tensor types diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 5d9013edfa1..db76bd52203 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -66,6 +66,39 @@ class OperandsSameAsResultsTypeOrRef } }; +namespace detail { +inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef( + Operation* op) { + Type element_type; + if (op->getNumResults() > 0) { + element_type = + mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType()); + } else if (op->getNumOperands() > 0) { + element_type = + mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType()); + } else { + // Nothing to check. + return success(); + } + // Verify that all result element types are compatible to `element_type`. + for (const auto& result_type : op->getResultTypes()) { + if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + // Verify that all operand element types are compatible to `element_type`. + for (const auto& operand_type : op->getOperandTypes()) { + if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != + element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + return success(); +} +} // namespace detail + // Verifies that op has the same operand and result element types (or type // itself, if scalar) after resolving reference types (i.e., after converting // reference types to their corresponding TensorFlow or standard types). @@ -75,34 +108,20 @@ class SameOperandsAndResultElementTypeResolveRef SameOperandsAndResultElementTypeResolveRef> { public: static LogicalResult verifyTrait(Operation* op) { - Type element_type; - if (op->getNumResults() > 0) { - element_type = - mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType()); - } else if (op->getNumOperands() > 0) { - element_type = - mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType()); - } else { - // Nothing to check. - return success(); - } - // Verify that all result element types are compatible to `element_type`. - for (const auto& result_type : op->getResultTypes()) { - if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != - element_type) { - return op->emitOpError( - "requires compatible element types for all operands and results"); - } - } - // Verify that all operand element types are compatible to `element_type`. - for (const auto& operand_type : op->getOperandTypes()) { - if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != - element_type) { - return op->emitOpError( - "requires compatible element types for all operands and results"); - } - } - return success(); + return detail::verifySameOperandsAndResultElementTypeResolveRef(op); + } +}; + +// Verifies that op has the same operand and result types after resolving +// reference types (i.e., after converting reference types to their +// corresponding TensorFlow or standard types). +template +class SameOperandsAndResultTypeResolveRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + if (failed(impl::verifySameOperandsAndResultShape(op))) return failure(); + return detail::verifySameOperandsAndResultElementTypeResolveRef(op); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index f06b226c52d..ffa3394c8c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -291,7 +291,8 @@ func @next_iteration_sink_control_input() { %source:3 = tf_executor.NextIteration.Source : tensor<*xi32> %island:2 = tf_executor.island { %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> - %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>) + tf_executor.yield %const : tensor<*xi32> } tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32> @@ -306,7 +307,7 @@ func @loop_cond_control_input() { tf_executor.graph { %island:2 = tf_executor.island { %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi1> - %print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>) + %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi1>) -> (tensor<*xi1>) tf_executor.yield %const : tensor<*xi1> } %loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1> @@ -321,7 +322,7 @@ func @enter_control_input() { tf_executor.graph { %island:2 = tf_executor.island { %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> - %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>) tf_executor.yield %const : tensor<*xi32> } %enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32> @@ -336,7 +337,7 @@ func @switchn_control_input(%arg1: tensor) { tf_executor.graph { %island:2 = tf_executor.island { %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> - %print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>) + %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>) tf_executor.yield %const : tensor<*xi32> } %switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32>