Sync all TensorFlow ODS op defs with TensorFlow registry
Manual changes: * Added explicit check on element type for Equal and NotEqual lowering to HLO * Moved Verifier declaration for BatchToSpace op PiperOrigin-RevId: 343631256 Change-Id: Iedf0a88898375874a00964220c423ff5056deb76
This commit is contained in:
parent
c0f7434dec
commit
2d263ad1ca
@ -31,7 +31,7 @@ limitations under the License.
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the absolute value of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
|
||||
@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise smallest integer not less than x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||
TF_Tensor:$x,
|
||||
TF_Tensor:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
||||
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$gradients,
|
||||
F32Tensor:$inputs,
|
||||
TF_Float32Tensor:$gradients,
|
||||
TF_Float32Tensor:$inputs,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
|
||||
DefaultValuedAttr<F32Attr, "6.0f">:$max,
|
||||
@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$backprops
|
||||
TF_Float32Tensor:$backprops
|
||||
);
|
||||
}
|
||||
|
||||
@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
|
||||
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$gradients,
|
||||
F32Tensor:$inputs,
|
||||
F32Tensor:$min,
|
||||
F32Tensor:$max,
|
||||
TF_Float32Tensor:$gradients,
|
||||
TF_Float32Tensor:$inputs,
|
||||
TF_Float32Tensor:$min,
|
||||
TF_Float32Tensor:$max,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$backprops_wrt_input,
|
||||
F32Tensor:$backprop_wrt_min,
|
||||
F32Tensor:$backprop_wrt_max
|
||||
TF_Float32Tensor:$backprops_wrt_input,
|
||||
TF_Float32Tensor:$backprop_wrt_min,
|
||||
TF_Float32Tensor:$backprop_wrt_max
|
||||
);
|
||||
}
|
||||
|
||||
@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
|
||||
];
|
||||
}
|
||||
|
||||
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise largest integer not greater than x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$predictions,
|
||||
TF_Float32Tensor:$predictions,
|
||||
TF_I32OrI64Tensor:$targets,
|
||||
TF_I32OrI64Tensor:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$precision
|
||||
TF_BoolTensor:$precision
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
@ -7855,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||
TF_Tensor:$x,
|
||||
TF_Tensor:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
@ -8034,7 +8034,7 @@ times by rerunning "MakeIterator".
|
||||
);
|
||||
}
|
||||
|
||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of ones with the same shape and type as x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -8432,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
|
||||
Computes the QR decomposition of each inner matrix in `tensor` such that
|
||||
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
|
||||
|
||||
Currently, the gradient for the QR decomposition is well-defined only when
|
||||
the first `P` columns of the inner matrix are linearly independent, where
|
||||
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
|
||||
|
||||
```python
|
||||
# a is a tensor.
|
||||
# q is a tensor of orthonormal matrices.
|
||||
@ -9117,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
|
||||
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
||||
def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||
|
||||
let description = [{
|
||||
@ -9143,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -10541,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise integer closest to x.";
|
||||
|
||||
let description = [{
|
||||
@ -10608,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
|
||||
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}];
|
||||
@ -11338,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns an element-wise indication of the sign of a number.";
|
||||
|
||||
let description = [{
|
||||
@ -12348,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
||||
}
|
||||
|
||||
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
|
||||
let summary = [{
|
||||
Picks the best counter-based RNG algorithm based on device.
|
||||
}];
|
||||
let summary = "Picks the best counter-based RNG algorithm based on device.";
|
||||
|
||||
let description = [{
|
||||
This op picks the best counter-based RNG algorithm based on device.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
TF_Int32Tensor:$alg
|
||||
);
|
||||
@ -14091,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
|
||||
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
|
||||
for the existing tensor cannot be re-used, a copy is made and updated.
|
||||
|
||||
If `indices` contains duplicates, then their updates are accumulated (summed).
|
||||
If `indices` contains duplicates, then we pick the last update for the index.
|
||||
|
||||
**WARNING**: The order in which updates are applied is nondeterministic, so the
|
||||
output will be nondeterministic if `indices` contains duplicates -- because
|
||||
of some numerical approximation issues, numbers summed in different order
|
||||
may yield different results.
|
||||
If an out of bound index is found on CPU, an error is returned.
|
||||
|
||||
**WARNING**: There are some GPU specific semantics for this operation.
|
||||
- If an out of bound index is found, the index is ignored.
|
||||
- The order in which updates are applied is nondeterministic, so the output
|
||||
will be nondeterministic if `indices` contains duplicates.
|
||||
|
||||
`indices` is an integer tensor containing indices into a new tensor of shape
|
||||
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
|
||||
`shape`.
|
||||
|
||||
indices.shape[-1] <= shape.rank
|
||||
* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
|
||||
* The last axis of `indices` is how deep to index into `tensor` so this index
|
||||
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
|
||||
|
||||
The last dimension of `indices` corresponds to indices into elements
|
||||
(if `indices.shape[-1] = shape.rank`) or slices
|
||||
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
|
||||
`shape`. `updates` is a tensor with shape
|
||||
if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
|
||||
if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
|
||||
`tensor`.
|
||||
|
||||
indices.shape[:-1] + shape[indices.shape[-1]:]
|
||||
Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
|
||||
The overall shape of `updates` is:
|
||||
|
||||
The simplest form of scatter is to insert individual elements in a tensor by
|
||||
index. For example, say we want to insert 4 scattered elements in a rank-1
|
||||
tensor with 8 elements.
|
||||
```
|
||||
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
|
||||
</div>
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
>>> indices = tf.constant([[4], [3], [1], [7]])
|
||||
>>> updates = tf.constant([9, 10, 11, 12])
|
||||
>>> tensor = tf.ones([8], dtype=tf.int32)
|
||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
|
||||
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
|
||||
|
||||
We can also, insert entire slices of a higher rank tensor all at once. For
|
||||
example, if we wanted to insert two slices in the first dimension of a
|
||||
rank-3 tensor with two matrices of new values.
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
>>> indices = tf.constant([[0], [2]])
|
||||
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
... [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
... [[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
... [7, 7, 7, 7], [8, 8, 8, 8]]])
|
||||
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
|
||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
|
||||
[[[5 5 5 5]
|
||||
[6 6 6 6]
|
||||
[7 7 7 7]
|
||||
[8 8 8 8]]
|
||||
[[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]]
|
||||
[[5 5 5 5]
|
||||
[6 6 6 6]
|
||||
[7 7 7 7]
|
||||
[8 8 8 8]]
|
||||
[[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]]]
|
||||
|
||||
Note that on CPU, if an out of bound index is found, an error is returned.
|
||||
On GPU, if an out of bound index is found, the index is ignored.
|
||||
For usage examples see the python [tf.tensor_scatter_nd_update](
|
||||
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -15083,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
||||
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
|
||||
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
|
||||
|
||||
@ -15092,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
@ -15460,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of zeros with the same shape and type as x.";
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -262,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal_unsupported_type
|
||||
func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> {
|
||||
// CHECK: "tf.Equal"
|
||||
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @notequal
|
||||
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}
|
||||
|
@ -224,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
|
||||
(HLOClient_BroadcastCompareOp
|
||||
$l, $r, (BinBroadcastDimensions $l, $r), direction,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
[(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>;
|
||||
|
||||
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
|
||||
def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;
|
||||
|
Loading…
x
Reference in New Issue
Block a user