Legalize TensorFlow Cumprod op to HLO

Also, add verifier for Cumsum op to reject illegal axis value.

GetScalarConstOfType doesn't support non int and float element types so reject ops with other element types.

PiperOrigin-RevId: 326787027
Change-Id: I54afe4e494d711fa873b6329391603fbd8958c88
This commit is contained in:
Smit Hinsu 2020-08-14 22:54:05 -07:00 committed by TensorFlower Gardener
parent 71ff216576
commit a3d2b5a910
7 changed files with 154 additions and 12 deletions

View File

@ -2068,6 +2068,62 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = [{
Compute the cumulative product of the tensor `x` along `axis`.
}];
let description = [{
By default, this op performs an inclusive cumprod, which means that the first
element of the input is identical to the first element of the output:
```python
tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
```
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
performed instead:
```python
tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
```
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
opposite direction:
```python
tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
```
This is more efficient than using separate `tf.reverse` ops.
The `reverse` and `exclusive` kwargs can also be combined:
```python
tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
```
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TF_I32OrI64Tensor:$axis,
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
DefaultValuedAttr<BoolAttr, "false">:$reverse
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
@ -2116,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {

View File

@ -774,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
//===----------------------------------------------------------------------===//
// CumsumOp and CumprodOp
//===----------------------------------------------------------------------===//
template <typename OpT, typename std::enable_if<llvm::is_one_of<
OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
static LogicalResult Verify(OpT op) {
if (!IsOfRankOrUnranked(op.axis(), 0))
return op.emitOpError("requires scalar axis operand");
DenseIntElementsAttr axis_attr;
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
if (input_ty) {
int64_t rank = input_ty.getRank();
assert(axis_attr.getNumElements() == 1 &&
"scalar attribute should have exactly one element");
int64_t axis = (*axis_attr.begin()).getSExtValue();
if (axis < -rank || axis >= rank) {
return op.emitError()
<< "axis operand should be within range [" << -rank << ", "
<< rank << "); actual value: " << axis;
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ConcatOffsetOp
//===----------------------------------------------------------------------===//

View File

@ -3373,3 +3373,28 @@ func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32>
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
return
}
// -----
// Test valid tf.Cumsum
func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> {
%0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> {
// expected-error @+1 {{requires scalar axis operand}}
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
%axis = constant dense<-3> : tensor<i32>
// expected-error @+1 {{axis operand should be within range [-2, 2)}}
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}

View File

@ -4668,6 +4668,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
//===----------------------------------------------------------------------===//
// Cumprod op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @cumprod
func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
// CHECK: mhlo.mul
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}
//===----------------------------------------------------------------------===//
// Qr op legalization
//===----------------------------------------------------------------------===//

View File

@ -5092,17 +5092,19 @@ class ConvertXlaDynamicUpdateSliceOp
}
};
/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting
/// appropriate window dimensions, with 'add' as the reduction function. The
/// input tensor needs to have a static shape, and 'axis' must be const. The
/// TableGen pattern is not used for this rewrite because it involves regions.
class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
using OpRewritePattern<TF::CumsumOp>::OpRewritePattern;
// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
// setting appropriate window dimensions, with the given aggregation op as the
// reduction function. The input tensor needs to have a static shape, and 'axis'
// must be const. The TableGen pattern is not used for this rewrite because it
// involves regions.
template <typename OpT, typename AggregationOp>
class ConvertCumOp : public OpRewritePattern<OpT> {
using OpRewritePattern<OpT>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::CumsumOp op,
LogicalResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
auto input = op.x();
auto input_type = input.getType().dyn_cast<ShapedType>();
auto input_type = input.getType().template dyn_cast<ShapedType>();
if (!input_type || !input_type.hasStaticShape()) {
return failure();
}
@ -5135,6 +5137,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
// Convert if we need to enlarge the element type's bitwidth to avoid
// precision loss.
Type input_element_type = input_type.getElementType();
// TODO(hinsu): Handle complex element types.
if (!input_element_type.isIntOrFloat()) return failure();
Type sum_element_type = GetSumAccumulationType(input_element_type);
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
@ -5148,8 +5154,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
paddings);
Value init =
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
&rewriter);
auto reduce = rewriter.create<ReduceWindowOp>(
op.getLoc(), input_type, input, init,
@ -5157,7 +5164,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
Value result = reduce.getResult();
if (op.exclusive()) {
@ -5193,6 +5200,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
}
};
using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
// dialect lowerings. This involves extracting the shape type, extracting and
// converting each dimension to a known integer type, and repacking into a final
@ -5857,7 +5867,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,

View File

@ -1165,6 +1165,7 @@ tf_xla_py_test(
name = "scan_ops_test",
size = "medium",
srcs = ["scan_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase):
for axis in range(-6, 6, 3):
self._compareAll(x, axis)
@test_util.disable_mlir_bridge("Error handling")
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
with self.session(), self.test_scope():
@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
for axis in range(-6, 6, 3):
self._compareAll(x, axis)
@test_util.disable_mlir_bridge("Error handling")
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
with self.session(), self.test_scope():