Fix a crash when SparseTensor::Slice receives start/size larger than input tensor.
SparseTensor::Slice already reduces the size of the result if the selection is not fully contained in the input; this fix ensures that the same treatment applies when the start is already beyond the input boundary. PiperOrigin-RevId: 315727283 Change-Id: I65acc099dd37932c32994e8026beb2be59d1c824
This commit is contained in:
parent
770b12dc60
commit
4acad672b7
tensorflow/core/util/sparse
@ -580,10 +580,22 @@ inline SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
|
||||
|
||||
const int dims = input_tensor.dims();
|
||||
for (int dim = 0; dim < dims; dim++) {
|
||||
int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
|
||||
? size[dim]
|
||||
: output_shape.dim_size(dim) - start[dim];
|
||||
output_shape.set_dim(dim, dim_size);
|
||||
// Determine the size of the result; if the selected slice goes beyond the
|
||||
// input boundary, the result will correspond to the size of the overlap
|
||||
// between the input and the selected slice.
|
||||
const int64 input_size = output_shape.dim_size(dim);
|
||||
const int64 start_index = start[dim];
|
||||
const int64 slice_size = size[dim];
|
||||
if (start_index + slice_size < input_size) {
|
||||
// The entire selection is within input boundaries.
|
||||
output_shape.set_dim(dim, slice_size);
|
||||
} else if (start_index < input_size) {
|
||||
// The selection starts within input boundaries, but goes beyond them.
|
||||
output_shape.set_dim(dim, input_size - start_index);
|
||||
} else {
|
||||
// The selection is entirely out of input boundaries.
|
||||
output_shape.set_dim(dim, 0);
|
||||
}
|
||||
}
|
||||
|
||||
auto input_indices_t = input_tensor.indices().matrix<int64>();
|
||||
|
@ -692,6 +692,29 @@ TEST(SparseTensorTest, Slice) {
|
||||
EXPECT_EQ(slice.indices().matrix<int64>()(2, 1), 2);
|
||||
}
|
||||
|
||||
TEST(SparseTensorTest, SliceReducesOutputDimension) {
|
||||
const int num_rows = 2;
|
||||
const int num_columns = 2;
|
||||
|
||||
Tensor ids(DT_INT64, TensorShape({num_rows, num_columns}));
|
||||
ids.matrix<int64>()(0, 0) = 0;
|
||||
ids.matrix<int64>()(0, 1) = 0;
|
||||
ids.matrix<int64>()(1, 0) = 1;
|
||||
ids.matrix<int64>()(1, 1) = 1;
|
||||
|
||||
Tensor vals(DT_INT64, TensorShape({2}));
|
||||
vals.vec<int64>()(0) = 1;
|
||||
vals.vec<int64>()(1) = 2;
|
||||
|
||||
SparseTensor st;
|
||||
TF_ASSERT_OK(SparseTensor::Create(ids, vals,
|
||||
TensorShape({num_rows, num_columns}), &st));
|
||||
|
||||
SparseTensor slice =
|
||||
SparseTensor::Slice<int64>(st, {num_rows + 1, 1}, {1, num_columns});
|
||||
EXPECT_EQ(TensorShape(slice.shape()), TensorShape({0, 1}));
|
||||
}
|
||||
|
||||
TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) {
|
||||
Tensor ix(DT_INT64, TensorShape({1, 0}));
|
||||
Tensor vals(DT_INT32, TensorShape({1}));
|
||||
|
Loading…
Reference in New Issue
Block a user