Modify HLO DynamicSlice op definition such that start indices are a Variadic list rather than tensor.
This allows for DynamicSlice ops where the start indices are the result of op computations which cannot be folded at compile time. Aside from modifying the op definition, this requires updating: - the TF Slice to HLO DynamicSlice legalization - the HLO DynamicSlice to Slice canonicalization - the HLO importer - the IREE lowering from HLO DynamicSlice to VMLA I also added a test for the HLO DynamicSlice exporter, as there wasn't one previously. PiperOrigin-RevId: 306495966 Change-Id: I5a080c63eaeedfceda034987cdf16177393ed3e8
This commit is contained in:
parent
85ea23af35
commit
6f1bbf2cc2
@ -296,9 +296,11 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
std::vector<int64_t> slice_sizes(
|
std::vector<int64_t> slice_sizes(
|
||||||
instruction->dynamic_slice_sizes().begin(),
|
instruction->dynamic_slice_sizes().begin(),
|
||||||
instruction->dynamic_slice_sizes().end());
|
instruction->dynamic_slice_sizes().end());
|
||||||
attributes.push_back(
|
return func_builder
|
||||||
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
|
->create<mlir::xla_hlo::DynamicSliceOp>(
|
||||||
MakeAndReturn(DynamicSliceOp);
|
loc, result_type, operands[0],
|
||||||
|
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
|
||||||
|
.getOperation();
|
||||||
}
|
}
|
||||||
case HloOpcode::kDynamicUpdateSlice: {
|
case HloOpcode::kDynamicUpdateSlice: {
|
||||||
return func_builder
|
return func_builder
|
||||||
|
@ -836,11 +836,64 @@ static LogicalResult Verify(ConcatenateOp op) {
|
|||||||
// DynamicSliceOp
|
// DynamicSliceOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
|
||||||
|
// This canonicalization is applied the case when the `begin` input values are
|
||||||
|
// compile time constants and thus can be made into a tensor.
|
||||||
|
struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
|
||||||
|
using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Value input = dynamic_slice.operand();
|
||||||
|
auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!input_tensor) return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> temp_start_indices;
|
||||||
|
for (Value start : dynamic_slice.start_indices()) {
|
||||||
|
APInt val;
|
||||||
|
if (!matchPattern(start, m_ConstantInt(&val))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
temp_start_indices.push_back(*(val.getRawData()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point we've determined that the start indices are all constants;
|
||||||
|
// pack them into a single tensor.
|
||||||
|
auto loc = dynamic_slice.getLoc();
|
||||||
|
int64_t input_rank = input_tensor.getRank();
|
||||||
|
auto slice_start_indices =
|
||||||
|
GetI64ElementsAttr(temp_start_indices, &rewriter);
|
||||||
|
DenseIntElementsAttr slice_limits = BuildSliceLimits(
|
||||||
|
slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
|
||||||
|
DenseIntElementsAttr slice_strides =
|
||||||
|
GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
|
||||||
|
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
|
||||||
|
slice_limits, slice_strides);
|
||||||
|
rewriter.replaceOp(dynamic_slice, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void DynamicSliceOp::getCanonicalizationPatterns(
|
void DynamicSliceOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicSliceToSlice>(context);
|
results.insert<DynamicSliceToSlice>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verifies that the number of slice sizes and the number of start indices match
|
||||||
|
static LogicalResult Verify(DynamicSliceOp op) {
|
||||||
|
int num_slice_sizes = op.slice_sizes().getNumElements();
|
||||||
|
int num_start_indices = op.start_indices().size();
|
||||||
|
if (num_start_indices != num_slice_sizes) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "has mismatched number of slice sizes (" << num_slice_sizes
|
||||||
|
<< ") and number of start indices (" << num_start_indices << ")";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// InfeedOp
|
// InfeedOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -661,11 +661,10 @@ def HLO_SliceOp: HLO_Op<
|
|||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
||||||
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
|
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
|
||||||
AllShapesMatch<["start_indices", "slice_sizes"]>]> {
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_Tensor:$start_indices,
|
Variadic<HLO_ScalarIntTensor>:$start_indices,
|
||||||
I64ElementsAttr:$slice_sizes
|
I64ElementsAttr:$slice_sizes
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||||
// CHECK: "xla_hlo.dynamic-slice"
|
// CHECK: "xla_hlo.dynamic-slice"
|
||||||
%0 = xla_hlo.constant dense<[1, 4]> : tensor<2xi64>
|
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
|
||||||
return %1 : tensor<1x4xi32>
|
return %1 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,21 +13,22 @@ func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
|||||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
|
||||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||||
%0 = xla_hlo.constant dense<1> : tensor<1xi64>
|
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||||
%2 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
|
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||||
return %2 : tensor<2xi32>
|
return %1 : tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
|
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
|
||||||
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
|
||||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
||||||
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
|
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
|
||||||
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
|
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
|
||||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
|
||||||
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
|
||||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
%1 = xla_hlo.constant dense<0> : tensor<i64>
|
||||||
return %1 : tensor<1x4xi32>
|
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||||
|
return %2 : tensor<?x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
||||||
|
@ -2249,7 +2249,16 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
|||||||
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
|
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START:.*]]) :
|
||||||
|
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]])
|
||||||
|
// CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||||
@ -2261,7 +2270,12 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
|||||||
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
|
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
|
||||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
|
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
|
||||||
// CHECK: slice_sizes = dense<2> : tensor<1xi64>
|
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
||||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
||||||
%0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
%0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||||
@ -2272,7 +2286,12 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
|||||||
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
|
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
|
||||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
|
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<3xi32>
|
// CHECK: return %[[RESULT]] : tensor<3xi32>
|
||||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||||
%sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
%sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||||
@ -2284,7 +2303,24 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
|
|||||||
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
|
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
|
||||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) :
|
||||||
|
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) :
|
||||||
|
// CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"
|
||||||
|
// CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]])
|
||||||
|
// CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
|
||||||
|
// CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
||||||
%starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
%starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||||
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||||
@ -2295,7 +2331,14 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2
|
|||||||
// CHECK-LABEL: slice_variable_start
|
// CHECK-LABEL: slice_variable_start
|
||||||
func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
|
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
|
||||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||||
|
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
||||||
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||||
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
|
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||||
|
@ -551,37 +551,45 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @dynamic_slice
|
// CHECK-LABEL: func @dynamic_slice
|
||||||
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||||
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||||
// expected-error@+1 {{failed to verify that all of {start_indices, slice_sizes} have same shape}}
|
// expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}}
|
||||||
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @dynamic_slice_different_indice_element_type
|
// CHECK-LABEL: @dynamic_slice_different_indice_element_type
|
||||||
func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<1xi32>) -> tensor<1x4xi32> {
|
func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<i32>) -> tensor<1x4xi32> {
|
||||||
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<1xi32>) -> tensor<1x4xi32>
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<i32>) -> tensor<1x4xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> {
|
func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xf32> {
|
||||||
// expected-error@+1 {{failed to verify that all of {operand, result} have same element type}}
|
// expected-error@+1 {{failed to verify that all of {operand, result} have same element type}}
|
||||||
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xf32>
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xf32>
|
||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
|
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
|
||||||
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @transpose
|
// CHECK-LABEL: func @transpose
|
||||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||||
|
@ -860,6 +860,21 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xi32> {
|
||||||
|
%0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: ENTRY
|
||||||
|
// CHECK: %[[ARG:.*]] = s32[3,4] parameter(0)
|
||||||
|
// CHECK: %[[ARG1:.*]] = s64[] parameter(1)
|
||||||
|
// CHECK: %[[ARG2:.*]] = s64[] parameter(2)
|
||||||
|
// CHECK: ROOT
|
||||||
|
// CHECK-SAME: s32[1,4] dynamic-slice(s32[3,4] %[[ARG]], s64[] %[[ARG1]], s64[] %[[ARG2]]), dynamic_slice_sizes={1,4}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK: HloModule
|
// CHECK: HloModule
|
||||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
"xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
|
"xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
|
||||||
|
@ -347,13 +347,15 @@ add {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_dynamic_slice
|
// CHECK-LABEL: func @test_dynamic_slice
|
||||||
// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_INDICES:%.*]]: tensor<3xi32>
|
// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor<i32>, [[START_IDX_2:%.*]]: tensor<i32>, [[START_IDX_3:%.*]]: tensor<i32>
|
||||||
%test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] {
|
%test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] {
|
||||||
%operand = s32[2,2,258] parameter(0)
|
%operand = s32[2,2,258] parameter(0)
|
||||||
%start_indices = s32[3] parameter(1)
|
%start_idx_1 = s32[] parameter(1)
|
||||||
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_INDICES]])
|
%start_idx_2 = s32[] parameter(2)
|
||||||
|
%start_idx_3 = s32[] parameter(3)
|
||||||
|
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
|
||||||
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
|
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
|
||||||
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[3] %start_indices), dynamic_slice_sizes={1,1,32}
|
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
|
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
|
||||||
|
@ -19,25 +19,6 @@ include "mlir/IR/OpBase.td"
|
|||||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||||
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
|
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// DynamicSlice op patterns.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def BuildSliceLimits : NativeCodeCall<
|
|
||||||
"BuildSliceLimits($0.cast<DenseIntElementsAttr>(),"
|
|
||||||
"$1.cast<DenseIntElementsAttr>(), &$_builder)">;
|
|
||||||
|
|
||||||
def BuildSliceStrides : NativeCodeCall<
|
|
||||||
"GetI64ElementsAttr(SmallVector<int64_t, 4>("
|
|
||||||
"$0.getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">;
|
|
||||||
|
|
||||||
def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
|
|
||||||
(HLO_ConstOp I64ElementsAttr:$starting_indices),
|
|
||||||
I64ElementsAttr:$slice_sizes),
|
|
||||||
(HLO_SliceOp $input, (CastIntElementsAttr $starting_indices),
|
|
||||||
(BuildSliceLimits $starting_indices, $slice_sizes),
|
|
||||||
(BuildSliceStrides $input))>;
|
|
||||||
|
|
||||||
def UnaryToBinaryEinsumEq : NativeCodeCall<
|
def UnaryToBinaryEinsumEq : NativeCodeCall<
|
||||||
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
||||||
|
|
||||||
|
@ -168,6 +168,20 @@ static ConvertOp CastValueToI64(Location loc, Value value,
|
|||||||
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
|
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates an unpack op along the 0th dimension of the tensor. The `value` input
|
||||||
|
// must be a ranked tensor.
|
||||||
|
static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
|
||||||
|
PatternRewriter *rewriter) {
|
||||||
|
auto indices_type = value.getType().cast<RankedTensorType>();
|
||||||
|
int num_outputs = indices_type.getShape().front();
|
||||||
|
SmallVector<Type, 2> unpacked_indices_type(
|
||||||
|
num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
|
||||||
|
auto unpacked_indices = rewriter->create<TF::UnpackOp>(
|
||||||
|
loc, unpacked_indices_type, value,
|
||||||
|
IntegerAttr::get(rewriter->getIntegerType(64), 0));
|
||||||
|
return unpacked_indices;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns size of dimension at the specified index, if ranked tensor.
|
// Returns size of dimension at the specified index, if ranked tensor.
|
||||||
// Otherwise, returns -1.
|
// Otherwise, returns -1.
|
||||||
//
|
//
|
||||||
|
@ -479,6 +479,9 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu
|
|||||||
// Slice op patterns.
|
// Slice op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def CastToI64AndUnpackTensor: NativeCodeCall<
|
||||||
|
"UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">;
|
||||||
|
|
||||||
def CanBeTranslatedToDynamicSlice : Constraint<CPred<
|
def CanBeTranslatedToDynamicSlice : Constraint<CPred<
|
||||||
"CanBeTranslatedToDynamicSlice($0, $1, $2.cast<DenseIntElementsAttr>())">>;
|
"CanBeTranslatedToDynamicSlice($0, $1, $2.cast<DenseIntElementsAttr>())">>;
|
||||||
|
|
||||||
@ -488,7 +491,8 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall<
|
|||||||
|
|
||||||
def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
|
def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
|
||||||
(TF_ConstOp $slice_sizes)),
|
(TF_ConstOp $slice_sizes)),
|
||||||
(HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices),
|
(HLO_DynamicSliceOp $input,
|
||||||
|
(CastToI64AndUnpackTensor $op, $starting_indices),
|
||||||
(TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
|
(TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
|
||||||
[(CanBeTranslatedToDynamicSlice $input, $starting_indices,
|
[(CanBeTranslatedToDynamicSlice $input, $starting_indices,
|
||||||
$slice_sizes)]>;
|
$slice_sizes)]>;
|
||||||
|
Loading…
Reference in New Issue
Block a user