Lower TensorFlow StridedSlice op to HLO

Currently, this only handles the simple case with constant begin and end operands and no mask attributes except shrink_axis_mask.

StridedSlice op is lowered to HLO Slice op along with Reverse op to handle negative stride values and Reshape op to update the output shape to handle change in the output rank.

Added a new build for the HLO Slice op that computes output type based on the operand type and attributes.

PiperOrigin-RevId: 270360518
This commit is contained in:
Smit Hinsu 2019-09-20 15:15:24 -07:00 committed by TensorFlower Gardener
parent f7464a6c05
commit ddd0a880cb
4 changed files with 269 additions and 2 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
@ -533,6 +534,60 @@ static LogicalResult Verify(PadOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
void SliceOp::build(Builder* builder, OperationState* result, Value* operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
return build(
builder, result,
InferOutputTypes(builder, operand, start_indices, limit_indices, strides),
operand, start_indices, limit_indices, strides);
}
// Returns output dimension size for slice result for the given arguments.
// Returns -1 if arguments are illegal.
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
int64_t stride) {
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
stride == 0)
return -1;
return llvm::divideCeil(end - start, stride);
}
Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
Type ty = operand->getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return ty;
int64_t rank = ranked_ty.getRank();
// Illegal attributes.
ShapedType attr_ty = start_indices.getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isInteger(64) ||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
return ty;
SmallVector<int64_t, 4> start(start_indices.getValues<int64_t>());
SmallVector<int64_t, 4> limit(limit_indices.getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(strides.getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
return builder->getTensorType(shape, ranked_ty.getElementType());
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

View File

@ -276,6 +276,21 @@ def HLO_SliceOp: HLO_Op<
let results = (outs HLO_Tensor);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Value *operand, "
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
"DenseIntElementsAttr strides"
>];
let extraClassDeclaration = [{
// Infers output type for given operand and attributes. Result type is
// unranked if any of the attributes is illegal.
static Type InferOutputTypes(Builder *builder, Value *operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides);
}];
// TODO(b/129422361) Two of the required arguments comes from the start and
// limit indices which aren't handled by the codegen.
let hasCustomHLOConverter = 1;

View File

@ -533,3 +533,77 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
%0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// CHECK-LABEL: simple_strided_slice
func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
%begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// CHECK: xla_hlo.slice
// CHECK-DAG-SAME: start_indices = dense<[0, 1]>
// CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
// CHECK-DAG-SAME: strides = dense<[1, 3]>
// CHECK-SAME: -> tensor<3x2xf32>
%output = "tf.StridedSlice"(%input, %begin, %end, %strides)
: (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
return %output : tensor<3x2xf32>
}
// CHECK-LABEL: strided_slice_negative_indices
func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
%begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.slice
// CHECK-DAG-SAME: start_indices = dense<[0, 1]>
// CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
// CHECK-DAG-SAME: strides = dense<[1, 3]>
// CHECK-SAME: -> tensor<3x2xf32>
%output = "tf.StridedSlice"(%input, %begin, %end, %strides)
: (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
return %output : tensor<3x2xf32>
}
// CHECK-LABEL: strided_slice_range_clamping
func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<0x3xf32> {
%begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%end = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%strides = "tf.Const"() {value = dense<[-1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// CHECK: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
// CHECK: xla_hlo.slice
// CHECK-DAG-SAME: start_indices = dense<[3, 0]>
// CHECK-DAG-SAME: limit_indices = dense<[3, 8]>
// CHECK-DAG-SAME: strides = dense<[1, 3]>
// CHECK-SAME: -> tensor<0x3xf32>
%output = "tf.StridedSlice"(%input, %begin, %end, %strides)
: (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<0x3xf32>
return %output : tensor<0x3xf32>
}
// CHECK-LABEL: strided_slice_shrink_axis
func @strided_slice_shrink_axis(%input: tensor<4x8xf32>) -> tensor<f32> {
%begin = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%end = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>)
%strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// CHECK: %[[SLICED:.*]] = "xla_hlo.slice"
// CHECK-DAG-SAME: start_indices = dense<[1, 3]>
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]>
// CHECK-DAG-SAME: strides = dense<[1, 3]>
// CHECK-SAME: -> tensor<1x1xf32>
// CHECK: "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x1xf32>) -> tensor<f32>
%output = "tf.StridedSlice"(%input, %begin, %end, %strides) {shrink_axis_mask = 3
: i64} : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<f32>
return %output : tensor<f32>
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
@ -56,6 +57,15 @@ static size_t getFeatureDimension(StringAttr format,
return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
}
// Returns 1D 64-bit dense elements attribute with the given values.
static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder *builder) {
RankedTensorType ty = builder->getTensorType(
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
return DenseElementsAttr::get<int64_t>(ty, values)
.cast<DenseIntElementsAttr>();
}
static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank,
Builder *b) {
SmallVector<uint64_t, 1> index(attr.getType().getRank(), 0);
@ -427,6 +437,117 @@ class ConvertSoftmaxOp : public OpRewritePattern<TF::SoftmaxOp> {
}
};
// Converts StridedSlice op to HLO Slice op along with Reverse op to handle
// negative strides and Reshape op to update the output shape. Indices and
// strides operands are converted to attributes with non-negative indexing.
//
// For example with an op like following,
// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
// : tensor<AxBxf32> -> tensor<Pxf32>
//
// Output would be:
// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
// %sliced = "xla_hlo.Slice" (%input)
// {start_indices = ..., limit_indices = ..., strides = ...}
// %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
//
class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
public:
explicit ConvertStridedSliceOp(MLIRContext *context)
: OpRewritePattern<TF::StridedSliceOp>(context, 1) {}
PatternMatchResult matchAndRewrite(TF::StridedSliceOp op,
PatternRewriter &rewriter) const override {
// Input shape needs to be static to convert negative indices in TensorFlow
// to absolute indices required by HLO.
//
// TODO(hinsu): Relax this constraint for ops without negative indices and
// strides.
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> input_shape = input_ty.getShape();
// Output shape needs to be static to apply 'new_axis_mask' or
// 'shrink_axis_mask' by reshaping tensor after slice.
//
// TODO(hinsu): Relax this constraint for ops without the above masks.
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
if (!result_ty || !result_ty.hasStaticShape()) return matchFailure();
// TODO(hinsu): Support non-zero mask values. Currently only
// 'shrink_axis_mask' is supported.
for (StringRef mask :
{"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask"}) {
auto attr = op.getAttrOfType<IntegerAttr>(mask);
if (attr && attr.getValue() != 0) return matchFailure();
}
// TODO(hinsu): Support lowering for ops with dynamic begin and end values
// when it is possible to derive indices based on mask attributes.
DenseIntElementsAttr begin_indices, end_indices, strides;
if (!matchPattern(op.begin(), m_Constant(&begin_indices)) ||
!matchPattern(op.end(), m_Constant(&end_indices)) ||
!matchPattern(op.strides(), m_Constant(&strides)))
return matchFailure();
SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
dims_to_reverse;
int64_t input_rank = input_ty.getRank();
for (auto *vec : {&hlo_begin_indices, &hlo_end_indices, &hlo_strides}) {
vec->reserve(input_rank);
}
int64_t indices_elements = begin_indices.getNumElements();
if (input_rank < indices_elements) return matchFailure();
// Convert from TensorFlow negative or out of range indices and strides
// values to legal HLO Slice attributes.
for (int i = 0, e = indices_elements; i != e; i++) {
int64_t begin = begin_indices.getValue<IntegerAttr>(i).getInt();
int64_t end = end_indices.getValue<IntegerAttr>(i).getInt();
int64_t stride = strides.getValue<IntegerAttr>(i).getInt();
if (begin < 0) begin = input_shape[i] + begin;
if (end < 0) end = input_shape[i] + end;
if (stride < 0) {
// Negative stride means that the output values are computed starting
// from end until begin. Mark the dimension for reversal before slice
// and compute indices for the reversed input.
dims_to_reverse.push_back(i);
begin = (input_shape[i] - 1) - begin;
end = (input_shape[i] - 1) - end;
stride = -stride;
}
// Unlike TensorFlow, HLO requires begin and end values to be within
// range.
begin = std::max(int64_t(0), begin);
end = std::max(begin, end);
end = std::min(end, input_shape[i]);
hlo_begin_indices.push_back(begin);
hlo_end_indices.push_back(end);
hlo_strides.push_back(stride);
}
Location loc = op.getLoc();
auto reversed = rewriter.create<xla_hlo::ReverseOp>(
loc, input_ty, op.input(),
GetI64ElementsAttr(dims_to_reverse, &rewriter));
auto sliced = rewriter.create<xla_hlo::SliceOp>(
loc, reversed.getResult(),
GetI64ElementsAttr(hlo_begin_indices, &rewriter),
GetI64ElementsAttr(hlo_end_indices, &rewriter),
GetI64ElementsAttr(hlo_strides, &rewriter));
// Reshape slice result so that the shape is updated depending on
// 'new_axis_mask' or 'shrink_axis_mask' attributes.
rewriter.replaceOpWithNewOp<xla_hlo::ReshapeOp>(op, op.getType(), sliced);
return matchSuccess();
}
};
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
} // end anonymous namespace
} // end namespace xla
@ -444,8 +565,10 @@ LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) {
// here for lowering to HLO.
mlir::TF::PopulateLoweringTFPatterns(context, &patterns);
patterns.insert<mlir::xla::ConvertMaxPoolOp, mlir::xla::ConvertSigmoidOp,
mlir::xla::ConvertSoftmaxOp>(op->getContext());
patterns
.insert<mlir::xla::ConvertMaxPoolOp, mlir::xla::ConvertSigmoidOp,
mlir::xla::ConvertSoftmaxOp, mlir::xla::ConvertStridedSliceOp>(
op->getContext());
ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>();