Modify HLO DynamicSlice op definition such that start indices are a Variadic list rather than tensor.

This allows for DynamicSlice ops where the start indices are the result of op computations which cannot be folded at compile time.

Aside from modifying the op definition, this requires updating:
- the TF Slice to HLO DynamicSlice legalization
- the HLO DynamicSlice to Slice canonicalization
- the HLO importer
- the IREE lowering from HLO DynamicSlice to VMLA

I also added a test for the HLO DynamicSlice exporter, as there wasn't one previously.

PiperOrigin-RevId: 306495966
Change-Id: I5a080c63eaeedfceda034987cdf16177393ed3e8
This commit is contained in:
Lucy Fox 2020-04-14 12:51:16 -07:00 committed by TensorFlower Gardener
parent 85ea23af35
commit 6f1bbf2cc2
11 changed files with 176 additions and 55 deletions

View File

@ -296,9 +296,11 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
std::vector<int64_t> slice_sizes(
instruction->dynamic_slice_sizes().begin(),
instruction->dynamic_slice_sizes().end());
attributes.push_back(
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
MakeAndReturn(DynamicSliceOp);
return func_builder
->create<mlir::xla_hlo::DynamicSliceOp>(
loc, result_type, operands[0],
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder

View File

@ -836,11 +836,64 @@ static LogicalResult Verify(ConcatenateOp op) {
// DynamicSliceOp
//===----------------------------------------------------------------------===//
namespace {
// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
// This canonicalization is applied the case when the `begin` input values are
// compile time constants and thus can be made into a tensor.
struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
PatternRewriter& rewriter) const override {
Value input = dynamic_slice.operand();
auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
if (!input_tensor) return failure();
SmallVector<int64_t, 4> temp_start_indices;
for (Value start : dynamic_slice.start_indices()) {
APInt val;
if (!matchPattern(start, m_ConstantInt(&val))) {
return failure();
}
temp_start_indices.push_back(*(val.getRawData()));
}
// At this point we've determined that the start indices are all constants;
// pack them into a single tensor.
auto loc = dynamic_slice.getLoc();
int64_t input_rank = input_tensor.getRank();
auto slice_start_indices =
GetI64ElementsAttr(temp_start_indices, &rewriter);
DenseIntElementsAttr slice_limits = BuildSliceLimits(
slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
DenseIntElementsAttr slice_strides =
GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
slice_limits, slice_strides);
rewriter.replaceOp(dynamic_slice, {result});
return success();
}
};
} // namespace
void DynamicSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicSliceToSlice>(context);
}
// Verifies that the number of slice sizes and the number of start indices match
static LogicalResult Verify(DynamicSliceOp op) {
int num_slice_sizes = op.slice_sizes().getNumElements();
int num_start_indices = op.start_indices().size();
if (num_start_indices != num_slice_sizes) {
return op.emitOpError()
<< "has mismatched number of slice sizes (" << num_slice_sizes
<< ") and number of start indices (" << num_start_indices << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// InfeedOp
//===----------------------------------------------------------------------===//

View File

@ -661,11 +661,10 @@ def HLO_SliceOp: HLO_Op<
}
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
AllShapesMatch<["start_indices", "slice_sizes"]>]> {
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
let arguments = (ins
HLO_Tensor:$operand,
HLO_Tensor:$start_indices,
Variadic<HLO_ScalarIntTensor>:$start_indices,
I64ElementsAttr:$slice_sizes
);

View File

@ -1,9 +1,8 @@
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice"
%0 = xla_hlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
}
@ -14,21 +13,22 @@ func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK: return %[[RESULT]] : tensor<2xi32>
%0 = xla_hlo.constant dense<1> : tensor<1xi64>
%2 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
return %2 : tensor<2xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = xla_hlo.constant dense<0> : tensor<i64>
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
return %2 : tensor<?x4xi32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic

View File

@ -2249,7 +2249,16 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START:.*]]) :
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]])
// CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
// CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
// CHECK: return %[[RESULT]] : tensor<2xi32>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
@ -2261,7 +2270,12 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: slice_sizes = dense<2> : tensor<1xi64>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
%0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
@ -2272,7 +2286,12 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32>
// CHECK: return %[[RESULT]] : tensor<3xi32>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
@ -2284,7 +2303,24 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
// CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) :
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
// CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) :
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"
// CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]])
// CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
// CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
%starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
@ -2295,7 +2331,14 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2
// CHECK-LABEL: slice_variable_start
func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>

