Use only positive strides in the XLA stride op
This commit is contained in:
parent
5b509dd764
commit
db58c826ae
@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
std::vector<int64> begin(input_shape.dims(), 0);
|
||||
auto dim_sizes = input_shape.dim_sizes();
|
||||
std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
|
||||
std::vector<int64> strides(input_shape.dims(), 0);
|
||||
std::vector<int64> strides(input_shape.dims(), 1);
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
TensorShape output_shape(input_shape);
|
||||
int slice_size = split_sizes_vec[i];
|
||||
|
@ -75,8 +75,30 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
|
||||
&dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
xla::ComputationDataHandle slice =
|
||||
ctx->builder()->Slice(ctx->Input(0), begin, end, strides);
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
if (strides[i] > 0) {
|
||||
slice_begin.push_back(begin[i]);
|
||||
slice_end.push_back(end[i]);
|
||||
slice_strides.push_back(strides[i]);
|
||||
} else {
|
||||
// Negative stride: swap begin and end, add 1 because the interval
|
||||
// is semi-open, and mark the dimension to be reversed.
|
||||
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
|
||||
slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
|
||||
slice_strides.push_back(-strides[i]);
|
||||
dimensions_to_reverse.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle slice = ctx->Input(0);
|
||||
if (!dimensions_to_reverse.empty()) {
|
||||
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
|
||||
}
|
||||
|
||||
slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
|
||||
|
||||
slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
|
@ -1056,13 +1056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
}
|
||||
sizes.push_back((limit_index - start_index + stride - 1) / stride);
|
||||
} else {
|
||||
if (start_index < limit_index) {
|
||||
return InvalidArgument(
|
||||
"limit index (%lld) must be less than or equal to "
|
||||
"start index (%lld) in slice with negative stride",
|
||||
limit_index, start_index);
|
||||
}
|
||||
sizes.push_back((start_index - limit_index - stride - 1) / -stride);
|
||||
return InvalidArgument("Negative strides not supported");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -697,15 +697,6 @@ TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithNegativeStrides) {
|
||||
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferSliceShape(matrix_shape, {64, 64}, {32, 0}, {-1, -4});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
Shape inferred = inferred_status.ValueOrDie();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 16}), inferred));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
|
||||
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
||||
auto inferred_status =
|
||||
|
Loading…
Reference in New Issue
Block a user