BREAKING CHANGE: Fix semantic error in how maybe_batch* handles sparse tensors.
PiperOrigin-RevId: 163276613
This commit is contained in:
parent
6028c071b5
commit
423c1eea0e
tensorflow/python/training
@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
|
||||
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
|
||||
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
|
||||
return out_tensor
|
||||
def _sparse_values_to_keep(t, keep_input):
|
||||
"""Convert a per-row `keep_input` vector to a per-value one."""
|
||||
# Get the rows of every value in the sparse Tensor.
|
||||
row_values = array_ops.reshape(
|
||||
t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0]
|
||||
# The value should be kept iff the row should be kept.
|
||||
return array_ops.gather(keep_input, row_values)
|
||||
if keep_input.shape.ndims == 1:
|
||||
t = sparse_ops.sparse_retain(t, keep_input)
|
||||
t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input))
|
||||
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
|
||||
elif enqueue_many:
|
||||
store_f = _maybe_store_many_sparse
|
||||
|
@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase):
|
||||
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
|
||||
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
|
||||
|
||||
def testMaybeBatchCorrectValues(self):
|
||||
sparse_t = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
|
||||
dense_shape=[2, 4],
|
||||
values=[5, 4, 7, 2])
|
||||
keep = constant_op.constant([True, False])
|
||||
batched = inp.maybe_batch(
|
||||
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
|
||||
|
||||
with self.test_session():
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(coord=coord)
|
||||
|
||||
batched_np = batched.eval()
|
||||
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
|
||||
self.assertAllEqual([5, 4], batched_np.values)
|
||||
self.assertAllEqual([1, 4], batched_np.dense_shape)
|
||||
|
||||
|
||||
class BatchJoinTest(test_lib.TestCase):
|
||||
|
||||
@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase):
|
||||
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
|
||||
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
|
||||
|
||||
def testMaybeBatchCorrectValues(self):
|
||||
sparse = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
|
||||
dense_shape=[2, 4],
|
||||
values=[5, 4, 7, 2])
|
||||
keep = constant_op.constant([True, False])
|
||||
batched = inp.maybe_batch_join(
|
||||
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
|
||||
|
||||
with self.test_session():
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(coord=coord)
|
||||
|
||||
batched_np = batched.eval()
|
||||
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
|
||||
self.assertAllEqual([5, 4], batched_np.values)
|
||||
self.assertAllEqual([1, 4], batched_np.dense_shape)
|
||||
|
||||
|
||||
class ShuffleBatchTest(test_lib.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user