From 92b36ca4ba91aeb5d5ad60eeac72e8b8a08d0095 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 17 Aug 2020 14:12:28 -0700 Subject: [PATCH] Enable fallback lowering for following TensorFlow ops BetaincOp DepthwiseConv2dNativeBackpropFilterOp DepthwiseConv2dNativeBackpropInputOp ExtractImagePatchesOp IgammaOp IgammacOp IgammaGradOp ListDiffOp LowerBoundOp MatrixInverseOp MatrixSolveOp RollOp UpperBoundOp PiperOrigin-RevId: 327096174 Change-Id: I64e6921ed605b294f1c73ad0030b021580b66ba1 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 319 ++++++++++++++++++ .../xla/transforms/legalize_tf_with_tf2xla.cc | 14 + tensorflow/compiler/tests/BUILD | 8 + tensorflow/compiler/tests/ternary_ops_test.py | 2 - 4 files changed, 341 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 8f31c74cd7c..00e9fddfae4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -965,6 +965,40 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } +def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { + let summary = [{ +Compute the regularized incomplete beta integral \\(I_x(a, b)\\). + }]; + + let description = [{ +The regularized incomplete beta integral is defined as: + + +\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) + +where + + +\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) + + +is the incomplete beta function and \\(B(a, b)\\) is the *complete* +beta function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$b, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; @@ -2528,6 +2562,54 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DepthwiseConv2dNativeBackpropFilterOp : TF_Op<"DepthwiseConv2dNativeBackpropFilter", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the filter. + }]; + + let arguments = (ins + TF_FpTensor:$input, + I32Tensor:$filter_sizes, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_DepthwiseConv2dNativeBackpropInputOp : TF_Op<"DepthwiseConv2dNativeBackpropInput", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the input. + }]; + + let arguments = (ins + I32Tensor:$input_sizes, + TF_FpTensor:$filter, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -3235,6 +3317,27 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ExtractImagePatchesOp : TF_Op<"ExtractImagePatches", [NoSideEffect]> { + let summary = [{ +Extract `patches` from `images` and put them in the "depth" output dimension. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images, + + Confined]>:$ksizes, + Confined]>:$strides, + Confined]>:$rates, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> { let summary = "Fast Fourier transform."; @@ -4906,6 +5009,49 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<2>; } +def TF_ListDiffOp : TF_Op<"ListDiff", [NoSideEffect]> { + let summary = [{ +Computes the difference between two lists of numbers or strings. + }]; + + let description = [{ +Given a list `x` and a list `y`, this operation returns a list `out` that +represents all values that are in `x` but not in `y`. The returned list `out` +is sorted in the same order that the numbers appear in `x` (duplicates are +preserved). This operation also returns a list `idx` that represents the +position of each `out` element in `x`. In other words: + +`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` + +For example, given this input: + +``` +x = [1, 2, 3, 4, 5, 6] +y = [1, 3, 5] +``` + +This operation would return: + +``` +out ==> [2, 4, 6] +idx ==> [1, 3, 5] +``` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_Tensor:$y + ); + + let results = (outs + TF_Tensor:$out, + TF_I32OrI64Tensor:$idx + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>; +} + def TF_LogOp : TF_Op<"Log", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes natural logarithm of x element-wise."; @@ -5089,6 +5235,44 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); } +def TF_LowerBoundOp : TF_Op<"LowerBound", [NoSideEffect]> { + let summary = [{ +Applies lower_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='left')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = LowerBound(sorted_sequence, values) + + result == [[1, 2, 2], + [0, 1, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -5598,6 +5782,36 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixInverseOp : TF_Op<"MatrixInverse", [NoSideEffect]> { + let summary = [{ +Computes the inverse of one or more square invertible matrices or their adjoints (conjugate transposes). + }]; + + let description = [{ +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. The output is a tensor of the same shape as the input +containing the inverse for all input submatrices `[..., :, :]`. + +The op uses LU decomposition with partial pivoting to compute the inverses. + +If a matrix is not invertible there is no guarantee what the op does. It +may detect the condition and raise an exception or it may simply return a +garbage result. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. @@ -5849,6 +6063,32 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixSolveOp : TF_Op<"MatrixSolve", [NoSideEffect]> { + let summary = "Solves systems of linear equations."; + + let description = [{ +`Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +If `adjoint` is `True` then each output matrix satisfies +`adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> { let summary = [{ Solves systems of linear equations with upper or lower triangular matrices by backsubstitution. @@ -8352,6 +8592,47 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> { + let summary = "Rolls the elements of a tensor along an axis."; + + let description = [{ +The elements are shifted positively (towards larger indices) by the offset of +`shift` along the dimension of `axis`. Negative `shift` values will shift +elements in the opposite direction. Elements that roll passed the last position +will wrap around to the first and vice versa. Multiple shifts along multiple +axes may be specified. + +For example: + +``` +# 't' is [0, 1, 2, 3, 4] +roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] + +# shifting along multiple dimensions +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] + +# shifting along the same axis multiple times +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$shift, + TF_I32OrI64Tensor:$axis + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tshift = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -11700,6 +11981,44 @@ tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }]; } +def TF_UpperBoundOp : TF_Op<"UpperBound", [NoSideEffect]> { + let summary = [{ +Applies upper_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='right')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = UpperBound(sorted_sequence, values) + + result == [[1, 2, 4], + [0, 2, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_VarIsInitializedOp : TF_Op<"VarIsInitializedOp", []> { let summary = [{ Checks whether a resource handle-based variable has been initialized. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 93b1f5c3397..658c3528186 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -82,6 +82,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. // clang-format off + static llvm::SmallDenseSet ops = { TypeID::get(), TypeID::get(), @@ -105,6 +106,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -120,6 +122,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -129,6 +133,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -144,6 +149,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -154,13 +162,17 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -185,6 +197,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -213,6 +226,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index ea42c0ab959..ce8b02a7a06 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -392,6 +392,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -414,6 +415,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -537,6 +539,7 @@ tf_xla_py_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -636,6 +639,7 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -788,6 +792,7 @@ tf_xla_py_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -826,6 +831,7 @@ tf_xla_py_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1576,6 +1582,7 @@ tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1898,6 +1905,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 7bbfecff403..4109fdc64a5 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -214,7 +214,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -252,7 +251,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: