Merge pull request #39678 from firejq:patch-1

PiperOrigin-RevId: 321841815
Change-Id: I2810249fbe923b60b3ee1846b208dfc9b3d19039
This commit is contained in:
TensorFlower Gardener 2020-07-17 13:24:47 -07:00
commit 4f80ad304c
2 changed files with 23 additions and 2 deletions

View File

@ -118,13 +118,17 @@ class SparseFillEmptyRowsOp : public OpKernel {
return;
}
bool rows_are_ordered = true;
int64 last_indices_row = 0;
std::vector<int64> csr_offset(dense_rows, 0);
for (int i = 0; i < N; ++i) {
const int64 row = indices(i, 0);
OP_REQUIRES(context, row >= 0 && row < dense_rows,
errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
row, " >= ", dense_rows));
++csr_offset[indices(i, 0)];
++csr_offset[row];
rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
last_indices_row = row;
}
bool all_rows_full = true;
for (int row = 0; row < dense_rows; ++row) {
@ -147,7 +151,7 @@ class SparseFillEmptyRowsOp : public OpKernel {
}
}
if (all_rows_full) {
if (all_rows_full && rows_are_ordered) {
context->set_output(kOutputIndicesOutput, indices_t);
context->set_output(kOutputValuesOutput, values_t);
if (reverse_index_map) {

View File

@ -585,6 +585,23 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(output.dense_shape, [2, 6])
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
def testNoEmptyRowsAndUnordered(self):
with test_util.force_cpu():
sp_input = sparse_tensor.SparseTensor(
indices=np.array([[1, 2], [1, 3], [0, 1], [0, 3]]),
values=np.array([1, 3, 2, 4]),
dense_shape=np.array([2, 5]))
sp_output, empty_row_indicator = (
sparse_ops.sparse_fill_empty_rows(sp_input, -1))
output, empty_row_indicator_out = self.evaluate(
[sp_output, empty_row_indicator])
self.assertAllEqual(output.indices, [[0, 1], [0, 3], [1, 2], [1, 3]])
self.assertAllEqual(output.values, [2, 4, 1, 3])
self.assertAllEqual(output.dense_shape, [2, 5])
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
class SparseAddTest(test_util.TensorFlowTestCase):