Enable fallback lowering for following TensorFlow ops

BetaincOp
DepthwiseConv2dNativeBackpropFilterOp
DepthwiseConv2dNativeBackpropInputOp
ExtractImagePatchesOp
IgammaOp
IgammacOp
IgammaGradOp
ListDiffOp
LowerBoundOp
MatrixInverseOp
MatrixSolveOp
RollOp
UpperBoundOp

PiperOrigin-RevId: 327096174
Change-Id: I64e6921ed605b294f1c73ad0030b021580b66ba1
This commit is contained in:
Smit Hinsu 2020-08-17 14:12:28 -07:00 committed by TensorFlower Gardener
parent c6769e20bf
commit 92b36ca4ba
4 changed files with 341 additions and 2 deletions
tensorflow/compiler

View File

@ -965,6 +965,40 @@ reverse of SpaceToBatch. See below for a precise description.
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
}
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
let summary = [{
Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
}];
let description = [{
The regularized incomplete beta integral is defined as:
\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\)
where
\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\)
is the incomplete beta function and \\(B(a, b)\\) is the *complete*
beta function.
}];
let arguments = (ins
TF_F32OrF64Tensor:$a,
TF_F32OrF64Tensor:$b,
TF_F32OrF64Tensor:$x
);
let results = (outs
TF_F32OrF64Tensor:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> {
let summary = "Adds `bias` to `value`.";
@ -2528,6 +2562,54 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DepthwiseConv2dNativeBackpropFilterOp : TF_Op<"DepthwiseConv2dNativeBackpropFilter", [NoSideEffect]> {
let summary = [{
Computes the gradients of depthwise convolution with respect to the filter.
}];
let arguments = (ins
TF_FpTensor:$input,
I32Tensor:$filter_sizes,
TF_FpTensor:$out_backprop,
I64ArrayAttr:$strides,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DepthwiseConv2dNativeBackpropInputOp : TF_Op<"DepthwiseConv2dNativeBackpropInput", [NoSideEffect]> {
let summary = [{
Computes the gradients of depthwise convolution with respect to the input.
}];
let arguments = (ins
I32Tensor:$input_sizes,
TF_FpTensor:$filter,
TF_FpTensor:$out_backprop,
I64ArrayAttr:$strides,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
let summary = "Return the index of device the op runs.";
@ -3235,6 +3317,27 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ExtractImagePatchesOp : TF_Op<"ExtractImagePatches", [NoSideEffect]> {
let summary = [{
Extract `patches` from `images` and put them in the "depth" output dimension.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$ksizes,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$strides,
Confined<I64ArrayAttr, [ArrayMinCount<4>]>:$rates,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> {
let summary = "Fast Fourier transform.";
@ -4906,6 +5009,49 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<2>;
}
def TF_ListDiffOp : TF_Op<"ListDiff", [NoSideEffect]> {
let summary = [{
Computes the difference between two lists of numbers or strings.
}];
let description = [{
Given a list `x` and a list `y`, this operation returns a list `out` that
represents all values that are in `x` but not in `y`. The returned list `out`
is sorted in the same order that the numbers appear in `x` (duplicates are
preserved). This operation also returns a list `idx` that represents the
position of each `out` element in `x`. In other words:
`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]`
For example, given this input:
```
x = [1, 2, 3, 4, 5, 6]
y = [1, 3, 5]
```
This operation would return:
```
out ==> [2, 4, 6]
idx ==> [1, 3, 5]
```
}];
let arguments = (ins
TF_Tensor:$x,
TF_Tensor:$y
);
let results = (outs
TF_Tensor:$out,
TF_I32OrI64Tensor:$idx
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>;
}
def TF_LogOp : TF_Op<"Log", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes natural logarithm of x element-wise.";
@ -5089,6 +5235,44 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> {
);
}
def TF_LowerBoundOp : TF_Op<"LowerBound", [NoSideEffect]> {
let summary = [{
Applies lower_bound(sorted_search_values, values) along each row.
}];
let description = [{
Each set of rows with the same index in (sorted_inputs, values) is treated
independently. The resulting row is the equivalent of calling
`np.searchsorted(sorted_inputs, values, side='left')`.
The result is not a global index to the entire
`Tensor`, but rather just the index in the last dimension.
A 2-D example:
sorted_sequence = [[0, 3, 9, 9, 10],
[1, 2, 3, 4, 5]]
values = [[2, 4, 9],
[0, 2, 6]]
result = LowerBound(sorted_sequence, values)
result == [[1, 2, 2],
[0, 1, 5]]
}];
let arguments = (ins
TF_Tensor:$sorted_inputs,
TF_Tensor:$values
);
let results = (outs
TF_I32OrI64Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
}
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = [{
Multiply the matrix "a" by the matrix "b".
@ -5598,6 +5782,36 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixInverseOp : TF_Op<"MatrixInverse", [NoSideEffect]> {
let summary = [{
Computes the inverse of one or more square invertible matrices or their adjoints (conjugate transposes).
}];
let description = [{
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. The output is a tensor of the same shape as the input
containing the inverse for all input submatrices `[..., :, :]`.
The op uses LU decomposition with partial pivoting to compute the inverses.
If a matrix is not invertible there is no guarantee what the op does. It
may detect the condition and raise an exception or it may simply return a
garbage result.
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input,
DefaultValuedAttr<BoolAttr, "false">:$adjoint
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> {
let summary = [{
Returns a batched matrix tensor with new batched diagonal values.
@ -5849,6 +6063,32 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT")
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixSolveOp : TF_Op<"MatrixSolve", [NoSideEffect]> {
let summary = "Solves systems of linear equations.";
let description = [{
`Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `True` then each output matrix satisfies
`adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix,
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs,
DefaultValuedAttr<BoolAttr, "false">:$adjoint
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> {
let summary = [{
Solves systems of linear equations with upper or lower triangular matrices by backsubstitution.
@ -8352,6 +8592,47 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
let summary = "Rolls the elements of a tensor along an axis.";
let description = [{
The elements are shifted positively (towards larger indices) by the offset of
`shift` along the dimension of `axis`. Negative `shift` values will shift
elements in the opposite direction. Elements that roll passed the last position
will wrap around to the first and vice versa. Multiple shifts along multiple
axes may be specified.
For example:
```
# 't' is [0, 1, 2, 3, 4]
roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]
# shifting along multiple dimensions
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]
# shifting along the same axis multiple times
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
```
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$shift,
TF_I32OrI64Tensor:$axis
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr Tshift = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
}
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Rounds the values of a tensor to the nearest integer, element-wise.
@ -11700,6 +11981,44 @@ tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }];
}
def TF_UpperBoundOp : TF_Op<"UpperBound", [NoSideEffect]> {
let summary = [{
Applies upper_bound(sorted_search_values, values) along each row.
}];
let description = [{
Each set of rows with the same index in (sorted_inputs, values) is treated
independently. The resulting row is the equivalent of calling
`np.searchsorted(sorted_inputs, values, side='right')`.
The result is not a global index to the entire
`Tensor`, but rather just the index in the last dimension.
A 2-D example:
sorted_sequence = [[0, 3, 9, 9, 10],
[1, 2, 3, 4, 5]]
values = [[2, 4, 9],
[0, 2, 6]]
result = UpperBound(sorted_sequence, values)
result == [[1, 2, 4],
[0, 2, 5]]
}];
let arguments = (ins
TF_Tensor:$sorted_inputs,
TF_Tensor:$values
);
let results = (outs
TF_I32OrI64Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
}
def TF_VarIsInitializedOp : TF_Op<"VarIsInitializedOp", []> {
let summary = [{
Checks whether a resource handle-based variable has been initialized.

View File

@ -82,6 +82,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
// all tf2xla kernels.
// clang-format off
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
TypeID::get<TF::AbsOp>(),
TypeID::get<TF::AcoshOp>(),
@ -105,6 +106,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::BatchToSpaceOp>(),
TypeID::get<TF::BesselI0eOp>(),
TypeID::get<TF::BesselI1eOp>(),
TypeID::get<TF::BetaincOp>(),
TypeID::get<TF::BiasAddGradOp>(),
TypeID::get<TF::BiasAddOp>(),
TypeID::get<TF::BitwiseAndOp>(),
@ -120,6 +122,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::CrossOp>(),
TypeID::get<TF::DataFormatDimMapOp>(),
TypeID::get<TF::DataFormatVecPermuteOp>(),
TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
TypeID::get<TF::DiagOp>(),
TypeID::get<TF::DigammaOp>(),
TypeID::get<TF::DivNoNanOp>(),
@ -129,6 +133,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfOp>(),
TypeID::get<TF::Expm1Op>(),
TypeID::get<TF::ExtractImagePatchesOp>(),
TypeID::get<TF::FFT2DOp>(),
TypeID::get<TF::FFT3DOp>(),
TypeID::get<TF::FFTOp>(),
@ -144,6 +149,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::IRFFT2DOp>(),
TypeID::get<TF::IRFFT3DOp>(),
TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::IgammaOp>(),
TypeID::get<TF::IgammacOp>(),
TypeID::get<TF::IgammaGradAOp>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
@ -154,13 +162,17 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::LessEqualOp>(),
TypeID::get<TF::LessOp>(),
TypeID::get<TF::LgammaOp>(),
TypeID::get<TF::ListDiffOp>(),
TypeID::get<TF::LogicalAndOp>(),
TypeID::get<TF::LogicalNotOp>(),
TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(),
TypeID::get<TF::LowerBoundOp>(),
TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MatrixDiagV3Op>(),
TypeID::get<TF::MatrixInverseOp>(),
TypeID::get<TF::MatrixSetDiagV3Op>(),
TypeID::get<TF::MatrixSolveOp>(),
TypeID::get<TF::MatrixTriangularSolveOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(),
@ -185,6 +197,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),
TypeID::get<TF::RollOp>(),
TypeID::get<TF::RoundOp>(),
TypeID::get<TF::SelectV2Op>(),
TypeID::get<TF::SelfAdjointEigV2Op>(),
@ -213,6 +226,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::TruncatedNormalOp>(),
TypeID::get<TF::TruncateModOp>(),
TypeID::get<TF::UnpackOp>(),
TypeID::get<TF::UpperBoundOp>(),
TypeID::get<TF::XdivyOp>(),
TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(),

View File

@ -392,6 +392,7 @@ tf_xla_py_test(
size = "small",
timeout = "moderate",
srcs = ["matrix_inverse_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -414,6 +415,7 @@ tf_xla_py_test(
size = "small",
timeout = "moderate",
srcs = ["matrix_solve_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -537,6 +539,7 @@ tf_xla_py_test(
name = "depthwise_conv_op_test",
size = "medium",
srcs = ["depthwise_conv_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
@ -636,6 +639,7 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -788,6 +792,7 @@ tf_xla_py_test(
name = "listdiff_op_test",
size = "small",
srcs = ["listdiff_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -826,6 +831,7 @@ tf_xla_py_test(
name = "manip_ops_test",
size = "small",
srcs = ["manip_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1576,6 +1582,7 @@ tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1898,6 +1905,7 @@ tf_xla_py_test(
name = "special_math_test",
size = "medium",
srcs = ["special_math_test.py"],
enable_mlir_bridge = True,
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -214,7 +214,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
upper,
expected=np.minimum(np.maximum(x, lower), upper))
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
def testBetaincSanity(self):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}:
@ -252,7 +251,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
'atol': 2e-4
},
)
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
def testBetainc(self, sigma, rtol, atol):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}: