From 77527d0df17be03f76d2bdc70e0126fbba87caaa Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 23 Mar 2020 15:14:16 -0700 Subject: [PATCH] Auto-generate all unary and binary TensorFlow ops supported by tf2xla bridge PiperOrigin-RevId: 302528398 Change-Id: I1e1c7f4eafc6e08722a1a32126ea68b263110f69 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 865 ++++++++++++++++++ 1 file changed, 865 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 9feeee87374..10a2b4f9451 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -49,6 +49,47 @@ an output element, this operation computes \\(y = |x|\\). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes acos of x element-wise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AcoshOp : TF_Op<"Acosh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic cosine of x element-wise."; + + let description = [{ +Given an input tensor, the function computes inverse hyperbolic cosine of every element. +Input range is `[1, inf]`. It returns `nan` if the input lies outside the range. + +```python +x = tf.constant([-2, -0.5, 1, 1.2, 200, 10000, float("inf")]) +tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] +``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -149,6 +190,41 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns the argument of a complex number."; + + let description = [{ +Given a tensor `input` of complex numbers, this operation returns a tensor of +type `float` that is the argument of each element in `input`. All elements in +`input` must be complex numbers of the form \\(a + bj\\), where *a* +is the real part and *b* is the imaginary part. + +The argument returned by this operation is of the form \\(atan2(b, a)\\). + +For example: + +``` +# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +tf.angle(input) ==> [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility + }]; + + let arguments = (ins + TensorOf<[TF_Complex128, TF_Complex64]>:$input + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; +} + def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. @@ -278,6 +354,63 @@ array([b'3.14', b'2.72'], dtype=object) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AsinOp : TF_Op<"Asin", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the trignometric inverse sine of x element-wise."; + + let description = [{ +The `tf.math.asin` operation returns the inverse of `tf.math.sin`, such that +if `y = tf.math.sin(x)` then, `x = tf.math.asin(y)`. + +**Note**: The output of `tf.math.asin` will lie within the invertible range +of sine, i.e [-pi/2, pi/2]. + +For example: + +```python +# Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +x = tf.constant([1.047, 0.785]) +y = tf.math.sin(x) # [0.8659266, 0.7068252] + +tf.math.asin(y) # [1.047, 0.785] = x +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AsinhOp : TF_Op<"Asinh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic sine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes inverse hyperbolic sine + for every element in the tensor. Both input and output has a range of + `[-inf, inf]`. + + ```python + x = tf.constant([-float("inf"), -2, -0.5, 1, 1.2, 200, 10000, float("inf")]) + tf.math.asinh(x) ==> [-inf -1.4436355 -0.4812118 0.8813736 1.0159732 5.991471 9.903487 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AssertOp : TF_Op<"Assert", []> { let summary = "Asserts that the given condition is true."; @@ -354,6 +487,38 @@ this value or a subsequent newer value of the variable. TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; } +def TF_AtanOp : TF_Op<"Atan", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the trignometric inverse tangent of x element-wise."; + + let description = [{ +The `tf.math.atan` operation returns the inverse of `tf.math.tan`, such that +if `y = tf.math.tan(x)` then, `x = tf.math.atan(y)`. + +**Note**: The output of `tf.math.atan` will lie within the invertible range +of tan, i.e (-pi/2, pi/2). + +For example: + +```python +# Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +x = tf.constant([1.047, 0.785]) +y = tf.math.tan(x) # [1.731261, 0.99920404] + +tf.math.atan(y) # [1.047, 0.785] = x +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = [{ @@ -380,6 +545,33 @@ where \(r = \sqrt(x^2 + y^2) \). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AtanhOp : TF_Op<"Atanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes inverse hyperbolic tangent of x element-wise."; + + let description = [{ +Given an input tensor, this function computes inverse hyperbolic tangent + for every element in the tensor. Input range is `[-1,1]` and output range is + `[-inf, inf]`. If input is `-1`, output will be `-inf` and if the + input is `1`, output will be `inf`. Values outside the range will have + `nan` as output. + + ```python + x = tf.constant([-float("inf"), -1, -0.5, 1, 0, 0.5, 10, float("inf")]) + tf.math.atanh(x) ==> [nan -inf -0.54930615 inf 0. 0.54930615 nan nan] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> { let summary = "Performs average pooling on the input."; @@ -546,6 +738,48 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } +def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i0e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. + +This function is faster and numerically stabler than `bessel_i0(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i1e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. + +This function is faster and numerically stabler than `bessel_i1(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; @@ -748,6 +982,44 @@ for dtype in dtype_list: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BitwiseXorOp : TF_Op<"BitwiseXor", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Elementwise computes the bitwise XOR of `x` and `y`."; + + let description = [{ +The result will have those bits set, that are different in `x` and `y`. The +computation is performed on the underlying representations of `x` and `y`. + +For example: + +```python +import tensorflow as tf +from tensorflow.python.ops import bitwise_ops +dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64, + tf.uint8, tf.uint16, tf.uint32, tf.uint64] + +for dtype in dtype_list: + lhs = tf.constant([0, 5, 3, 14], dtype=dtype) + rhs = tf.constant([5, 0, 7, 11], dtype=dtype) + exp = tf.constant([5, 5, 4, 5], dtype=tf.float32) + + res = bitwise_ops.bitwise_xor(lhs, rhs) + tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE +``` + }]; + + let arguments = (ins + TF_IntTensor:$x, + TF_IntTensor:$y + ); + + let results = (outs + TF_IntTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -1235,6 +1507,31 @@ Given an input tensor, this function computes cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CoshOp : TF_Op<"Cosh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes hyperbolic cosine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes hyperbolic cosine of every + element in the tensor. Input range is `[-inf, inf]` and output range + is `[1, inf]`. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + tf.math.cosh(x) ==> [inf 4.0515420e+03 1.1276259e+00 1.5430807e+00 1.8106556e+00 3.7621956e+00 1.1013233e+04 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { let summary = "An Op to sum inputs across replicated TPU instances."; @@ -1461,6 +1758,26 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DigammaOp : TF_Op<"Digamma", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes Psi, the derivative of Lgamma (the log of the absolute value of + }]; + + let description = [{ +`Gamma(x)`), element-wise. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise."; @@ -1755,6 +2072,59 @@ tf.math.equal(x, y) ==> array([True, True]) }]; } +def TF_ErfOp : TF_Op<"Erf", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Gauss error function of `x` element-wise."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_ErfcOp : TF_Op<"Erfc", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes the complementary error function of `x` element-wise. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_ErfinvOp : TF_Op<"Erfinv", [NoSideEffect]> { + let summary = ""; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Computes exponential of x element-wise. \\(y = e^x\\). @@ -1854,6 +2224,36 @@ size 1. ]; } +def TF_Expm1Op : TF_Op<"Expm1", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes `exp(x) - 1` element-wise."; + + let description = [{ +i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. + `e` denotes Euler's number and is approximately equal to 2.718281. + + ```python + x = tf.constant(2.0) + tf.math.expm1(x) ==> 6.389056 + + x = tf.constant([2.0, 8.0]) + tf.math.expm1(x) ==> array([6.389056, 2979.958], dtype=float32) + + x = tf.constant(1 + 1j) + tf.math.expm1(x) ==> (0.46869393991588515+2.2873552871788423j) + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. @@ -2613,6 +3013,92 @@ def ApplyG(op, dy, _): TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } +def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Compute the lower regularized incomplete Gamma function `P(a, x)`. + }]; + + let description = [{ +The lower regularized incomplete Gamma function is defined as: + + +\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) + +where + +\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) + +is the lower incomplete Gamma function. + +Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +Gamma function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Computes the gradient of `igamma(a, x)` wrt `a`."; + + let description = [{ + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Compute the upper regularized incomplete Gamma function `Q(a, x)`. + }]; + + let description = [{ +The upper regularized incomplete Gamma function is defined as: + +\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) + +where + +\\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) + +is the upper incomplete Gama function. + +Note, above `P(a, x)` (`Igamma`) is the lower regularized complete +Gamma function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ImagOp : TF_Op<"Imag", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Returns the imaginary part of a complex number."; @@ -2799,6 +3285,60 @@ tf.math.is_finite(x) ==> [True, True, True, False, False] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_IsInfOp : TF_Op<"IsInf", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns which elements of x are Inf."; + + let description = [{ +@compatibility(numpy) +Equivalent to np.isinf +@end_compatibility + +Example: + +```python +x = tf.constant([5.0, np.inf, 6.8, np.inf]) +tf.math.is_inf(x) ==> [False, True, False, True] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_IsNanOp : TF_Op<"IsNan", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Returns which elements of x are NaN."; + + let description = [{ +@compatibility(numpy) +Equivalent to np.isnan +@end_compatibility + +Example: + +```python +x = tf.constant([5.0, np.nan, 6.8, np.nan, np.inf]) +tf.math.is_nan(x) ==> [False, True, False, True, False] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { let summary = "Gets the next output from the given iterator ."; @@ -3006,6 +3546,34 @@ tf.math.less_equal(x, y) ==> [True, True, True] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_LgammaOp : TF_Op<"Lgamma", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes the log of the absolute value of `Gamma(x)` element-wise. + }]; + + let description = [{ +For positive numbers, this function computes log((input - 1)!) for every element in the tensor. + `lgamma(5) = log((5-1)!) = log(4!) = log(24) = 3.1780539` + +Example: + +```python +x = tf.constant([0, 0.5, 1, 4.5, -4, -5.6]) +tf.math.lgamma(x) ==> [inf, 0.5723649, 0., 2.4537368, inf, -4.6477685] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LinSpaceOp : TF_Op<"LinSpace", [NoSideEffect]> { let summary = "Generates values in an interval."; @@ -4135,6 +4703,32 @@ graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.Tensor TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } +def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns element-wise remainder of division. This emulates C semantics in that + }]; + + let description = [{ +the result here is consistent with a truncating divide. E.g. +`tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. + +*NOTE*: `Mod` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$x, + TF_FpOrI32OrI64Tensor:$y + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -4179,6 +4773,23 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { + let summary = ""; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_NegOp : TF_Op<"Neg", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes numerical negative value element-wise."; @@ -4862,6 +5473,27 @@ the dimension is padded with zeros. TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>; } +def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Computes the derivative of a Gamma random sample w.r.t. `alpha`. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$alpha, + TF_F32OrF64Tensor:$sample + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RandomShuffleOp : TF_Op<"RandomShuffle", [SameOperandsAndResultType]> { let summary = "Randomly shuffles a tensor along its first dimension."; @@ -5108,6 +5740,26 @@ I.e., \\(y = 1 / x\\). let hasCanonicalizer = 1; } +def TF_ReciprocalGradOp : TF_Op<"ReciprocalGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient for the inverse of `x` wrt its input."; + + let description = [{ +Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; @@ -5632,6 +6284,32 @@ bitwise_ops.right_shift(lhs, rhs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns element-wise integer closest to x."; + + let description = [{ +If the result is midway between two representable values, +the even representable is chosen. +For example: + +``` +rint(-1.5) ==> -2.0 +rint(0.5000001) ==> 1.0 +rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] +``` + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -6057,6 +6735,26 @@ Specifically, `y = 1 / (1 + exp(-x))`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SigmoidGradOp : TF_Op<"SigmoidGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient of the sigmoid of `x` wrt its input."; + + let description = [{ +Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +`dy` is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns an element-wise indication of the sign of a number."; @@ -6106,6 +6804,31 @@ Given an input tensor, this function computes sine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SinhOp : TF_Op<"Sinh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes hyperbolic sine of x element-wise."; + + let description = [{ +Given an input tensor, this function computes hyperbolic sine of every + element in the tensor. Input range is `[-inf,inf]` and output range + is `[-inf,inf]`. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + tf.math.sinh(x) ==> [-inf -4.0515420e+03 -5.2109528e-01 1.1752012e+00 1.5094614e+00 3.6268604e+00 1.1013232e+04 inf] + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SizeOp : TF_Op<"Size", [NoSideEffect]> { let summary = "Returns the size of a tensor."; @@ -6251,6 +6974,59 @@ def TF_SoftplusOp : TF_Op<"Softplus", [NoSideEffect, SameOperandsAndResultType]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SoftplusGradOp : TF_Op<"SoftplusGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softplus gradients for a softplus operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SoftsignOp : TF_Op<"Softsign", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softsign: `features / (abs(features) + 1)`."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$activations + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes softsign gradients for a softsign operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> { let summary = "SpaceToBatch for N-D tensors of type T."; @@ -7137,6 +7913,32 @@ variables. TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_TanOp : TF_Op<"Tan", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes tan of x element-wise."; + + let description = [{ +Given an input tensor, this function computes tangent of every + element in the tensor. Input range is `(-inf, inf)` and + output range is `(-inf, inf)`. If input lies outside the boundary, `nan` + is returned. + + ```python + x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")]) + tf.math.tan(x) ==> [nan 0.45231566 -0.5463025 1.5574077 2.572152 -1.7925274 0.32097113 nan] + ``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes hyperbolic tangent of `x` element-wise."; @@ -7969,6 +8771,32 @@ Python Semantics. let hasCanonicalizer = 1; } +def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns element-wise remainder of division. This emulates C semantics in that + }]; + + let description = [{ +the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +y + truncate_mod(x, y) = x`. + +*NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$x, + TF_FpOrI32OrI64Tensor:$y + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { let summary = "Finds unique elements in a 1-D tensor."; @@ -8421,6 +9249,43 @@ An op which shards the input based on the given sharding attribute. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> { + let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x.";