Legalize TensorFlow Cumprod op to HLO
Also, add verifier for Cumsum op to reject illegal axis value. GetScalarConstOfType doesn't support non int and float element types so reject ops with other element types. PiperOrigin-RevId: 326787027 Change-Id: I54afe4e494d711fa873b6329391603fbd8958c88
This commit is contained in:
parent
71ff216576
commit
a3d2b5a910
@ -2068,6 +2068,62 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||||
|
let summary = [{
|
||||||
|
Compute the cumulative product of the tensor `x` along `axis`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
By default, this op performs an inclusive cumprod, which means that the first
|
||||||
|
element of the input is identical to the first element of the output:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
|
||||||
|
```
|
||||||
|
|
||||||
|
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
|
||||||
|
performed instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
|
||||||
|
```
|
||||||
|
|
||||||
|
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
|
||||||
|
opposite direction:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
|
||||||
|
```
|
||||||
|
|
||||||
|
This is more efficient than using separate `tf.reverse` ops.
|
||||||
|
|
||||||
|
The `reverse` and `exclusive` kwargs can also be combined:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||||
|
TF_I32OrI64Tensor:$axis,
|
||||||
|
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$reverse
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||||
|
|
||||||
@ -2116,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
|
|||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
|
@ -774,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CumsumOp and CumprodOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename OpT, typename std::enable_if<llvm::is_one_of<
|
||||||
|
OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
|
||||||
|
static LogicalResult Verify(OpT op) {
|
||||||
|
if (!IsOfRankOrUnranked(op.axis(), 0))
|
||||||
|
return op.emitOpError("requires scalar axis operand");
|
||||||
|
|
||||||
|
DenseIntElementsAttr axis_attr;
|
||||||
|
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
||||||
|
auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (input_ty) {
|
||||||
|
int64_t rank = input_ty.getRank();
|
||||||
|
assert(axis_attr.getNumElements() == 1 &&
|
||||||
|
"scalar attribute should have exactly one element");
|
||||||
|
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
||||||
|
if (axis < -rank || axis >= rank) {
|
||||||
|
return op.emitError()
|
||||||
|
<< "axis operand should be within range [" << -rank << ", "
|
||||||
|
<< rank << "); actual value: " << axis;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConcatOffsetOp
|
// ConcatOffsetOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -3373,3 +3373,28 @@ func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32>
|
|||||||
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
|
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test valid tf.Cumsum
|
||||||
|
func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> {
|
||||||
|
%0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||||
|
return %0 : tensor<8x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> {
|
||||||
|
// expected-error @+1 {{requires scalar axis operand}}
|
||||||
|
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32>
|
||||||
|
return %0 : tensor<8x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
|
%axis = constant dense<-3> : tensor<i32>
|
||||||
|
// expected-error @+1 {{axis operand should be within range [-2, 2)}}
|
||||||
|
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||||
|
return %0 : tensor<8x16xf32>
|
||||||
|
}
|
||||||
|
@ -4668,6 +4668,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
|
|||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Cumprod op legalizations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cumprod
|
||||||
|
func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
// CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
|
||||||
|
// CHECK: mhlo.mul
|
||||||
|
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Qr op legalization
|
// Qr op legalization
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -5092,17 +5092,19 @@ class ConvertXlaDynamicUpdateSliceOp
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting
|
// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
|
||||||
/// appropriate window dimensions, with 'add' as the reduction function. The
|
// setting appropriate window dimensions, with the given aggregation op as the
|
||||||
/// input tensor needs to have a static shape, and 'axis' must be const. The
|
// reduction function. The input tensor needs to have a static shape, and 'axis'
|
||||||
/// TableGen pattern is not used for this rewrite because it involves regions.
|
// must be const. The TableGen pattern is not used for this rewrite because it
|
||||||
class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
// involves regions.
|
||||||
using OpRewritePattern<TF::CumsumOp>::OpRewritePattern;
|
template <typename OpT, typename AggregationOp>
|
||||||
|
class ConvertCumOp : public OpRewritePattern<OpT> {
|
||||||
|
using OpRewritePattern<OpT>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::CumsumOp op,
|
LogicalResult matchAndRewrite(OpT op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto input = op.x();
|
auto input = op.x();
|
||||||
auto input_type = input.getType().dyn_cast<ShapedType>();
|
auto input_type = input.getType().template dyn_cast<ShapedType>();
|
||||||
if (!input_type || !input_type.hasStaticShape()) {
|
if (!input_type || !input_type.hasStaticShape()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@ -5135,6 +5137,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
|||||||
// Convert if we need to enlarge the element type's bitwidth to avoid
|
// Convert if we need to enlarge the element type's bitwidth to avoid
|
||||||
// precision loss.
|
// precision loss.
|
||||||
Type input_element_type = input_type.getElementType();
|
Type input_element_type = input_type.getElementType();
|
||||||
|
|
||||||
|
// TODO(hinsu): Handle complex element types.
|
||||||
|
if (!input_element_type.isIntOrFloat()) return failure();
|
||||||
|
|
||||||
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
||||||
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
|
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
|
||||||
|
|
||||||
@ -5148,8 +5154,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
|||||||
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
|
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
|
||||||
paddings);
|
paddings);
|
||||||
|
|
||||||
Value init =
|
int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
|
||||||
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
|
Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
|
||||||
|
&rewriter);
|
||||||
|
|
||||||
auto reduce = rewriter.create<ReduceWindowOp>(
|
auto reduce = rewriter.create<ReduceWindowOp>(
|
||||||
op.getLoc(), input_type, input, init,
|
op.getLoc(), input_type, input, init,
|
||||||
@ -5157,7 +5164,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
|||||||
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
|
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
|
||||||
/*base_dilations=*/DenseIntElementsAttr(),
|
/*base_dilations=*/DenseIntElementsAttr(),
|
||||||
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
|
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
|
||||||
BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
|
BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
|
||||||
Value result = reduce.getResult();
|
Value result = reduce.getResult();
|
||||||
|
|
||||||
if (op.exclusive()) {
|
if (op.exclusive()) {
|
||||||
@ -5193,6 +5200,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
|
||||||
|
using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
|
||||||
|
|
||||||
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
|
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
|
||||||
// dialect lowerings. This involves extracting the shape type, extracting and
|
// dialect lowerings. This involves extracting the shape type, extracting and
|
||||||
// converting each dimension to a known integer type, and repacking into a final
|
// converting each dimension to a known integer type, and repacking into a final
|
||||||
@ -5857,7 +5867,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
|||||||
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
|
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
|
||||||
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
|
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
|
||||||
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
||||||
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||||
|
@ -1165,6 +1165,7 @@ tf_xla_py_test(
|
|||||||
name = "scan_ops_test",
|
name = "scan_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["scan_ops_test.py"],
|
srcs = ["scan_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase):
|
|||||||
for axis in range(-6, 6, 3):
|
for axis in range(-6, 6, 3):
|
||||||
self._compareAll(x, axis)
|
self._compareAll(x, axis)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error handling")
|
||||||
def testInvalidAxis(self):
|
def testInvalidAxis(self):
|
||||||
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
||||||
with self.session(), self.test_scope():
|
with self.session(), self.test_scope():
|
||||||
@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
|
|||||||
for axis in range(-6, 6, 3):
|
for axis in range(-6, 6, 3):
|
||||||
self._compareAll(x, axis)
|
self._compareAll(x, axis)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error handling")
|
||||||
def testInvalidAxis(self):
|
def testInvalidAxis(self):
|
||||||
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
|
||||||
with self.session(), self.test_scope():
|
with self.session(), self.test_scope():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user