Add SparseTensor support to tf.batch and friends.
Change: 116914274
This commit is contained in:
parent
025c0d21a6
commit
64dd5b58d5
tensorflow/python
@ -576,25 +576,25 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
|
||||
as_ref=False):
|
||||
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
|
||||
|
||||
If `value` is an `IndexedSlices` it is returned
|
||||
If `value` is an `IndexedSlices` or `SparseTensor` it is returned
|
||||
unmodified. Otherwise, it is converted to a `Tensor` using
|
||||
`convert_to_tensor()`.
|
||||
|
||||
Args:
|
||||
value: An `IndexedSlices` or an object that can be consumed by
|
||||
`convert_to_tensor()`.
|
||||
value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
|
||||
by `convert_to_tensor()`.
|
||||
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
||||
`IndexedSlices`.
|
||||
name: (Optional.) A name to use if a new `Tensor` is created.
|
||||
as_ref: True if the caller wants the results as ref tensors.
|
||||
|
||||
Returns:
|
||||
An `Tensor` or an `IndexedSlices` based on `value`.
|
||||
An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `dtype` does not match the element type of `value`.
|
||||
"""
|
||||
if isinstance(value, IndexedSlices):
|
||||
if isinstance(value, (IndexedSlices, SparseTensor)):
|
||||
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
|
||||
@ -608,9 +608,12 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
|
||||
as_ref=False):
|
||||
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
|
||||
|
||||
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
values: A list of `None`, `IndexedSlices`, or objects that can be consumed
|
||||
by `convert_to_tensor()`.
|
||||
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
|
||||
can be consumed by `convert_to_tensor()`.
|
||||
dtype: (Optional.) The required `DType` of the returned `Tensor`
|
||||
`IndexedSlices`.
|
||||
name: (Optional.) A name prefix to used when a new `Tensor` is
|
||||
@ -619,7 +622,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
|
||||
as_ref: True if the caller wants the results as ref tensors.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` and/or `IndexedSlices` objects.
|
||||
A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
|
||||
|
||||
Raises:
|
||||
TypeError: If no conversion function is registered for an element in
|
||||
|
@ -23,6 +23,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,6 +37,7 @@ from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import queue_runner
|
||||
|
||||
@ -228,6 +231,54 @@ def _flatten(tensor_list_list):
|
||||
return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
|
||||
|
||||
|
||||
def _serialize_sparse_tensors(tensor_list, enqueue_many):
|
||||
"""Serialize SparseTensors for feeding into batch, etc."""
|
||||
is_sparse_list = [isinstance(t, ops.SparseTensor) for t in tensor_list]
|
||||
sparse_dtypes_list = [
|
||||
t.dtype if isinstance(t, ops.SparseTensor) else None
|
||||
for t in tensor_list]
|
||||
|
||||
def _maybe_serialize(t, is_sparse):
|
||||
if not is_sparse:
|
||||
return t
|
||||
return (sparse_ops.serialize_many_sparse(t) if enqueue_many
|
||||
else sparse_ops.serialize_sparse(t))
|
||||
serialized_list = [
|
||||
_maybe_serialize(t, is_sparse)
|
||||
for (t, is_sparse) in zip(tensor_list, is_sparse_list)]
|
||||
return serialized_list, is_sparse_list, sparse_dtypes_list
|
||||
|
||||
|
||||
def _serialize_sparse_tensors_join(tensor_list_list, enqueue_many):
|
||||
"""Serialize SparseTensors for feeding into batch_join, etc."""
|
||||
(s0, is_sparse_list, sparse_dtypes_list) = _serialize_sparse_tensors(
|
||||
tensor_list_list[0], enqueue_many)
|
||||
serialized_list_list = [s0]
|
||||
for tensor_list in tensor_list_list[1:]:
|
||||
(s, is_sparse_candidate, sparse_dtypes_candidate) = (
|
||||
_serialize_sparse_tensors(tensor_list, enqueue_many))
|
||||
if is_sparse_candidate != is_sparse_list:
|
||||
raise ValueError("Inconsistent SparseTensors list: %s vs. %s"
|
||||
% (tensor_list_list[0], tensor_list))
|
||||
if sparse_dtypes_candidate != sparse_dtypes_list:
|
||||
raise ValueError("Inconsistent SparseTensor dtypes in list: %s vs. %s"
|
||||
% (tensor_list_list[0], tensor_list))
|
||||
serialized_list_list.append(s)
|
||||
return (serialized_list_list, is_sparse_list, sparse_dtypes_list)
|
||||
|
||||
|
||||
def _deserialize_sparse_tensors(serialized_list, is_sparse_list, sparse_dtypes):
|
||||
"""Deserialize SparseTensors after dequeue in batch, batch_join, etc."""
|
||||
received_sequence = isinstance(serialized_list, collections.Sequence)
|
||||
if not received_sequence:
|
||||
serialized_list = (serialized_list,)
|
||||
tensors = [sparse_ops.deserialize_many_sparse(s, sparse_dtype) if is_sparse
|
||||
else s
|
||||
for (s, is_sparse, sparse_dtype)
|
||||
in zip(serialized_list, is_sparse_list, sparse_dtypes)]
|
||||
return tensors if received_sequence else tensors[0]
|
||||
|
||||
|
||||
def _validate(tensor_list):
|
||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
||||
if not tensor_list:
|
||||
@ -343,6 +394,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
|
||||
"""
|
||||
with ops.op_scope(tensor_list, name, "batch") as name:
|
||||
tensor_list = _validate(tensor_list)
|
||||
tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
|
||||
tensor_list, enqueue_many)
|
||||
types = _dtypes([tensor_list])
|
||||
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
||||
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
||||
@ -352,7 +405,10 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
|
||||
logging_ops.scalar_summary(
|
||||
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
||||
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
||||
return queue.dequeue_many(batch_size, name=name)
|
||||
|
||||
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||
return dequeued
|
||||
|
||||
|
||||
# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
|
||||
@ -422,6 +478,8 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
|
||||
"""
|
||||
with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
|
||||
tensor_list_list = _validate_join(tensor_list_list)
|
||||
tensor_list_list, is_sparse, sparse_dtypes = (
|
||||
_serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
|
||||
types = _dtypes(tensor_list_list)
|
||||
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
||||
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
||||
@ -431,7 +489,10 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
|
||||
logging_ops.scalar_summary(
|
||||
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
||||
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
||||
return queue.dequeue_many(batch_size, name=name)
|
||||
|
||||
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||
return dequeued
|
||||
|
||||
|
||||
def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
||||
@ -506,6 +567,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
||||
"""
|
||||
with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
|
||||
tensor_list = _validate(tensor_list)
|
||||
tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
|
||||
tensor_list, enqueue_many)
|
||||
types = _dtypes([tensor_list])
|
||||
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
||||
queue = data_flow_ops.RandomShuffleQueue(
|
||||
@ -522,7 +585,9 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
||||
(name, min_after_dequeue, capacity - min_after_dequeue))
|
||||
logging_ops.scalar_summary(summary_name, full)
|
||||
|
||||
return queue.dequeue_many(batch_size, name=name)
|
||||
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||
return dequeued
|
||||
|
||||
|
||||
def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
||||
@ -587,6 +652,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
||||
with ops.op_scope(
|
||||
_flatten(tensor_list_list), name, "shuffle_batch_join") as name:
|
||||
tensor_list_list = _validate_join(tensor_list_list)
|
||||
tensor_list_list, is_sparse, sparse_dtypes = (
|
||||
_serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
|
||||
types = _dtypes(tensor_list_list)
|
||||
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
||||
queue = data_flow_ops.RandomShuffleQueue(
|
||||
@ -602,4 +669,7 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
||||
"queue/%sfraction_over_%d_of_%d_full" %
|
||||
(name, min_after_dequeue, capacity - min_after_dequeue))
|
||||
logging_ops.scalar_summary(summary_name, full)
|
||||
return queue.dequeue_many(batch_size, name=name)
|
||||
|
||||
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||
return dequeued
|
||||
|
@ -318,7 +318,12 @@ class BatchTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_batches * batch_size)
|
||||
batched = tf.train.batch([counter, "string"], batch_size=batch_size)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(tf.pack([zero64, zero64 + 1]), [2, 1]),
|
||||
values=tf.cast(tf.pack([counter, -counter]), tf.float32),
|
||||
shape=[2])
|
||||
batched = tf.train.batch(
|
||||
[counter, sparse_counter, "string"], batch_size=batch_size)
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
|
||||
@ -326,7 +331,16 @@ class BatchTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
||||
(i + 1) * batch_size))
|
||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ...
|
||||
[0, 1] * batch_size)).T)
|
||||
# [x, -x, x+1, -(x+1), ...]
|
||||
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
|
||||
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
|
||||
self.assertAllEqual(results[1].values, expected)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 2])
|
||||
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||
|
||||
# Reached the limit.
|
||||
with self.assertRaises(tf.errors.OutOfRangeError):
|
||||
@ -341,7 +355,12 @@ class BatchTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_batches * batch_size)
|
||||
pre_batched = tf.train.batch([counter, "string"], batch_size=2)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
pre_batched = tf.train.batch(
|
||||
[counter, sparse_counter, "string"], batch_size=2)
|
||||
batched = tf.train.batch(pre_batched, enqueue_many=True,
|
||||
batch_size=batch_size)
|
||||
tf.initialize_all_variables().run()
|
||||
@ -351,7 +370,13 @@ class BatchTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
||||
(i + 1) * batch_size))
|
||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(
|
||||
results[1].values, np.arange(i * batch_size, (i + 1) * batch_size))
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||
|
||||
# Reached the limit.
|
||||
with self.assertRaises(tf.errors.OutOfRangeError):
|
||||
@ -364,10 +389,16 @@ class BatchTest(tf.test.TestCase):
|
||||
batch_size = 10
|
||||
num_batches = 3
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_batches * batch_size)
|
||||
batched = tf.train.batch([counter, "string"], batch_size=batch_size,
|
||||
num_threads=4)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
batched = tf.train.batch(
|
||||
[counter, sparse_counter, "string"],
|
||||
batch_size=batch_size, num_threads=4)
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
|
||||
@ -376,8 +407,13 @@ class BatchTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
tf.logging.info("Batch %d: %s", i, results[0])
|
||||
self.assertEqual(len(results[0]), batch_size)
|
||||
self.assertAllEqual(results[0], results[1].values)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
all_counts.extend(results[0])
|
||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
||||
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
|
||||
|
||||
# Reached the limit.
|
||||
@ -411,16 +447,26 @@ class BatchJoinTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_a)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
|
||||
# The second generates (99, "b") 90 times and then stops.
|
||||
num_b = 90
|
||||
ninety_nine = tf.train.limit_epochs(
|
||||
tf.constant(99, dtype=tf.int64), num_b)
|
||||
sparse_ninety_nine = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
|
||||
shape=[1])
|
||||
|
||||
# These get joined together and grouped into batches of 5.
|
||||
batch_size = 5
|
||||
batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
|
||||
batch_size=batch_size)
|
||||
batched = tf.train.batch_join(
|
||||
[[counter, sparse_counter, "a"],
|
||||
[ninety_nine, sparse_ninety_nine, "b"]],
|
||||
batch_size=batch_size)
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
|
||||
@ -433,9 +479,14 @@ class BatchJoinTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
tf.logging.info("Batch %d: %s", i, results[0])
|
||||
self.assertEqual(len(results[0]), batch_size)
|
||||
self.assertEqual(len(results[1]), batch_size)
|
||||
which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
|
||||
which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
|
||||
self.assertEqual(len(results[2]), batch_size)
|
||||
self.assertAllEqual(results[0], results[1].values)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
|
||||
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
|
||||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
@ -481,8 +532,13 @@ class ShuffleBatchTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_batches * batch_size)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
batched = tf.train.shuffle_batch(
|
||||
[counter, "string"], batch_size=batch_size, capacity=32,
|
||||
[counter, sparse_counter, "string"],
|
||||
batch_size=batch_size, capacity=32,
|
||||
min_after_dequeue=16, seed=141421)
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
@ -492,7 +548,12 @@ class ShuffleBatchTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
self.assertEqual(len(results[0]), batch_size)
|
||||
all_counts.extend(results[0])
|
||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(results[0], results[1].values)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||
# Results scrambled, but include all the expected numbers.
|
||||
deltas = [all_counts[i + 1] - all_counts[i]
|
||||
for i in range(len(all_counts) - 1)]
|
||||
@ -512,8 +573,13 @@ class ShuffleBatchTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_batches * batch_size)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
batched = tf.train.shuffle_batch(
|
||||
[counter, "string"], batch_size=batch_size, capacity=32,
|
||||
[counter, sparse_counter, "string"],
|
||||
batch_size=batch_size, capacity=32,
|
||||
min_after_dequeue=16, seed=173205, num_threads=4)
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
@ -524,7 +590,12 @@ class ShuffleBatchTest(tf.test.TestCase):
|
||||
tf.logging.info("Batch %d: %s", i, results[0])
|
||||
self.assertEqual(len(results[0]), batch_size)
|
||||
all_counts.extend(results[0])
|
||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(results[0], results[1].values)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||
# Results scrambled, but include all the expected numbers.
|
||||
deltas = [all_counts[i + 1] - all_counts[i]
|
||||
for i in range(len(all_counts) - 1)]
|
||||
@ -564,17 +635,27 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
|
||||
zero64 = tf.constant(0, dtype=tf.int64)
|
||||
examples = tf.Variable(zero64)
|
||||
counter = examples.count_up_to(num_a)
|
||||
sparse_counter = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||
shape=[1])
|
||||
|
||||
# The second generates (99, "b") 35 times and then stops.
|
||||
num_b = 35
|
||||
ninety_nine = tf.train.limit_epochs(
|
||||
tf.constant(99, dtype=tf.int64), num_b)
|
||||
sparse_ninety_nine = tf.SparseTensor(
|
||||
indices=tf.reshape(zero64, [1, 1]),
|
||||
values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
|
||||
shape=[1])
|
||||
|
||||
# These get joined together and grouped into batches of 5.
|
||||
batch_size = 5
|
||||
batched = tf.train.shuffle_batch_join(
|
||||
[[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
|
||||
capacity=32, min_after_dequeue=16, seed=223607)
|
||||
[[counter, sparse_counter, "a"],
|
||||
[ninety_nine, sparse_ninety_nine, "b"]],
|
||||
batch_size=batch_size, capacity=32,
|
||||
min_after_dequeue=16, seed=223607)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
threads = tf.train.start_queue_runners()
|
||||
@ -588,9 +669,14 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
|
||||
results = sess.run(batched)
|
||||
tf.logging.info("Batch %d: %s", i, results[0])
|
||||
self.assertEqual(len(results[0]), batch_size)
|
||||
self.assertEqual(len(results[1]), batch_size)
|
||||
which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
|
||||
which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
|
||||
self.assertEqual(len(results[2]), batch_size)
|
||||
self.assertAllEqual(results[0], results[1].values)
|
||||
self.assertAllEqual(
|
||||
results[1].indices,
|
||||
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
|
||||
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
|
||||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
|
Loading…
Reference in New Issue
Block a user