From 423c1eea0e47eb71d3bf3ec7e99e7a4a63c3e433 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2017 16:51:45 -0700 Subject: [PATCH] BREAKING CHANGE: Fix semantic error in how maybe_batch* handles sparse tensors. PiperOrigin-RevId: 163276613 --- tensorflow/python/training/input.py | 9 ++++- tensorflow/python/training/input_test.py | 46 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 396ec11a025..94c5df619ff 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64)) out_tensor.set_shape([None]) # necessary when t.ndims is unknown return out_tensor + def _sparse_values_to_keep(t, keep_input): + """Convert a per-row `keep_input` vector to a per-value one.""" + # Get the rows of every value in the sparse Tensor. + row_values = array_ops.reshape( + t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0] + # The value should be kept iff the row should be kept. + return array_ops.gather(keep_input, row_values) if keep_input.shape.ndims == 1: - t = sparse_ops.sparse_retain(t, keep_input) + t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input)) store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name) elif enqueue_many: store_f = _maybe_store_many_sparse diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index d32768f5da7..3a25bfe3432 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase): [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True) self.assertIs(None, batched.dense_shape.get_shape().num_elements()) + def testMaybeBatchCorrectValues(self): + sparse_t = sparse_tensor.SparseTensor( + indices=[[0, 1], [0, 2], [1, 0], [1, 3]], + dense_shape=[2, 4], + values=[5, 4, 7, 2]) + keep = constant_op.constant([True, False]) + batched = inp.maybe_batch( + [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True) + + with self.test_session(): + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(coord=coord) + + batched_np = batched.eval() + + coord.request_stop() + for thread in threads: + thread.join() + + self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) + self.assertAllEqual([5, 4], batched_np.values) + self.assertAllEqual([1, 4], batched_np.dense_shape) + class BatchJoinTest(test_lib.TestCase): @@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase): [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True) self.assertIs(None, batched.dense_shape.get_shape().num_elements()) + def testMaybeBatchCorrectValues(self): + sparse = sparse_tensor.SparseTensor( + indices=[[0, 1], [0, 2], [1, 0], [1, 3]], + dense_shape=[2, 4], + values=[5, 4, 7, 2]) + keep = constant_op.constant([True, False]) + batched = inp.maybe_batch_join( + [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True) + + with self.test_session(): + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(coord=coord) + + batched_np = batched.eval() + + coord.request_stop() + for thread in threads: + thread.join() + + self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) + self.assertAllEqual([5, 4], batched_np.values) + self.assertAllEqual([1, 4], batched_np.dense_shape) + class ShuffleBatchTest(test_lib.TestCase):