Merge pull request #39678 from firejq:patch-1
PiperOrigin-RevId: 321841815 Change-Id: I2810249fbe923b60b3ee1846b208dfc9b3d19039
This commit is contained in:
commit
4f80ad304c
@ -118,13 +118,17 @@ class SparseFillEmptyRowsOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
bool rows_are_ordered = true;
|
||||
int64 last_indices_row = 0;
|
||||
std::vector<int64> csr_offset(dense_rows, 0);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const int64 row = indices(i, 0);
|
||||
OP_REQUIRES(context, row >= 0 && row < dense_rows,
|
||||
errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
|
||||
row, " >= ", dense_rows));
|
||||
++csr_offset[indices(i, 0)];
|
||||
++csr_offset[row];
|
||||
rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
|
||||
last_indices_row = row;
|
||||
}
|
||||
bool all_rows_full = true;
|
||||
for (int row = 0; row < dense_rows; ++row) {
|
||||
@ -147,7 +151,7 @@ class SparseFillEmptyRowsOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
if (all_rows_full) {
|
||||
if (all_rows_full && rows_are_ordered) {
|
||||
context->set_output(kOutputIndicesOutput, indices_t);
|
||||
context->set_output(kOutputValuesOutput, values_t);
|
||||
if (reverse_index_map) {
|
||||
|
@ -585,6 +585,23 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(output.dense_shape, [2, 6])
|
||||
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
|
||||
|
||||
def testNoEmptyRowsAndUnordered(self):
|
||||
with test_util.force_cpu():
|
||||
sp_input = sparse_tensor.SparseTensor(
|
||||
indices=np.array([[1, 2], [1, 3], [0, 1], [0, 3]]),
|
||||
values=np.array([1, 3, 2, 4]),
|
||||
dense_shape=np.array([2, 5]))
|
||||
sp_output, empty_row_indicator = (
|
||||
sparse_ops.sparse_fill_empty_rows(sp_input, -1))
|
||||
|
||||
output, empty_row_indicator_out = self.evaluate(
|
||||
[sp_output, empty_row_indicator])
|
||||
|
||||
self.assertAllEqual(output.indices, [[0, 1], [0, 3], [1, 2], [1, 3]])
|
||||
self.assertAllEqual(output.values, [2, 4, 1, 3])
|
||||
self.assertAllEqual(output.dense_shape, [2, 5])
|
||||
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
|
||||
|
||||
|
||||
class SparseAddTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user