diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 1077a9ac472..76e8836fef1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -31,7 +31,7 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "mlir/Interfaces/InferTypeOpInterface.td" -def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { +def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the absolute value of a tensor."; let description = [{ @@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description. TF_Tensor:$output ); - let verifier = [{ - return Verify(*this); - }]; - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { @@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let hasFolder = 1; } -def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { +def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise smallest integer not less than x."; let arguments = (ins @@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TF_Tensor:$x, + TF_Tensor:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation."; let arguments = (ins - F32Tensor:$gradients, - F32Tensor:$inputs, + TF_Float32Tensor:$gradients, + TF_Float32Tensor:$inputs, DefaultValuedAttr:$min, DefaultValuedAttr:$max, @@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien ); let results = (outs - F32Tensor:$backprops + TF_Float32Tensor:$backprops ); } @@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation."; let arguments = (ins - F32Tensor:$gradients, - F32Tensor:$inputs, - F32Tensor:$min, - F32Tensor:$max, + TF_Float32Tensor:$gradients, + TF_Float32Tensor:$inputs, + TF_Float32Tensor:$min, + TF_Float32Tensor:$max, DefaultValuedAttr:$num_bits, DefaultValuedAttr:$narrow_range ); let results = (outs - F32Tensor:$backprops_wrt_input, - F32Tensor:$backprop_wrt_min, - F32Tensor:$backprop_wrt_max + TF_Float32Tensor:$backprops_wrt_input, + TF_Float32Tensor:$backprop_wrt_min, + TF_Float32Tensor:$backprop_wrt_max ); } @@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9] ]; } -def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { +def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise largest integer not greater than x."; let arguments = (ins @@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ }]; let arguments = (ins - F32Tensor:$predictions, + TF_Float32Tensor:$predictions, TF_I32OrI64Tensor:$targets, TF_I32OrI64Tensor:$k ); let results = (outs - I1Tensor:$precision + TF_BoolTensor:$precision ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -7855,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TF_Tensor:$x, + TF_Tensor:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -8034,7 +8034,7 @@ times by rerunning "MakeIterator". ); } -def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> { +def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of ones with the same shape and type as x."; let arguments = (ins @@ -8432,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> { Computes the QR decomposition of each inner matrix in `tensor` such that `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +Currently, the gradient for the QR decomposition is well-defined only when +the first `P` columns of the inner matrix are linearly independent, where +`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`. + ```python # a is a tensor. # q is a tensor of orthonormal matrices. @@ -9117,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph. TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; } -def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { +def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; let description = [{ @@ -9143,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32) }]; } -def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { +def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`."; let arguments = (ins @@ -10541,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> { +def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise integer closest to x."; let description = [{ @@ -10608,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; } -def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { +def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. }]; @@ -11338,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> { +def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns an element-wise indication of the sign of a number."; let description = [{ @@ -12348,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`. } def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> { - let summary = [{ -Picks the best counter-based RNG algorithm based on device. - }]; + let summary = "Picks the best counter-based RNG algorithm based on device."; let description = [{ This op picks the best counter-based RNG algorithm based on device. }]; + let arguments = (ins); + let results = (outs TF_Int32Tensor:$alg ); @@ -14091,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are scattered onto an existing tensor (as opposed to a zero-tensor). If the memory for the existing tensor cannot be re-used, a copy is made and updated. -If `indices` contains duplicates, then their updates are accumulated (summed). +If `indices` contains duplicates, then we pick the last update for the index. -**WARNING**: The order in which updates are applied is nondeterministic, so the -output will be nondeterministic if `indices` contains duplicates -- because -of some numerical approximation issues, numbers summed in different order -may yield different results. +If an out of bound index is found on CPU, an error is returned. + +**WARNING**: There are some GPU specific semantics for this operation. +- If an out of bound index is found, the index is ignored. +- The order in which updates are applied is nondeterministic, so the output +will be nondeterministic if `indices` contains duplicates. `indices` is an integer tensor containing indices into a new tensor of shape -`shape`. The last dimension of `indices` can be at most the rank of `shape`: +`shape`. - indices.shape[-1] <= shape.rank +* `indices` must have at least 2 axes: `(num_updates, index_depth)`. +* The last axis of `indices` is how deep to index into `tensor` so this index + depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim` -The last dimension of `indices` corresponds to indices into elements -(if `indices.shape[-1] = shape.rank`) or slices -(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -`shape`. `updates` is a tensor with shape +if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements. +if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input +`tensor`. - indices.shape[:-1] + shape[indices.shape[-1]:] +Each `update` has a rank of `tensor.rank - indices.shape[-1]`. +The overall shape of `updates` is: -The simplest form of scatter is to insert individual elements in a tensor by -index. For example, say we want to insert 4 scattered elements in a rank-1 -tensor with 8 elements. +``` +indices.shape[:-1] + tensor.shape[indices.shape[-1]:] +``` -
- -
- -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[4], [3], [1], [7]]) - >>> updates = tf.constant([9, 10, 11, 12]) - >>> tensor = tf.ones([8], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) - tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32) - -We can also, insert entire slices of a higher rank tensor all at once. For -example, if we wanted to insert two slices in the first dimension of a -rank-3 tensor with two matrices of new values. - -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[0], [2]]) - >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]], - ... [[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]]]) - >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()) - [[[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]] - [[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]]] - -Note that on CPU, if an out of bound index is found, an error is returned. -On GPU, if an out of bound index is found, the index is ignored. +For usage examples see the python [tf.tensor_scatter_nd_update]( +https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function }]; let arguments = (ins @@ -15083,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather }]; let arguments = (ins - Arg, [{The array we're gathering from.}]>:$operand, + Arg, [{The array we're gathering from.}]>:$operand, Arg:$start_indices, Arg:$slice_sizes, @@ -15092,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; @@ -15460,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; let arguments = (ins diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 230db70817f..f24b1b6e6bc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -262,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> return %0: tensor<*xi1> } +// CHECK-LABEL: func @equal_unsupported_type +func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + // CHECK-LABEL: func @notequal func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 9518bbc6a41..113d88158d2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -224,7 +224,7 @@ class EqualityPat (HLOClient_BroadcastCompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction, (HLO_DEFAULT_COMPARISON_TYPE)), - [(AreBroadcastCompatible $l, $r)]>; + [(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>; def : EqualityPat; def : EqualityPat;