Added new type constraint TF_SameOperandsAndResultTypeResolveRef

Also fixed testcases and redundant type constraints.

PiperOrigin-RevId: 347948243
Change-Id: I61b879ab9392df33eebfa10dcd4acf7bc8ded3f2
This commit is contained in:
Michael Gester 2020-12-16 20:22:44 -08:00 committed by TensorFlower Gardener
parent bd73bad5b3
commit a87bf6e6fe
3 changed files with 59 additions and 34 deletions

View File

@ -65,6 +65,12 @@ def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
"TF::SameOperandsAndResultElementTypeResolveRef">;
// Op has the same operand and result types after resolving reference types
// (i.e., after converting reference types to their corresponding TensorFlow or
// standard types).
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
"TF::SameOperandsAndResultTypeResolveRef">;
// Layout agnostic operations do not depend on the operands data layout (data
// format), as an example all element wise operations are layout agnostic.
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
@ -331,8 +337,7 @@ def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
def TF_Float : AnyTypeOf<
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16,
TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref],
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
"floating-point">;
// Tensor types

View File

@ -66,6 +66,39 @@ class OperandsSameAsResultsTypeOrRef
}
};
namespace detail {
inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
Operation* op) {
Type element_type;
if (op->getNumResults() > 0) {
element_type =
mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
} else if (op->getNumOperands() > 0) {
element_type =
mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
} else {
// Nothing to check.
return success();
}
// Verify that all result element types are compatible to `element_type`.
for (const auto& result_type : op->getResultTypes()) {
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type) {
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
// Verify that all operand element types are compatible to `element_type`.
for (const auto& operand_type : op->getOperandTypes()) {
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
element_type) {
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
return success();
}
} // namespace detail
// Verifies that op has the same operand and result element types (or type
// itself, if scalar) after resolving reference types (i.e., after converting
// reference types to their corresponding TensorFlow or standard types).
@ -75,34 +108,20 @@ class SameOperandsAndResultElementTypeResolveRef
SameOperandsAndResultElementTypeResolveRef> {
public:
static LogicalResult verifyTrait(Operation* op) {
Type element_type;
if (op->getNumResults() > 0) {
element_type =
mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
} else if (op->getNumOperands() > 0) {
element_type =
mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
} else {
// Nothing to check.
return success();
}
// Verify that all result element types are compatible to `element_type`.
for (const auto& result_type : op->getResultTypes()) {
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) !=
element_type) {
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
// Verify that all operand element types are compatible to `element_type`.
for (const auto& operand_type : op->getOperandTypes()) {
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
element_type) {
return op->emitOpError(
"requires compatible element types for all operands and results");
}
}
return success();
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
}
};
// Verifies that op has the same operand and result types after resolving
// reference types (i.e., after converting reference types to their
// corresponding TensorFlow or standard types).
template <typename ConcreteType>
class SameOperandsAndResultTypeResolveRef
: public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
public:
static LogicalResult verifyTrait(Operation* op) {
if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
}
};

View File

@ -291,7 +291,8 @@ func @next_iteration_sink_control_input() {
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
@ -306,7 +307,7 @@ func @loop_cond_control_input() {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi1>) -> (tensor<*xi1>)
tf_executor.yield %const : tensor<*xi1>
}
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
@ -321,7 +322,7 @@ func @enter_control_input() {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
@ -336,7 +337,7 @@ func @switchn_control_input(%arg1: tensor<i32>) {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
%switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32>