Add checks for equal operand and result (element) types for TF dialect ops
Many TF ops require that all operands and results have the same (element) type but previously this was often not checked, leading to failures later in the flow that are harder to debug. This change adds several such checks. PiperOrigin-RevId: 320511211 Change-Id: Iaba7d7449085d4fc7a3e3c18b0dd26c185b7282c
This commit is contained in:
parent
f5e8685729
commit
78199b484c
@ -87,7 +87,7 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf]
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -136,7 +136,7 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -648,7 +648,7 @@ tf.math.atan(y) # [1.047, 0.785] = x
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
|
||||
@ -765,7 +765,7 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
|
||||
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
@ -806,7 +806,7 @@ It is computed as:
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> {
|
||||
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
@ -1422,7 +1422,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
@ -1984,7 +1984,7 @@ Given an input tensor, this function computes hyperbolic cosine of every
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
|
||||
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Compute the pairwise cross product.";
|
||||
|
||||
let description = [{
|
||||
@ -2469,7 +2469,7 @@ Computes Psi, the derivative of Lgamma (the log of the absolute value of
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x / y element-wise.";
|
||||
|
||||
@ -2494,7 +2494,7 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns 0 if the denominator is zero.";
|
||||
|
||||
@ -3374,7 +3374,7 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [NoSideEffect, ResultsBroadcastableShape]>
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Returns element-wise remainder of division. When `x < 0` xor `y < 0` is
|
||||
@ -4111,7 +4111,7 @@ def ApplyG(op, dy, _):
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Compute the lower regularized incomplete Gamma function `P(a, x)`.
|
||||
@ -4145,7 +4145,7 @@ Gamma function.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Computes the gradient of `igamma(a, x)` wrt `a`.";
|
||||
|
||||
@ -4161,7 +4161,7 @@ def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableS
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Compute the upper regularized incomplete Gamma function `Q(a, x)`.
|
||||
@ -4928,7 +4928,7 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect]> {
|
||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = [{
|
||||
Multiply the matrix "a" by the matrix "b".
|
||||
}];
|
||||
@ -5692,7 +5692,7 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
|
||||
|
||||
@ -5766,7 +5766,7 @@ retained with length 1.
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns the min of x and y (i.e. x < y ? x : y) element-wise.";
|
||||
|
||||
@ -5899,7 +5899,7 @@ graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.Tensor
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Returns element-wise remainder of division. This emulates C semantics in that
|
||||
@ -5925,7 +5925,7 @@ the result here is consistent with a truncating divide. E.g.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x * y element-wise.";
|
||||
|
||||
@ -6426,7 +6426,7 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Computes the power of one value to another.";
|
||||
|
||||
@ -6905,7 +6905,7 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_RangeOp : TF_Op<"Range", [NoSideEffect]> {
|
||||
def TF_RangeOp : TF_Op<"Range", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Creates a sequence of numbers.";
|
||||
|
||||
let description = [{
|
||||
@ -9518,7 +9518,7 @@ Examples:
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x - y element-wise.";
|
||||
|
||||
@ -10675,7 +10675,7 @@ Python Semantics.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
Returns element-wise remainder of division. This emulates C semantics in that
|
||||
@ -11148,7 +11148,7 @@ where(input) ==> [[0, 0, 0],
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise.";
|
||||
|
||||
@ -11514,7 +11514,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> {
|
||||
def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -11529,7 +11529,7 @@ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise.";
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user