Added new type constraint TF_SameOperandsAndResultTypeResolveRef
Also fixed testcases and redundant type constraints. PiperOrigin-RevId: 347948243 Change-Id: I61b879ab9392df33eebfa10dcd4acf7bc8ded3f2
This commit is contained in:
parent
bd73bad5b3
commit
a87bf6e6fe
@ -65,6 +65,12 @@ def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
|
||||
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
|
||||
"TF::SameOperandsAndResultElementTypeResolveRef">;
|
||||
|
||||
// Op has the same operand and result types after resolving reference types
|
||||
// (i.e., after converting reference types to their corresponding TensorFlow or
|
||||
// standard types).
|
||||
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
|
||||
"TF::SameOperandsAndResultTypeResolveRef">;
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as an example all element wise operations are layout agnostic.
|
||||
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||||
@ -331,8 +337,7 @@ def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
|
||||
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
|
||||
|
||||
def TF_Float : AnyTypeOf<
|
||||
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16,
|
||||
TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref],
|
||||
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
|
||||
"floating-point">;
|
||||
|
||||
// Tensor types
|
||||
|
@ -66,6 +66,39 @@ class OperandsSameAsResultsTypeOrRef
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
|
||||
Operation* op) {
|
||||
Type element_type;
|
||||
if (op->getNumResults() > 0) {
|
||||
element_type =
|
||||
mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
|
||||
} else if (op->getNumOperands() > 0) {
|
||||
element_type =
|
||||
mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
|
||||
} else {
|
||||
// Nothing to check.
|
||||
return success();
|
||||
}
|
||||
// Verify that all result element types are compatible to `element_type`.
|
||||
for (const auto& result_type : op->getResultTypes()) {
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type) {
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
// Verify that all operand element types are compatible to `element_type`.
|
||||
for (const auto& operand_type : op->getOperandTypes()) {
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
|
||||
element_type) {
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Verifies that op has the same operand and result element types (or type
|
||||
// itself, if scalar) after resolving reference types (i.e., after converting
|
||||
// reference types to their corresponding TensorFlow or standard types).
|
||||
@ -75,34 +108,20 @@ class SameOperandsAndResultElementTypeResolveRef
|
||||
SameOperandsAndResultElementTypeResolveRef> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op) {
|
||||
Type element_type;
|
||||
if (op->getNumResults() > 0) {
|
||||
element_type =
|
||||
mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
|
||||
} else if (op->getNumOperands() > 0) {
|
||||
element_type =
|
||||
mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
|
||||
} else {
|
||||
// Nothing to check.
|
||||
return success();
|
||||
}
|
||||
// Verify that all result element types are compatible to `element_type`.
|
||||
for (const auto& result_type : op->getResultTypes()) {
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) !=
|
||||
element_type) {
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
// Verify that all operand element types are compatible to `element_type`.
|
||||
for (const auto& operand_type : op->getOperandTypes()) {
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
|
||||
element_type) {
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
|
||||
}
|
||||
};
|
||||
|
||||
// Verifies that op has the same operand and result types after resolving
|
||||
// reference types (i.e., after converting reference types to their
|
||||
// corresponding TensorFlow or standard types).
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultTypeResolveRef
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op) {
|
||||
if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
|
||||
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -291,7 +291,8 @@ func @next_iteration_sink_control_input() {
|
||||
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
|
||||
@ -306,7 +307,7 @@ func @loop_cond_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
|
||||
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi1>) -> (tensor<*xi1>)
|
||||
tf_executor.yield %const : tensor<*xi1>
|
||||
}
|
||||
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
|
||||
@ -321,7 +322,7 @@ func @enter_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
|
||||
@ -336,7 +337,7 @@ func @switchn_control_input(%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%switchn:4 = tf_executor._SwitchN %island#0, %arg1 of 3: tensor<*xi32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user