Sync all TensorFlow ODS op defs with TensorFlow registry

Manual changes:
* Added explicit check on element type for Equal and NotEqual lowering to HLO
* Moved Verifier declaration for BatchToSpace op

PiperOrigin-RevId: 343631256
Change-Id: Iedf0a88898375874a00964220c423ff5056deb76
This commit is contained in:
Smit Hinsu 2020-11-21 01:44:50 -08:00 committed by TensorFlower Gardener
parent c0f7434dec
commit 2d263ad1ca
3 changed files with 68 additions and 95 deletions

View File

@ -31,7 +31,7 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the absolute value of a tensor.";
let description = [{
@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
TF_Tensor:$output
);
let verifier = [{
return Verify(*this);
}];
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
let hasFolder = 1;
}
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise smallest integer not less than x.";
let arguments = (ins
@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
}];
let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
TF_Tensor:$x,
TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
let arguments = (ins
F32Tensor:$gradients,
F32Tensor:$inputs,
TF_Float32Tensor:$gradients,
TF_Float32Tensor:$inputs,
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
DefaultValuedAttr<F32Attr, "6.0f">:$max,
@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
);
let results = (outs
F32Tensor:$backprops
TF_Float32Tensor:$backprops
);
}
@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
let arguments = (ins
F32Tensor:$gradients,
F32Tensor:$inputs,
F32Tensor:$min,
F32Tensor:$max,
TF_Float32Tensor:$gradients,
TF_Float32Tensor:$inputs,
TF_Float32Tensor:$min,
TF_Float32Tensor:$max,
DefaultValuedAttr<I64Attr, "8">:$num_bits,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
);
let results = (outs
F32Tensor:$backprops_wrt_input,
F32Tensor:$backprop_wrt_min,
F32Tensor:$backprop_wrt_max
TF_Float32Tensor:$backprops_wrt_input,
TF_Float32Tensor:$backprop_wrt_min,
TF_Float32Tensor:$backprop_wrt_max
);
}
@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
];
}
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise largest integer not greater than x.";
let arguments = (ins
@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
}];
let arguments = (ins
F32Tensor:$predictions,
TF_Float32Tensor:$predictions,
TF_I32OrI64Tensor:$targets,
TF_I32OrI64Tensor:$k
);
let results = (outs
I1Tensor:$precision
TF_BoolTensor:$precision
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
@ -7855,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
TF_Tensor:$x,
TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
@ -8034,7 +8034,7 @@ times by rerunning "MakeIterator".
);
}
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of ones with the same shape and type as x.";
let arguments = (ins
@ -8432,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
Currently, the gradient for the QR decomposition is well-defined only when
the first `P` columns of the inner matrix are linearly independent, where
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
```python
# a is a tensor.
# q is a tensor of orthonormal matrices.
@ -9117,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
}
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
let summary = "Computes rectified linear: `max(features, 0)`.";
let description = [{
@ -9143,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
}];
}
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
let arguments = (ins
@ -10541,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise integer closest to x.";
let description = [{
@ -10608,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
}
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Rounds the values of a tensor to the nearest integer, element-wise.
}];
@ -11338,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns an element-wise indication of the sign of a number.";
let description = [{
@ -12348,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
}
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
let summary = [{
Picks the best counter-based RNG algorithm based on device.
}];
let summary = "Picks the best counter-based RNG algorithm based on device.";
let description = [{
This op picks the best counter-based RNG algorithm based on device.
}];
let arguments = (ins);
let results = (outs
TF_Int32Tensor:$alg
);
@ -14091,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then their updates are accumulated (summed).
If `indices` contains duplicates, then we pick the last update for the index.
**WARNING**: The order in which updates are applied is nondeterministic, so the
output will be nondeterministic if `indices` contains duplicates -- because
of some numerical approximation issues, numbers summed in different order
may yield different results.
If an out of bound index is found on CPU, an error is returned.
**WARNING**: There are some GPU specific semantics for this operation.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
will be nondeterministic if `indices` contains duplicates.
`indices` is an integer tensor containing indices into a new tensor of shape
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
`shape`.
indices.shape[-1] <= shape.rank
* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
* The last axis of `indices` is how deep to index into `tensor` so this index
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
The last dimension of `indices` corresponds to indices into elements
(if `indices.shape[-1] = shape.rank`) or slices
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
`shape`. `updates` is a tensor with shape
if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
`tensor`.
indices.shape[:-1] + shape[indices.shape[-1]:]
Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
The overall shape of `updates` is:
The simplest form of scatter is to insert individual elements in a tensor by
index. For example, say we want to insert 4 scattered elements in a rank-1
tensor with 8 elements.
```
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
```
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
</div>
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[4], [3], [1], [7]])
>>> updates = tf.constant([9, 10, 11, 12])
>>> tensor = tf.ones([8], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
We can also, insert entire slices of a higher rank tensor all at once. For
example, if we wanted to insert two slices in the first dimension of a
rank-3 tensor with two matrices of new values.
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[0], [2]])
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]],
... [[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]]])
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
[[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]
[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]]
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
For usage examples see the python [tf.tensor_scatter_nd_update](
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
}];
let arguments = (ins
@ -15083,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
}];
let arguments = (ins
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
@ -15092,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
);
let results = (outs
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
@ -15460,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of zeros with the same shape and type as x.";
let arguments = (ins

View File

@ -262,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @equal_unsupported_type
func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> {
// CHECK: "tf.Equal"
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @notequal
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}

View File

@ -224,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
(HLOClient_BroadcastCompareOp
$l, $r, (BinBroadcastDimensions $l, $r), direction,
(HLO_DEFAULT_COMPARISON_TYPE)),
[(AreBroadcastCompatible $l, $r)]>;
[(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>;
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;