Support non-constant begin and end inputs to tf.StridedSlice in TF to XLA HLO lowering.

The TF StridedSliceOp can be lowered to a HLO DynamicSliceOp in the case that all strides have a known value of 1.

PiperOrigin-RevId: 307936203
Change-Id: Iec28dd2590f4d638e0f03fd9ece744db98700be0
This commit is contained in:
Lucy Fox 2020-04-22 17:12:10 -07:00 committed by TensorFlower Gardener
parent 50a33eef52
commit 9bea378143
3 changed files with 321 additions and 25 deletions

View File

@ -2954,6 +2954,12 @@ void SumOp::build(Builder *builder, OperationState &result, Value input,
// StridedSliceOp
//===----------------------------------------------------------------------===//
// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
// tf.SliceOp if both of the following are true:
// - All strides have a known value equal to 1
// - No masks are set (or masks can be applied by transforming the inputs to
// Slice)
// Verifies that,
//
// - begin, end and strides operands are 1D and they have the same number of

View File

@ -2514,8 +2514,8 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) {
// Begin: 1, 4, -3
// End: 8, 65, 42
// Stride: 1, 4, -1
// Begin mask: 1, 0, 0 (= 1)
// End mask: 0, 0, 1 (= 4)
// Begin mask: 0, 0, 1 (= 1)
// End mask: 1, 0, 0 (= 4)
// So result shape:
// Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
@ -2662,6 +2662,142 @@ func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tenso
return %0 : tensor<2x16x2xf32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end
func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) {
// In this case, the `begin` and `end` inputs are unknown at compile time --
// so the StridedSlice needs to slice these vectors and use that as input to
// an HLO dynamic slice.
%begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
// CHECK: %[[A:.*]] = "xla_hlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: %[[BEGIN:.*]] = "xla_hlo.concatenate"(%[[A]])
// CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor<i32>
// CHECK-NEXT: %[[INDEX:.*]] = "xla_hlo.slice"(%[[BEGIN]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]])
// CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor<i32>
// CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor<i32>
// CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
// CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice"
// CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
// CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
// CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32>
// CHECK-NEXT: %[[FINAL:.*]] = "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32>
%result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
// CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1
func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) {
// Dynamic stride: when `begin` and `end` inputs are unknown at compile time,
// `strides` must be known.
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2
func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown
// at compile time, `strides` must be known to have all 1 values.
%strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count
func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> {
%strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
// When begin/end are dynamic, the number of output elements must be equal to
// the number of input elements sliced.
// CHECK: tf.StridedSlice
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32>
return %0 : tensor<6x10xf32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask
func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// Begin mask: When `begin` and `end` inputs are unknown at compile time, we
// can't support a begin mask.
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask
func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// End mask: When `begin` and `end` inputs are unknown at compile time, we
// can't support an end mask.
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask
func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// New axis mask: When `begin` and `end` inputs are unknown at compile time,
// we can't support a new_axis mask.
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask
func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// This ellipsis mask is not supported because it does not refer to the last
// dimension.
// [0, 1, 0] = 2
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask
func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// This ellipsis mask is supported because it refers to the last dimension.
// [1, 0, 0] = 4
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: xla_hlo.dynamic-slice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask
func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// This shrink_axis mask is supported because it refers to a major dimension.
// [1, 1, 1] = 7
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: xla_hlo.dynamic-slice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask
func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
// This shrink_axis mask is unsupported because it does not refer to a major
// dimension.
// [0, 1, 0] = 2
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: tf.StridedSlice
%result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
return %result : tensor<1x97xi32>
}
//===----------------------------------------------------------------------===//
// Reduction op legalizations.

View File

@ -2148,11 +2148,16 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
// negative strides and Reshape op to update the output shape. Indices and
// strides operands are converted to attributes with non-negative indexing.
//
// If the begin input is not a compile time constant, the begin input needs to
// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this
// case, strides must have a known value of 1 (otherwise we have insufficient
// information to conform to XLA's op semantics).
//
// For example with an op like following,
// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
// : tensor<AxBxf32> -> tensor<Pxf32>
//
// Output would be:
// If the %begin input is constant, output would be:
// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
// %sliced = "xla_hlo.Slice" (%input)
// {start_indices = ..., limit_indices = ..., strides = ...}
@ -2162,31 +2167,16 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::StridedSliceOp op,
PatternRewriter &rewriter) const override {
// Input shape needs to be static to convert negative indices in TensorFlow
// to absolute indices required by HLO.
//
// TODO(hinsu): Relax this constraint for ops without negative indices and
// strides.
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return failure();
ArrayRef<int64_t> input_shape = input_ty.getShape();
// Output shape needs to be static to apply 'new_axis_mask' or
// 'shrink_axis_mask' by reshaping tensor after slice.
//
// TODO(hinsu): Relax this constraint for ops without the above masks.
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
if (!result_ty || !result_ty.hasStaticShape()) return failure();
SmallVector<int64_t, 4> begin_indices, end_indices, strides;
if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides))
return failure();
LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
ArrayRef<int64_t> begin_indices,
ArrayRef<int64_t> end_indices,
ArrayRef<int64_t> strides,
RankedTensorType input_ty,
PatternRewriter &rewriter) const {
SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
dims_to_reverse;
int64_t input_rank = input_ty.getRank();
ArrayRef<int64_t> input_shape = input_ty.getShape();
hlo_begin_indices.reserve(input_rank);
hlo_end_indices.reserve(input_rank);
hlo_strides.reserve(input_rank);
@ -2238,6 +2228,170 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
return success();
}
LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
RankedTensorType input_ty,
RankedTensorType result_ty,
PatternRewriter &rewriter) const {
// If begin and end values are dynamic, we can only support this lowering
// if strides are a known value of 1.
DenseIntElementsAttr sparse_strides_attr;
if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
return rewriter.notifyMatchFailure(
op,
"requires that strides are known when begin/end values are dynamic");
}
SmallVector<int64_t, 4> strides;
int64_t stride_value;
for (const APInt &stride : sparse_strides_attr) {
if ((stride_value = stride.getSExtValue()) != 1) {
return rewriter.notifyMatchFailure(op,
"requires that strides are all 1 "
"when begin/end values are dynamic");
}
strides.push_back(stride_value);
}
ArrayRef<int64_t> input_shape = input_ty.getShape();
int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
// When begin/end values are dynamic, we can only support shrinking a major
// axis. For instance, if there are 4 dims, we can support a
// shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
// other.
bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask();
if (!shrink_axis_mask_ok)
return rewriter.notifyMatchFailure(
op,
"requires that shrink_axis_mask, if set, refer to a major axis "
"dimension (when begin/end values are dynamic)");
// When begin/end values are dynamic, the ellipsis mask, if set, must refer
// to the last dimension.
int ellipsis_mask = op.ellipsis_mask().getZExtValue();
if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
return rewriter.notifyMatchFailure(
op,
"requires that ellipsis_mask, if set, refer to the last dimension of "
"input (when begin/end values are dynamic)");
APInt begin_mask = op.begin_mask();
if (!begin_mask.isNullValue())
return rewriter.notifyMatchFailure(
op,
"requires that begin_mask is either set to 0 or not set when "
"begin/end values are dynamic");
APInt end_mask = op.end_mask();
if (!end_mask.isNullValue())
return rewriter.notifyMatchFailure(
op,
"requires that end_mask is either set to 0 or not set when begin/end "
"values are dynamic");
APInt new_axis_mask = op.new_axis_mask();
if (!new_axis_mask.isNullValue())
return rewriter.notifyMatchFailure(
op,
"requires that new_axis_mask is either set to 0 or not set when "
"begin/end values are dynamic");
// In this case where the begin and end values are dynamic, the number of
// output elements has to be equal to the number of input elements that
// are sliced.
int output_elements = result_ty.getNumElements();
int input_elements_sliced = 1;
// Begin must be a ranked, 1-dimensional tensor: This is checked by the
// verifier.
int64_t slicing_dim_size =
op.begin().getType().cast<RankedTensorType>().getShape()[0];
auto input_rank = input_shape.size();
for (int d = slicing_dim_size; d < input_rank; ++d) {
// We only support slicing major dimensions, so minor dimensions after
// slicing dimensions are all sliced with their full sizes.
input_elements_sliced *= input_shape[d];
}
if (input_elements_sliced != output_elements) {
return rewriter.notifyMatchFailure(
op,
"requires the number of output elements to be equal to the number of "
"input elements sliced (when begin/end values are dynamic)");
}
SmallVector<Value, 4> slice_begin_indices;
// For the dimensions that are to be sliced, all have slice sizes of 1.
SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
auto input_element_ty = input_ty.getElementType();
// Scalar tensor type.
TensorType type = RankedTensorType::get(/*shape=*/{}, input_element_ty);
Location loc = op.getLoc();
auto zero = GetScalarConstOfType(input_element_ty, loc, 0, &rewriter);
for (int d = 0; d < slicing_dim_size; ++d) {
auto index = rewriter.create<SliceOp>(
loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
GetI64ElementsAttr({d + 1}, &rewriter),
GetI64ElementsAttr({1}, &rewriter));
// Convert index to scalar.
auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
// If the index is negative, wrap it around with dimension size.
auto index_negative =
rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
auto input_val = GetScalarConstOfType(input_element_ty, loc,
input_shape[d], &rewriter);
auto wrapped_index =
rewriter.create<TF::AddOp>(loc, input_val, reshaped_index);
auto final_index = rewriter.create<SelectOp>(
loc, type, index_negative, wrapped_index, reshaped_index);
slice_begin_indices.push_back(final_index);
}
// For non-slice dims, get the full slice of that dimension.
for (int d = slicing_dim_size; d < input_shape.size(); ++d) {
slice_sizes.push_back(input_shape[d]);
slice_begin_indices.push_back(zero);
}
auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
// This must be an xla DynamicSlice op due to the inputs that aren't
// constant.
auto sliced = rewriter.create<DynamicSliceOp>(
loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
// Reshape slice result so that the shape is updated depending on
// 'new_axis_mask' or 'shrink_axis_mask' attributes.
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
return success();
}
LogicalResult matchAndRewrite(TF::StridedSliceOp op,
PatternRewriter &rewriter) const override {
// Input shape needs to be static to convert negative indices in TensorFlow
// to absolute indices required by HLO.
//
// TODO(hinsu): Relax this constraint for ops without negative indices and
// strides.
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return failure();
// Output shape needs to be static to apply 'new_axis_mask' or
// 'shrink_axis_mask' by reshaping tensor after slice.
//
// TODO(hinsu): Relax this constraint for ops without the above masks.
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
if (!result_ty || !result_ty.hasStaticShape()) return failure();
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
!matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
}
SmallVector<int64_t, 4> begin_indices, end_indices, strides;
if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
return failure();
}
return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
input_ty, rewriter);
}
};
// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.