Fix a crash when SparseTensor::Slice receives start/size larger than input tensor.

SparseTensor::Slice already reduces the size of the result if the selection is not fully contained in the input; this fix ensures that the same treatment applies when the start is already beyond the input boundary.

PiperOrigin-RevId: 315727283
Change-Id: I65acc099dd37932c32994e8026beb2be59d1c824
This commit is contained in:
A. Unique TensorFlower 2020-06-10 11:11:14 -07:00 committed by TensorFlower Gardener
parent 770b12dc60
commit 4acad672b7
2 changed files with 39 additions and 4 deletions
tensorflow/core/util/sparse

View File

@ -580,10 +580,22 @@ inline SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
const int dims = input_tensor.dims();
for (int dim = 0; dim < dims; dim++) {
int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
? size[dim]
: output_shape.dim_size(dim) - start[dim];
output_shape.set_dim(dim, dim_size);
// Determine the size of the result; if the selected slice goes beyond the
// input boundary, the result will correspond to the size of the overlap
// between the input and the selected slice.
const int64 input_size = output_shape.dim_size(dim);
const int64 start_index = start[dim];
const int64 slice_size = size[dim];
if (start_index + slice_size < input_size) {
// The entire selection is within input boundaries.
output_shape.set_dim(dim, slice_size);
} else if (start_index < input_size) {
// The selection starts within input boundaries, but goes beyond them.
output_shape.set_dim(dim, input_size - start_index);
} else {
// The selection is entirely out of input boundaries.
output_shape.set_dim(dim, 0);
}
}
auto input_indices_t = input_tensor.indices().matrix<int64>();

View File

@ -692,6 +692,29 @@ TEST(SparseTensorTest, Slice) {
EXPECT_EQ(slice.indices().matrix<int64>()(2, 1), 2);
}
TEST(SparseTensorTest, SliceReducesOutputDimension) {
const int num_rows = 2;
const int num_columns = 2;
Tensor ids(DT_INT64, TensorShape({num_rows, num_columns}));
ids.matrix<int64>()(0, 0) = 0;
ids.matrix<int64>()(0, 1) = 0;
ids.matrix<int64>()(1, 0) = 1;
ids.matrix<int64>()(1, 1) = 1;
Tensor vals(DT_INT64, TensorShape({2}));
vals.vec<int64>()(0) = 1;
vals.vec<int64>()(1) = 2;
SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ids, vals,
TensorShape({num_rows, num_columns}), &st));
SparseTensor slice =
SparseTensor::Slice<int64>(st, {num_rows + 1, 1}, {1, num_columns});
EXPECT_EQ(TensorShape(slice.shape()), TensorShape({0, 1}));
}
TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) {
Tensor ix(DT_INT64, TensorShape({1, 0}));
Tensor vals(DT_INT32, TensorShape({1}));