From 9bea3781431fbac9cc4b43ae33ac0db976387248 Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 22 Apr 2020 17:12:10 -0700 Subject: [PATCH] Support non-constant begin and end inputs to tf.StridedSlice in TF to XLA HLO lowering. The TF StridedSliceOp can be lowered to a HLO DynamicSliceOp in the case that all strides have a known value of 1. PiperOrigin-RevId: 307936203 Change-Id: Iec28dd2590f4d638e0f03fd9ece744db98700be0 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 6 + .../compiler/mlir/xla/tests/legalize-tf.mlir | 140 +++++++++++- .../mlir/xla/transforms/legalize_tf.cc | 200 ++++++++++++++++-- 3 files changed, 321 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index d40b0a6ba16..1c4d2073413 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2954,6 +2954,12 @@ void SumOp::build(Builder *builder, OperationState &result, Value input, // StridedSliceOp //===----------------------------------------------------------------------===// +// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to +// tf.SliceOp if both of the following are true: +// - All strides have a known value equal to 1 +// - No masks are set (or masks can be applied by transforming the inputs to +// Slice) + // Verifies that, // // - begin, end and strides operands are 1D and they have the same number of diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 222885c487d..9fd53a7db00 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2514,8 +2514,8 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { // Begin: 1, 4, -3 // End: 8, 65, 42 // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) + // Begin mask: 0, 0, 1 (= 1) + // End mask: 1, 0, 0 (= 4) // So result shape: // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 @@ -2662,6 +2662,142 @@ func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tenso return %0 : tensor<2x16x2xf32> } +// CHECK-LABEL: strided_slice_nonconstant_begin_end +func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { + // In this case, the `begin` and `end` inputs are unknown at compile time -- + // so the StridedSlice needs to slice these vectors and use that as input to + // an HLO dynamic slice. + %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + // CHECK: %[[A:.*]] = "xla_hlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[BEGIN:.*]] = "xla_hlo.concatenate"(%[[A]]) + // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor + // CHECK-NEXT: %[[INDEX:.*]] = "xla_hlo.slice"(%[[BEGIN]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]]) + // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor + // CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor + // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" + // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : + // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x97xi32> + // CHECK-NEXT: %[[FINAL:.*]] = "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> + %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 +func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, + // `strides` must be known. + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 +func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown + // at compile time, `strides` must be known to have all 1 values. + %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count +func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { + %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + // When begin/end are dynamic, the number of output elements must be equal to + // the number of input elements sliced. + // CHECK: tf.StridedSlice + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> + return %0 : tensor<6x10xf32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask +func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Begin mask: When `begin` and `end` inputs are unknown at compile time, we + // can't support a begin mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask +func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // End mask: When `begin` and `end` inputs are unknown at compile time, we + // can't support an end mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask +func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // New axis mask: When `begin` and `end` inputs are unknown at compile time, + // we can't support a new_axis mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask +func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is not supported because it does not refer to the last + // dimension. + // [0, 1, 0] = 2 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask +func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is supported because it refers to the last dimension. + // [1, 0, 0] = 4 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: xla_hlo.dynamic-slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask +func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This shrink_axis mask is supported because it refers to a major dimension. + // [1, 1, 1] = 7 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: xla_hlo.dynamic-slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask +func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This shrink_axis mask is unsupported because it does not refer to a major + // dimension. + // [0, 1, 0] = 2 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + //===----------------------------------------------------------------------===// // Reduction op legalizations. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index c7f72b921b0..0a0aefeff0d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -2148,11 +2148,16 @@ class ConvertSplitVOp : public OpRewritePattern { // negative strides and Reshape op to update the output shape. Indices and // strides operands are converted to attributes with non-negative indexing. // +// If the begin input is not a compile time constant, the begin input needs to +// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this +// case, strides must have a known value of 1 (otherwise we have insufficient +// information to conform to XLA's op semantics). +// // For example with an op like following, // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} // : tensor -> tensor // -// Output would be: +// If the %begin input is constant, output would be: // %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...} // %sliced = "xla_hlo.Slice" (%input) // {start_indices = ..., limit_indices = ..., strides = ...} @@ -2162,31 +2167,16 @@ class ConvertStridedSliceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::StridedSliceOp op, - PatternRewriter &rewriter) const override { - // Input shape needs to be static to convert negative indices in TensorFlow - // to absolute indices required by HLO. - // - // TODO(hinsu): Relax this constraint for ops without negative indices and - // strides. - auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return failure(); - ArrayRef input_shape = input_ty.getShape(); - - // Output shape needs to be static to apply 'new_axis_mask' or - // 'shrink_axis_mask' by reshaping tensor after slice. - // - // TODO(hinsu): Relax this constraint for ops without the above masks. - auto result_ty = op.getType().dyn_cast(); - if (!result_ty || !result_ty.hasStaticShape()) return failure(); - - SmallVector begin_indices, end_indices, strides; - if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) - return failure(); - + LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op, + ArrayRef begin_indices, + ArrayRef end_indices, + ArrayRef strides, + RankedTensorType input_ty, + PatternRewriter &rewriter) const { SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, dims_to_reverse; int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); hlo_begin_indices.reserve(input_rank); hlo_end_indices.reserve(input_rank); hlo_strides.reserve(input_rank); @@ -2238,6 +2228,170 @@ class ConvertStridedSliceOp : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } + + LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op, + RankedTensorType input_ty, + RankedTensorType result_ty, + PatternRewriter &rewriter) const { + // If begin and end values are dynamic, we can only support this lowering + // if strides are a known value of 1. + DenseIntElementsAttr sparse_strides_attr; + if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) { + return rewriter.notifyMatchFailure( + op, + "requires that strides are known when begin/end values are dynamic"); + } + SmallVector strides; + int64_t stride_value; + for (const APInt &stride : sparse_strides_attr) { + if ((stride_value = stride.getSExtValue()) != 1) { + return rewriter.notifyMatchFailure(op, + "requires that strides are all 1 " + "when begin/end values are dynamic"); + } + strides.push_back(stride_value); + } + + ArrayRef input_shape = input_ty.getShape(); + int last_dim = std::max(static_cast(input_shape.size()) - 1, 0); + + // When begin/end values are dynamic, we can only support shrinking a major + // axis. For instance, if there are 4 dims, we can support a + // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no + // other. + bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask(); + if (!shrink_axis_mask_ok) + return rewriter.notifyMatchFailure( + op, + "requires that shrink_axis_mask, if set, refer to a major axis " + "dimension (when begin/end values are dynamic)"); + + // When begin/end values are dynamic, the ellipsis mask, if set, must refer + // to the last dimension. + int ellipsis_mask = op.ellipsis_mask().getZExtValue(); + if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) + return rewriter.notifyMatchFailure( + op, + "requires that ellipsis_mask, if set, refer to the last dimension of " + "input (when begin/end values are dynamic)"); + + APInt begin_mask = op.begin_mask(); + if (!begin_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that begin_mask is either set to 0 or not set when " + "begin/end values are dynamic"); + APInt end_mask = op.end_mask(); + if (!end_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that end_mask is either set to 0 or not set when begin/end " + "values are dynamic"); + APInt new_axis_mask = op.new_axis_mask(); + if (!new_axis_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that new_axis_mask is either set to 0 or not set when " + "begin/end values are dynamic"); + + // In this case where the begin and end values are dynamic, the number of + // output elements has to be equal to the number of input elements that + // are sliced. + int output_elements = result_ty.getNumElements(); + int input_elements_sliced = 1; + + // Begin must be a ranked, 1-dimensional tensor: This is checked by the + // verifier. + int64_t slicing_dim_size = + op.begin().getType().cast().getShape()[0]; + auto input_rank = input_shape.size(); + for (int d = slicing_dim_size; d < input_rank; ++d) { + // We only support slicing major dimensions, so minor dimensions after + // slicing dimensions are all sliced with their full sizes. + input_elements_sliced *= input_shape[d]; + } + if (input_elements_sliced != output_elements) { + return rewriter.notifyMatchFailure( + op, + "requires the number of output elements to be equal to the number of " + "input elements sliced (when begin/end values are dynamic)"); + } + + SmallVector slice_begin_indices; + // For the dimensions that are to be sliced, all have slice sizes of 1. + SmallVector slice_sizes(slicing_dim_size, 1); + auto input_element_ty = input_ty.getElementType(); + // Scalar tensor type. + TensorType type = RankedTensorType::get(/*shape=*/{}, input_element_ty); + Location loc = op.getLoc(); + auto zero = GetScalarConstOfType(input_element_ty, loc, 0, &rewriter); + for (int d = 0; d < slicing_dim_size; ++d) { + auto index = rewriter.create( + loc, op.begin(), GetI64ElementsAttr({d}, &rewriter), + GetI64ElementsAttr({d + 1}, &rewriter), + GetI64ElementsAttr({1}, &rewriter)); + // Convert index to scalar. + auto reshaped_index = rewriter.create(loc, type, index); + // If the index is negative, wrap it around with dimension size. + auto index_negative = + rewriter.create(loc, reshaped_index, zero); + auto input_val = GetScalarConstOfType(input_element_ty, loc, + input_shape[d], &rewriter); + auto wrapped_index = + rewriter.create(loc, input_val, reshaped_index); + auto final_index = rewriter.create( + loc, type, index_negative, wrapped_index, reshaped_index); + slice_begin_indices.push_back(final_index); + } + + // For non-slice dims, get the full slice of that dimension. + for (int d = slicing_dim_size; d < input_shape.size(); ++d) { + slice_sizes.push_back(input_shape[d]); + slice_begin_indices.push_back(zero); + } + + auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + // This must be an xla DynamicSlice op due to the inputs that aren't + // constant. + auto sliced = rewriter.create( + loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { + // Input shape needs to be static to convert negative indices in TensorFlow + // to absolute indices required by HLO. + // + // TODO(hinsu): Relax this constraint for ops without negative indices and + // strides. + auto input_ty = op.input().getType().dyn_cast(); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + + // Output shape needs to be static to apply 'new_axis_mask' or + // 'shrink_axis_mask' by reshaping tensor after slice. + // + // TODO(hinsu): Relax this constraint for ops without the above masks. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || !result_ty.hasStaticShape()) return failure(); + + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; + if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(op.end(), m_Constant(&sparse_end_attr))) { + return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); + } + + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) { + return failure(); + } + return rewriteWithConstantBegin(op, begin_indices, end_indices, strides, + input_ty, rewriter); + } }; // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.