From ddd0a880cbc8074ade6fad74f3333781c21b3a86 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 20 Sep 2019 15:15:24 -0700 Subject: [PATCH] Lower TensorFlow StridedSlice op to HLO Currently, this only handles the simple case with constant begin and end operands and no mask attributes except shrink_axis_mask. StridedSlice op is lowered to HLO Slice op along with Reverse op to handle negative stride values and Reshape op to update the output shape to handle change in the output rank. Added a new build for the HLO Slice op that computes output type based on the operand type and attributes. PiperOrigin-RevId: 270360518 --- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 55 ++++++++ tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 15 +++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 74 ++++++++++ .../mlir/xla/transforms/legalize_tf.cc | 127 +++++++++++++++++- 4 files changed, 269 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index de5bc1dee54..f52d724b03d 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Dialect.h" // TF:local_config_mlir @@ -533,6 +534,60 @@ static LogicalResult Verify(PadOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +void SliceOp::build(Builder* builder, OperationState* result, Value* operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides) { + return build( + builder, result, + InferOutputTypes(builder, operand, start_indices, limit_indices, strides), + operand, start_indices, limit_indices, strides); +} + +// Returns output dimension size for slice result for the given arguments. +// Returns -1 if arguments are illegal. +static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end, + int64_t stride) { + if (input_dim == -1 || start < 0 || start > end || end > input_dim || + stride == 0) + return -1; + + return llvm::divideCeil(end - start, stride); +} + +Type SliceOp::InferOutputTypes(Builder* builder, Value* operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides) { + Type ty = operand->getType(); + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) return ty; + int64_t rank = ranked_ty.getRank(); + + // Illegal attributes. + ShapedType attr_ty = start_indices.getType(); + if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || + !attr_ty.getElementType().isInteger(64) || + limit_indices.getType() != attr_ty || strides.getType() != attr_ty) + return ty; + + SmallVector start(start_indices.getValues()); + SmallVector limit(limit_indices.getValues()); + SmallVector stride_vals(strides.getValues()); + + SmallVector shape; + shape.reserve(rank); + for (int64_t i = 0, e = rank; i != e; i++) { + shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i], + stride_vals[i])); + } + return builder->getTensorType(shape, ranked_ty.getElementType()); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index e6e6a74b05c..79f2d79ae23 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -276,6 +276,21 @@ def HLO_SliceOp: HLO_Op< let results = (outs HLO_Tensor); + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *operand, " + "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " + "DenseIntElementsAttr strides" + >]; + + let extraClassDeclaration = [{ + // Infers output type for given operand and attributes. Result type is + // unranked if any of the attributes is illegal. + static Type InferOutputTypes(Builder *builder, Value *operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides); + }]; + // TODO(b/129422361) Two of the required arguments comes from the start and // limit indices which aren't handled by the codegen. let hasCustomHLOConverter = 1; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index be45ecebd65..3b5c65a1c0b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -533,3 +533,77 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> return %0 : tensor<1x2xf32> } + +// CHECK-LABEL: simple_strided_slice +func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { + %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: xla_hlo.slice + // CHECK-DAG-SAME: start_indices = dense<[0, 1]> + // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<3x2xf32> + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> + return %output : tensor<3x2xf32> +} + +// CHECK-LABEL: strided_slice_negative_indices +func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { + %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} + + // CHECK: xla_hlo.slice + // CHECK-DAG-SAME: start_indices = dense<[0, 1]> + // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<3x2xf32> + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> + return %output : tensor<3x2xf32> +} + +// CHECK-LABEL: strided_slice_range_clamping +func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<0x3xf32> { + %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[-1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + + // CHECK: xla_hlo.slice + // CHECK-DAG-SAME: start_indices = dense<[3, 0]> + // CHECK-DAG-SAME: limit_indices = dense<[3, 8]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<0x3xf32> + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<0x3xf32> + return %output : tensor<0x3xf32> +} + +// CHECK-LABEL: strided_slice_shrink_axis +func @strided_slice_shrink_axis(%input: tensor<4x8xf32>) -> tensor { + %begin = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: %[[SLICED:.*]] = "xla_hlo.slice" + // CHECK-DAG-SAME: start_indices = dense<[1, 3]> + // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> + // CHECK-DAG-SAME: strides = dense<[1, 3]> + // CHECK-SAME: -> tensor<1x1xf32> + + // CHECK: "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x1xf32>) -> tensor + + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) {shrink_axis_mask = 3 + : i64} : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %output : tensor +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index da6c5635813..95c94f99648 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir @@ -56,6 +57,15 @@ static size_t getFeatureDimension(StringAttr format, return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; } +// Returns 1D 64-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = builder->getTensorType( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseElementsAttr::get(ty, values) + .cast(); +} + static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank, Builder *b) { SmallVector index(attr.getType().getRank(), 0); @@ -427,6 +437,117 @@ class ConvertSoftmaxOp : public OpRewritePattern { } }; +// Converts StridedSlice op to HLO Slice op along with Reverse op to handle +// negative strides and Reshape op to update the output shape. Indices and +// strides operands are converted to attributes with non-negative indexing. +// +// For example with an op like following, +// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} +// : tensor -> tensor +// +// Output would be: +// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...} +// %sliced = "xla_hlo.Slice" (%input) +// {start_indices = ..., limit_indices = ..., strides = ...} +// %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// +class ConvertStridedSliceOp : public OpRewritePattern { + public: + explicit ConvertStridedSliceOp(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + PatternMatchResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { + // Input shape needs to be static to convert negative indices in TensorFlow + // to absolute indices required by HLO. + // + // TODO(hinsu): Relax this constraint for ops without negative indices and + // strides. + auto input_ty = op.input()->getType().dyn_cast(); + if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); + ArrayRef input_shape = input_ty.getShape(); + + // Output shape needs to be static to apply 'new_axis_mask' or + // 'shrink_axis_mask' by reshaping tensor after slice. + // + // TODO(hinsu): Relax this constraint for ops without the above masks. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || !result_ty.hasStaticShape()) return matchFailure(); + + // TODO(hinsu): Support non-zero mask values. Currently only + // 'shrink_axis_mask' is supported. + for (StringRef mask : + {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask"}) { + auto attr = op.getAttrOfType(mask); + if (attr && attr.getValue() != 0) return matchFailure(); + } + + // TODO(hinsu): Support lowering for ops with dynamic begin and end values + // when it is possible to derive indices based on mask attributes. + DenseIntElementsAttr begin_indices, end_indices, strides; + if (!matchPattern(op.begin(), m_Constant(&begin_indices)) || + !matchPattern(op.end(), m_Constant(&end_indices)) || + !matchPattern(op.strides(), m_Constant(&strides))) + return matchFailure(); + + SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, + dims_to_reverse; + int64_t input_rank = input_ty.getRank(); + for (auto *vec : {&hlo_begin_indices, &hlo_end_indices, &hlo_strides}) { + vec->reserve(input_rank); + } + + int64_t indices_elements = begin_indices.getNumElements(); + if (input_rank < indices_elements) return matchFailure(); + + // Convert from TensorFlow negative or out of range indices and strides + // values to legal HLO Slice attributes. + for (int i = 0, e = indices_elements; i != e; i++) { + int64_t begin = begin_indices.getValue(i).getInt(); + int64_t end = end_indices.getValue(i).getInt(); + int64_t stride = strides.getValue(i).getInt(); + + if (begin < 0) begin = input_shape[i] + begin; + if (end < 0) end = input_shape[i] + end; + + if (stride < 0) { + // Negative stride means that the output values are computed starting + // from end until begin. Mark the dimension for reversal before slice + // and compute indices for the reversed input. + dims_to_reverse.push_back(i); + begin = (input_shape[i] - 1) - begin; + end = (input_shape[i] - 1) - end; + stride = -stride; + } + + // Unlike TensorFlow, HLO requires begin and end values to be within + // range. + begin = std::max(int64_t(0), begin); + end = std::max(begin, end); + end = std::min(end, input_shape[i]); + + hlo_begin_indices.push_back(begin); + hlo_end_indices.push_back(end); + hlo_strides.push_back(stride); + } + + Location loc = op.getLoc(); + auto reversed = rewriter.create( + loc, input_ty, op.input(), + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + auto sliced = rewriter.create( + loc, reversed.getResult(), + GetI64ElementsAttr(hlo_begin_indices, &rewriter), + GetI64ElementsAttr(hlo_end_indices, &rewriter), + GetI64ElementsAttr(hlo_strides, &rewriter)); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" } // end anonymous namespace } // end namespace xla @@ -444,8 +565,10 @@ LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) { // here for lowering to HLO. mlir::TF::PopulateLoweringTFPatterns(context, &patterns); - patterns.insert(op->getContext()); + patterns + .insert( + op->getContext()); ConversionTarget target(*context); target.addLegalDialect();