Legalize TensorFlow Cumprod op to HLO
Also, add verifier for Cumsum op to reject illegal axis value. GetScalarConstOfType doesn't support non int and float element types so reject ops with other element types. PiperOrigin-RevId: 326787027 Change-Id: I54afe4e494d711fa873b6329391603fbd8958c88
This commit is contained in:
parent
71ff216576
commit
a3d2b5a910
@ -2068,6 +2068,62 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = [{
|
||||
Compute the cumulative product of the tensor `x` along `axis`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
By default, this op performs an inclusive cumprod, which means that the first
|
||||
element of the input is identical to the first element of the output:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
|
||||
```
|
||||
|
||||
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
|
||||
performed instead:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
|
||||
```
|
||||
|
||||
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
|
||||
opposite direction:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
|
||||
```
|
||||
|
||||
This is more efficient than using separate `tf.reverse` ops.
|
||||
|
||||
The `reverse` and `exclusive` kwargs can also be combined:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TF_I32OrI64Tensor:$axis,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$reverse
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||
|
||||
@ -2116,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -774,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CumsumOp and CumprodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpT, typename std::enable_if<llvm::is_one_of<
|
||||
OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
|
||||
static LogicalResult Verify(OpT op) {
|
||||
if (!IsOfRankOrUnranked(op.axis(), 0))
|
||||
return op.emitOpError("requires scalar axis operand");
|
||||
|
||||
DenseIntElementsAttr axis_attr;
|
||||
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
||||
auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
|
||||
if (input_ty) {
|
||||
int64_t rank = input_ty.getRank();
|
||||
assert(axis_attr.getNumElements() == 1 &&
|
||||
"scalar attribute should have exactly one element");
|
||||
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
||||
if (axis < -rank || axis >= rank) {
|
||||
return op.emitError()
|
||||
<< "axis operand should be within range [" << -rank << ", "
|
||||
<< rank << "); actual value: " << axis;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOffsetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3373,3 +3373,28 @@ func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32>
|
||||
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.Cumsum
|
||||
func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> {
|
||||
// expected-error @+1 {{requires scalar axis operand}}
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%axis = constant dense<-3> : tensor<i32>
|
||||
// expected-error @+1 {{axis operand should be within range [-2, 2)}}
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
@ -4668,6 +4668,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cumprod op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @cumprod
|
||||
func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
|
||||
// CHECK: mhlo.mul
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Qr op legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -5092,17 +5092,19 @@ class ConvertXlaDynamicUpdateSliceOp
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting
|
||||
/// appropriate window dimensions, with 'add' as the reduction function. The
|
||||
/// input tensor needs to have a static shape, and 'axis' must be const. The
|
||||
/// TableGen pattern is not used for this rewrite because it involves regions.
|
||||
class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
using OpRewritePattern<TF::CumsumOp>::OpRewritePattern;
|
||||
// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
|
||||
// setting appropriate window dimensions, with the given aggregation op as the
|
||||
// reduction function. The input tensor needs to have a static shape, and 'axis'
|
||||
// must be const. The TableGen pattern is not used for this rewrite because it
|
||||
// involves regions.
|
||||
template <typename OpT, typename AggregationOp>
|
||||
class ConvertCumOp : public OpRewritePattern<OpT> {
|
||||
using OpRewritePattern<OpT>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::CumsumOp op,
|
||||
LogicalResult matchAndRewrite(OpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input = op.x();
|
||||
auto input_type = input.getType().dyn_cast<ShapedType>();
|
||||
auto input_type = input.getType().template dyn_cast<ShapedType>();
|
||||
if (!input_type || !input_type.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
@ -5135,6 +5137,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
// Convert if we need to enlarge the element type's bitwidth to avoid
|
||||
// precision loss.
|
||||
Type input_element_type = input_type.getElementType();
|
||||
|
||||
// TODO(hinsu): Handle complex element types.
|
||||
if (!input_element_type.isIntOrFloat()) return failure();
|
||||
|
||||
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
||||
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
|
||||
|
||||
@ -5148,8 +5154,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
|
||||
paddings);
|
||||
|
||||
Value init =
|
||||
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
|
||||
int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
|
||||
Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
|
||||
&rewriter);
|
||||
|
||||
auto reduce = rewriter.create<ReduceWindowOp>(
|
||||
op.getLoc(), input_type, input, init,
|
||||
@ -5157,7 +5164,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
|
||||
/*base_dilations=*/DenseIntElementsAttr(),
|
||||
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
|
||||
BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
|
||||
BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
|
||||
Value result = reduce.getResult();
|
||||
|
||||
if (op.exclusive()) {
|
||||
@ -5193,6 +5200,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
}
|
||||
};
|
||||
|
||||
using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
|
||||
using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
|
||||
|
||||
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
|
||||
// dialect lowerings. This involves extracting the shape type, extracting and
|
||||
// converting each dimension to a known integer type, and repacking into a final
|
||||
@ -5857,7 +5867,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
|
||||
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
|
||||
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
||||
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||
|
@ -1165,6 +1165,7 @@ tf_xla_py_test(
|
||||
name = "scan_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["scan_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase):
|
||||
for axis in range(-6, 6, 3):
|
||||
self._compareAll(x, axis)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testInvalidAxis(self):
|
||||
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
||||
with self.session(), self.test_scope():
|
||||
@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
|
||||
for axis in range(-6, 6, 3):
|
||||
self._compareAll(x, axis)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testInvalidAxis(self):
|
||||
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
||||
with self.session(), self.test_scope():
|
||||
|
Loading…
Reference in New Issue
Block a user