Use only positive strides in the XLA stride op

This commit is contained in:
DavidNorman 2017-06-13 07:34:44 +01:00 committed by Martin Wicke
parent 5b509dd764
commit db58c826ae
4 changed files with 26 additions and 19 deletions

View File

@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel {
std::vector<int64> begin(input_shape.dims(), 0); std::vector<int64> begin(input_shape.dims(), 0);
auto dim_sizes = input_shape.dim_sizes(); auto dim_sizes = input_shape.dim_sizes();
std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end()); std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
std::vector<int64> strides(input_shape.dims(), 0); std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < num_split; ++i) { for (int i = 0; i < num_split; ++i) {
TensorShape output_shape(input_shape); TensorShape output_shape(input_shape);
int slice_size = split_sizes_vec[i]; int slice_size = split_sizes_vec[i];

View File

@ -75,8 +75,30 @@ class StridedSliceOp : public XlaOpKernel {
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
&dummy, &dummy, &begin, &end, &strides)); &dummy, &dummy, &begin, &end, &strides));
xla::ComputationDataHandle slice = gtl::InlinedVector<int64, 4> dimensions_to_reverse;
ctx->builder()->Slice(ctx->Input(0), begin, end, strides); gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
for (int i = 0; i < begin.size(); ++i) {
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_end.push_back(end[i]);
slice_strides.push_back(strides[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
slice_strides.push_back(-strides[i]);
dimensions_to_reverse.push_back(i);
}
}
xla::ComputationDataHandle slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
}
slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice); ctx->SetOutput(0, slice);

View File

@ -1056,13 +1056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
} }
sizes.push_back((limit_index - start_index + stride - 1) / stride); sizes.push_back((limit_index - start_index + stride - 1) / stride);
} else { } else {
if (start_index < limit_index) { return InvalidArgument("Negative strides not supported");
return InvalidArgument(
"limit index (%lld) must be less than or equal to "
"start index (%lld) in slice with negative stride",
limit_index, start_index);
}
sizes.push_back((start_index - limit_index - stride - 1) / -stride);
} }
} }

View File

@ -697,15 +697,6 @@ TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
} }
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithNegativeStrides) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
ShapeInference::InferSliceShape(matrix_shape, {64, 64}, {32, 0}, {-1, -4});
ASSERT_IS_OK(inferred_status.status());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 16}), inferred));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status = auto inferred_status =