Support non-constant begin and end inputs to tf.StridedSlice in TF to XLA HLO lowering.
The TF StridedSliceOp can be lowered to a HLO DynamicSliceOp in the case that all strides have a known value of 1. PiperOrigin-RevId: 307936203 Change-Id: Iec28dd2590f4d638e0f03fd9ece744db98700be0
This commit is contained in:
parent
50a33eef52
commit
9bea378143
@ -2954,6 +2954,12 @@ void SumOp::build(Builder *builder, OperationState &result, Value input,
|
|||||||
// StridedSliceOp
|
// StridedSliceOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
|
||||||
|
// tf.SliceOp if both of the following are true:
|
||||||
|
// - All strides have a known value equal to 1
|
||||||
|
// - No masks are set (or masks can be applied by transforming the inputs to
|
||||||
|
// Slice)
|
||||||
|
|
||||||
// Verifies that,
|
// Verifies that,
|
||||||
//
|
//
|
||||||
// - begin, end and strides operands are 1D and they have the same number of
|
// - begin, end and strides operands are 1D and they have the same number of
|
||||||
|
@ -2514,8 +2514,8 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) {
|
|||||||
// Begin: 1, 4, -3
|
// Begin: 1, 4, -3
|
||||||
// End: 8, 65, 42
|
// End: 8, 65, 42
|
||||||
// Stride: 1, 4, -1
|
// Stride: 1, 4, -1
|
||||||
// Begin mask: 1, 0, 0 (= 1)
|
// Begin mask: 0, 0, 1 (= 1)
|
||||||
// End mask: 0, 0, 1 (= 4)
|
// End mask: 1, 0, 0 (= 4)
|
||||||
|
|
||||||
// So result shape:
|
// So result shape:
|
||||||
// Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
|
// Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
|
||||||
@ -2662,6 +2662,142 @@ func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tenso
|
|||||||
return %0 : tensor<2x16x2xf32>
|
return %0 : tensor<2x16x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end
|
||||||
|
func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// In this case, the `begin` and `end` inputs are unknown at compile time --
|
||||||
|
// so the StridedSlice needs to slice these vectors and use that as input to
|
||||||
|
// an HLO dynamic slice.
|
||||||
|
%begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
|
||||||
|
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
%end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %[[A:.*]] = "xla_hlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
|
||||||
|
// CHECK-NEXT: %[[BEGIN:.*]] = "xla_hlo.concatenate"(%[[A]])
|
||||||
|
// CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor<i32>
|
||||||
|
// CHECK-NEXT: %[[INDEX:.*]] = "xla_hlo.slice"(%[[BEGIN]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
|
||||||
|
// CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
|
||||||
|
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]])
|
||||||
|
// CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
// CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor<i32>
|
||||||
|
// CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor<i32>
|
||||||
|
// CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
|
||||||
|
// CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
// CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice"
|
||||||
|
// CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
|
||||||
|
// CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32>
|
||||||
|
// CHECK-NEXT: %[[FINAL:.*]] = "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32>
|
||||||
|
%result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
// CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1
|
||||||
|
func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// Dynamic stride: when `begin` and `end` inputs are unknown at compile time,
|
||||||
|
// `strides` must be known.
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2
|
||||||
|
func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown
|
||||||
|
// at compile time, `strides` must be known to have all 1 values.
|
||||||
|
%strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count
|
||||||
|
func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> {
|
||||||
|
%strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||||
|
// When begin/end are dynamic, the number of output elements must be equal to
|
||||||
|
// the number of input elements sliced.
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32>
|
||||||
|
return %0 : tensor<6x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// Begin mask: When `begin` and `end` inputs are unknown at compile time, we
|
||||||
|
// can't support a begin mask.
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// End mask: When `begin` and `end` inputs are unknown at compile time, we
|
||||||
|
// can't support an end mask.
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// New axis mask: When `begin` and `end` inputs are unknown at compile time,
|
||||||
|
// we can't support a new_axis mask.
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// This ellipsis mask is not supported because it does not refer to the last
|
||||||
|
// dimension.
|
||||||
|
// [0, 1, 0] = 2
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// This ellipsis mask is supported because it refers to the last dimension.
|
||||||
|
// [1, 0, 0] = 4
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: xla_hlo.dynamic-slice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// This shrink_axis mask is supported because it refers to a major dimension.
|
||||||
|
// [1, 1, 1] = 7
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: xla_hlo.dynamic-slice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask
|
||||||
|
func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
|
||||||
|
// This shrink_axis mask is unsupported because it does not refer to a major
|
||||||
|
// dimension.
|
||||||
|
// [0, 1, 0] = 2
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: tf.StridedSlice
|
||||||
|
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
|
||||||
|
return %result : tensor<1x97xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reduction op legalizations.
|
// Reduction op legalizations.
|
||||||
|
@ -2148,11 +2148,16 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
|
|||||||
// negative strides and Reshape op to update the output shape. Indices and
|
// negative strides and Reshape op to update the output shape. Indices and
|
||||||
// strides operands are converted to attributes with non-negative indexing.
|
// strides operands are converted to attributes with non-negative indexing.
|
||||||
//
|
//
|
||||||
|
// If the begin input is not a compile time constant, the begin input needs to
|
||||||
|
// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this
|
||||||
|
// case, strides must have a known value of 1 (otherwise we have insufficient
|
||||||
|
// information to conform to XLA's op semantics).
|
||||||
|
//
|
||||||
// For example with an op like following,
|
// For example with an op like following,
|
||||||
// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
|
// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
|
||||||
// : tensor<AxBxf32> -> tensor<Pxf32>
|
// : tensor<AxBxf32> -> tensor<Pxf32>
|
||||||
//
|
//
|
||||||
// Output would be:
|
// If the %begin input is constant, output would be:
|
||||||
// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
|
// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
|
||||||
// %sliced = "xla_hlo.Slice" (%input)
|
// %sliced = "xla_hlo.Slice" (%input)
|
||||||
// {start_indices = ..., limit_indices = ..., strides = ...}
|
// {start_indices = ..., limit_indices = ..., strides = ...}
|
||||||
@ -2162,31 +2167,16 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
|||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::StridedSliceOp op,
|
LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
ArrayRef<int64_t> begin_indices,
|
||||||
// Input shape needs to be static to convert negative indices in TensorFlow
|
ArrayRef<int64_t> end_indices,
|
||||||
// to absolute indices required by HLO.
|
ArrayRef<int64_t> strides,
|
||||||
//
|
RankedTensorType input_ty,
|
||||||
// TODO(hinsu): Relax this constraint for ops without negative indices and
|
PatternRewriter &rewriter) const {
|
||||||
// strides.
|
|
||||||
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!input_ty || !input_ty.hasStaticShape()) return failure();
|
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
|
||||||
|
|
||||||
// Output shape needs to be static to apply 'new_axis_mask' or
|
|
||||||
// 'shrink_axis_mask' by reshaping tensor after slice.
|
|
||||||
//
|
|
||||||
// TODO(hinsu): Relax this constraint for ops without the above masks.
|
|
||||||
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!result_ty || !result_ty.hasStaticShape()) return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> begin_indices, end_indices, strides;
|
|
||||||
if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
|
SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
|
||||||
dims_to_reverse;
|
dims_to_reverse;
|
||||||
int64_t input_rank = input_ty.getRank();
|
int64_t input_rank = input_ty.getRank();
|
||||||
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
hlo_begin_indices.reserve(input_rank);
|
hlo_begin_indices.reserve(input_rank);
|
||||||
hlo_end_indices.reserve(input_rank);
|
hlo_end_indices.reserve(input_rank);
|
||||||
hlo_strides.reserve(input_rank);
|
hlo_strides.reserve(input_rank);
|
||||||
@ -2238,6 +2228,170 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
|||||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
|
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
|
||||||
|
RankedTensorType input_ty,
|
||||||
|
RankedTensorType result_ty,
|
||||||
|
PatternRewriter &rewriter) const {
|
||||||
|
// If begin and end values are dynamic, we can only support this lowering
|
||||||
|
// if strides are a known value of 1.
|
||||||
|
DenseIntElementsAttr sparse_strides_attr;
|
||||||
|
if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that strides are known when begin/end values are dynamic");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t, 4> strides;
|
||||||
|
int64_t stride_value;
|
||||||
|
for (const APInt &stride : sparse_strides_attr) {
|
||||||
|
if ((stride_value = stride.getSExtValue()) != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"requires that strides are all 1 "
|
||||||
|
"when begin/end values are dynamic");
|
||||||
|
}
|
||||||
|
strides.push_back(stride_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
|
int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
|
||||||
|
|
||||||
|
// When begin/end values are dynamic, we can only support shrinking a major
|
||||||
|
// axis. For instance, if there are 4 dims, we can support a
|
||||||
|
// shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
|
||||||
|
// other.
|
||||||
|
bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask();
|
||||||
|
if (!shrink_axis_mask_ok)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that shrink_axis_mask, if set, refer to a major axis "
|
||||||
|
"dimension (when begin/end values are dynamic)");
|
||||||
|
|
||||||
|
// When begin/end values are dynamic, the ellipsis mask, if set, must refer
|
||||||
|
// to the last dimension.
|
||||||
|
int ellipsis_mask = op.ellipsis_mask().getZExtValue();
|
||||||
|
if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that ellipsis_mask, if set, refer to the last dimension of "
|
||||||
|
"input (when begin/end values are dynamic)");
|
||||||
|
|
||||||
|
APInt begin_mask = op.begin_mask();
|
||||||
|
if (!begin_mask.isNullValue())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that begin_mask is either set to 0 or not set when "
|
||||||
|
"begin/end values are dynamic");
|
||||||
|
APInt end_mask = op.end_mask();
|
||||||
|
if (!end_mask.isNullValue())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that end_mask is either set to 0 or not set when begin/end "
|
||||||
|
"values are dynamic");
|
||||||
|
APInt new_axis_mask = op.new_axis_mask();
|
||||||
|
if (!new_axis_mask.isNullValue())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires that new_axis_mask is either set to 0 or not set when "
|
||||||
|
"begin/end values are dynamic");
|
||||||
|
|
||||||
|
// In this case where the begin and end values are dynamic, the number of
|
||||||
|
// output elements has to be equal to the number of input elements that
|
||||||
|
// are sliced.
|
||||||
|
int output_elements = result_ty.getNumElements();
|
||||||
|
int input_elements_sliced = 1;
|
||||||
|
|
||||||
|
// Begin must be a ranked, 1-dimensional tensor: This is checked by the
|
||||||
|
// verifier.
|
||||||
|
int64_t slicing_dim_size =
|
||||||
|
op.begin().getType().cast<RankedTensorType>().getShape()[0];
|
||||||
|
auto input_rank = input_shape.size();
|
||||||
|
for (int d = slicing_dim_size; d < input_rank; ++d) {
|
||||||
|
// We only support slicing major dimensions, so minor dimensions after
|
||||||
|
// slicing dimensions are all sliced with their full sizes.
|
||||||
|
input_elements_sliced *= input_shape[d];
|
||||||
|
}
|
||||||
|
if (input_elements_sliced != output_elements) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"requires the number of output elements to be equal to the number of "
|
||||||
|
"input elements sliced (when begin/end values are dynamic)");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> slice_begin_indices;
|
||||||
|
// For the dimensions that are to be sliced, all have slice sizes of 1.
|
||||||
|
SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
|
||||||
|
auto input_element_ty = input_ty.getElementType();
|
||||||
|
// Scalar tensor type.
|
||||||
|
TensorType type = RankedTensorType::get(/*shape=*/{}, input_element_ty);
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto zero = GetScalarConstOfType(input_element_ty, loc, 0, &rewriter);
|
||||||
|
for (int d = 0; d < slicing_dim_size; ++d) {
|
||||||
|
auto index = rewriter.create<SliceOp>(
|
||||||
|
loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
|
||||||
|
GetI64ElementsAttr({d + 1}, &rewriter),
|
||||||
|
GetI64ElementsAttr({1}, &rewriter));
|
||||||
|
// Convert index to scalar.
|
||||||
|
auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
|
||||||
|
// If the index is negative, wrap it around with dimension size.
|
||||||
|
auto index_negative =
|
||||||
|
rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
|
||||||
|
auto input_val = GetScalarConstOfType(input_element_ty, loc,
|
||||||
|
input_shape[d], &rewriter);
|
||||||
|
auto wrapped_index =
|
||||||
|
rewriter.create<TF::AddOp>(loc, input_val, reshaped_index);
|
||||||
|
auto final_index = rewriter.create<SelectOp>(
|
||||||
|
loc, type, index_negative, wrapped_index, reshaped_index);
|
||||||
|
slice_begin_indices.push_back(final_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-slice dims, get the full slice of that dimension.
|
||||||
|
for (int d = slicing_dim_size; d < input_shape.size(); ++d) {
|
||||||
|
slice_sizes.push_back(input_shape[d]);
|
||||||
|
slice_begin_indices.push_back(zero);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
|
||||||
|
// This must be an xla DynamicSlice op due to the inputs that aren't
|
||||||
|
// constant.
|
||||||
|
auto sliced = rewriter.create<DynamicSliceOp>(
|
||||||
|
loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
|
||||||
|
|
||||||
|
// Reshape slice result so that the shape is updated depending on
|
||||||
|
// 'new_axis_mask' or 'shrink_axis_mask' attributes.
|
||||||
|
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(TF::StridedSliceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Input shape needs to be static to convert negative indices in TensorFlow
|
||||||
|
// to absolute indices required by HLO.
|
||||||
|
//
|
||||||
|
// TODO(hinsu): Relax this constraint for ops without negative indices and
|
||||||
|
// strides.
|
||||||
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!input_ty || !input_ty.hasStaticShape()) return failure();
|
||||||
|
|
||||||
|
// Output shape needs to be static to apply 'new_axis_mask' or
|
||||||
|
// 'shrink_axis_mask' by reshaping tensor after slice.
|
||||||
|
//
|
||||||
|
// TODO(hinsu): Relax this constraint for ops without the above masks.
|
||||||
|
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_ty || !result_ty.hasStaticShape()) return failure();
|
||||||
|
|
||||||
|
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
|
||||||
|
if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
|
||||||
|
!matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
|
||||||
|
return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> begin_indices, end_indices, strides;
|
||||||
|
if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
|
||||||
|
input_ty, rewriter);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
|
// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
|
||||||
|
Loading…
Reference in New Issue
Block a user