View File

@ -551,37 +551,45 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
// -----
// CHECK-LABEL: func @dynamic_slice
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// expected-error@+1 {{failed to verify that all of {start_indices, slice_sizes} have same shape}}
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}}
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @dynamic_slice_different_indice_element_type
func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<1xi32>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<1xi32>) -> tensor<1x4xi32>
func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<i32>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> {
func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xf32> {
// expected-error@+1 {{failed to verify that all of {operand, result} have same element type}}
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xf32>
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>

View File

@ -860,6 +860,21 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
// -----
// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// CHECK: ENTRY
// CHECK: %[[ARG:.*]] = s32[3,4] parameter(0)
// CHECK: %[[ARG1:.*]] = s64[] parameter(1)
// CHECK: %[[ARG2:.*]] = s64[] parameter(2)
// CHECK: ROOT
// CHECK-SAME: s32[1,4] dynamic-slice(s32[3,4] %[[ARG]], s64[] %[[ARG1]], s64[] %[[ARG2]]), dynamic_slice_sizes={1,4}
// -----
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
"xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()

View File

@ -347,13 +347,15 @@ add {
}
// CHECK-LABEL: func @test_dynamic_slice
// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_INDICES:%.*]]: tensor<3xi32>
// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor<i32>, [[START_IDX_2:%.*]]: tensor<i32>, [[START_IDX_3:%.*]]: tensor<i32>
%test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] {
%operand = s32[2,2,258] parameter(0)
%start_indices = s32[3] parameter(1)
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_INDICES]])
%start_idx_1 = s32[] parameter(1)
%start_idx_2 = s32[] parameter(2)
%start_idx_3 = s32[] parameter(3)
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[3] %start_indices), dynamic_slice_sizes={1,1,32}
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
}
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {

View File

@ -19,25 +19,6 @@ include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
//===----------------------------------------------------------------------===//
// DynamicSlice op patterns.
//===----------------------------------------------------------------------===//
def BuildSliceLimits : NativeCodeCall<
"BuildSliceLimits($0.cast<DenseIntElementsAttr>(),"
"$1.cast<DenseIntElementsAttr>(), &$_builder)">;
def BuildSliceStrides : NativeCodeCall<
"GetI64ElementsAttr(SmallVector<int64_t, 4>("
"$0.getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">;
def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
(HLO_ConstOp I64ElementsAttr:$starting_indices),
I64ElementsAttr:$slice_sizes),
(HLO_SliceOp $input, (CastIntElementsAttr $starting_indices),
(BuildSliceLimits $starting_indices, $slice_sizes),
(BuildSliceStrides $input))>;
def UnaryToBinaryEinsumEq : NativeCodeCall<
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;

View File

@ -168,6 +168,20 @@ static ConvertOp CastValueToI64(Location loc, Value value,
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
}
// Creates an unpack op along the 0th dimension of the tensor. The `value` input
// must be a ranked tensor.
static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
PatternRewriter *rewriter) {
auto indices_type = value.getType().cast<RankedTensorType>();
int num_outputs = indices_type.getShape().front();
SmallVector<Type, 2> unpacked_indices_type(
num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
auto unpacked_indices = rewriter->create<TF::UnpackOp>(
loc, unpacked_indices_type, value,
IntegerAttr::get(rewriter->getIntegerType(64), 0));
return unpacked_indices;
}
// Returns size of dimension at the specified index, if ranked tensor.
// Otherwise, returns -1.
//

View File

@ -479,6 +479,9 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu
// Slice op patterns.
//===----------------------------------------------------------------------===//
def CastToI64AndUnpackTensor: NativeCodeCall<
"UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">;
def CanBeTranslatedToDynamicSlice : Constraint<CPred<
"CanBeTranslatedToDynamicSlice($0, $1, $2.cast<DenseIntElementsAttr>())">>;
@ -488,7 +491,8 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall<
def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
(TF_ConstOp $slice_sizes)),
(HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices),
(HLO_DynamicSliceOp $input,
(CastToI64AndUnpackTensor $op, $starting_indices),
(TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
[(CanBeTranslatedToDynamicSlice $input, $starting_indices,
$slice_sizes)]>;