Merge pull request #39678 from firejq:patch-1
PiperOrigin-RevId: 321841815 Change-Id: I2810249fbe923b60b3ee1846b208dfc9b3d19039
This commit is contained in:
commit
4f80ad304c
@ -118,13 +118,17 @@ class SparseFillEmptyRowsOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool rows_are_ordered = true;
|
||||||
|
int64 last_indices_row = 0;
|
||||||
std::vector<int64> csr_offset(dense_rows, 0);
|
std::vector<int64> csr_offset(dense_rows, 0);
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
const int64 row = indices(i, 0);
|
const int64 row = indices(i, 0);
|
||||||
OP_REQUIRES(context, row >= 0 && row < dense_rows,
|
OP_REQUIRES(context, row >= 0 && row < dense_rows,
|
||||||
errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
|
errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
|
||||||
row, " >= ", dense_rows));
|
row, " >= ", dense_rows));
|
||||||
++csr_offset[indices(i, 0)];
|
++csr_offset[row];
|
||||||
|
rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
|
||||||
|
last_indices_row = row;
|
||||||
}
|
}
|
||||||
bool all_rows_full = true;
|
bool all_rows_full = true;
|
||||||
for (int row = 0; row < dense_rows; ++row) {
|
for (int row = 0; row < dense_rows; ++row) {
|
||||||
@ -147,7 +151,7 @@ class SparseFillEmptyRowsOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all_rows_full) {
|
if (all_rows_full && rows_are_ordered) {
|
||||||
context->set_output(kOutputIndicesOutput, indices_t);
|
context->set_output(kOutputIndicesOutput, indices_t);
|
||||||
context->set_output(kOutputValuesOutput, values_t);
|
context->set_output(kOutputValuesOutput, values_t);
|
||||||
if (reverse_index_map) {
|
if (reverse_index_map) {
|
||||||
|
@ -585,6 +585,23 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(output.dense_shape, [2, 6])
|
self.assertAllEqual(output.dense_shape, [2, 6])
|
||||||
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
|
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
|
||||||
|
|
||||||
|
def testNoEmptyRowsAndUnordered(self):
|
||||||
|
with test_util.force_cpu():
|
||||||
|
sp_input = sparse_tensor.SparseTensor(
|
||||||
|
indices=np.array([[1, 2], [1, 3], [0, 1], [0, 3]]),
|
||||||
|
values=np.array([1, 3, 2, 4]),
|
||||||
|
dense_shape=np.array([2, 5]))
|
||||||
|
sp_output, empty_row_indicator = (
|
||||||
|
sparse_ops.sparse_fill_empty_rows(sp_input, -1))
|
||||||
|
|
||||||
|
output, empty_row_indicator_out = self.evaluate(
|
||||||
|
[sp_output, empty_row_indicator])
|
||||||
|
|
||||||
|
self.assertAllEqual(output.indices, [[0, 1], [0, 3], [1, 2], [1, 3]])
|
||||||
|
self.assertAllEqual(output.values, [2, 4, 1, 3])
|
||||||
|
self.assertAllEqual(output.dense_shape, [2, 5])
|
||||||
|
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
|
||||||
|
|
||||||
|
|
||||||
class SparseAddTest(test_util.TensorFlowTestCase):
|
class SparseAddTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user