Add SparseTensor support to tf.batch and friends.
Change: 116914274
This commit is contained in:
parent
025c0d21a6
commit
64dd5b58d5
@ -576,25 +576,25 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
|
|||||||
as_ref=False):
|
as_ref=False):
|
||||||
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
|
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
|
||||||
|
|
||||||
If `value` is an `IndexedSlices` it is returned
|
If `value` is an `IndexedSlices` or `SparseTensor` it is returned
|
||||||
unmodified. Otherwise, it is converted to a `Tensor` using
|
unmodified. Otherwise, it is converted to a `Tensor` using
|
||||||
`convert_to_tensor()`.
|
`convert_to_tensor()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: An `IndexedSlices` or an object that can be consumed by
|
value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
|
||||||
`convert_to_tensor()`.
|
by `convert_to_tensor()`.
|
||||||
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
||||||
`IndexedSlices`.
|
`IndexedSlices`.
|
||||||
name: (Optional.) A name to use if a new `Tensor` is created.
|
name: (Optional.) A name to use if a new `Tensor` is created.
|
||||||
as_ref: True if the caller wants the results as ref tensors.
|
as_ref: True if the caller wants the results as ref tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An `Tensor` or an `IndexedSlices` based on `value`.
|
An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `dtype` does not match the element type of `value`.
|
ValueError: If `dtype` does not match the element type of `value`.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, IndexedSlices):
|
if isinstance(value, (IndexedSlices, SparseTensor)):
|
||||||
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
|
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
|
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
|
||||||
@ -608,9 +608,12 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
|
|||||||
as_ref=False):
|
as_ref=False):
|
||||||
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
|
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
|
||||||
|
|
||||||
|
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
|
||||||
|
unmodified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: A list of `None`, `IndexedSlices`, or objects that can be consumed
|
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
|
||||||
by `convert_to_tensor()`.
|
can be consumed by `convert_to_tensor()`.
|
||||||
dtype: (Optional.) The required `DType` of the returned `Tensor`
|
dtype: (Optional.) The required `DType` of the returned `Tensor`
|
||||||
`IndexedSlices`.
|
`IndexedSlices`.
|
||||||
name: (Optional.) A name prefix to used when a new `Tensor` is
|
name: (Optional.) A name prefix to used when a new `Tensor` is
|
||||||
@ -619,7 +622,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
|
|||||||
as_ref: True if the caller wants the results as ref tensors.
|
as_ref: True if the caller wants the results as ref tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `Tensor` and/or `IndexedSlices` objects.
|
A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If no conversion function is registered for an element in
|
TypeError: If no conversion function is registered for an element in
|
||||||
|
@ -23,6 +23,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -35,6 +37,7 @@ from tensorflow.python.ops import io_ops
|
|||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import queue_runner
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
@ -228,6 +231,54 @@ def _flatten(tensor_list_list):
|
|||||||
return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
|
return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_sparse_tensors(tensor_list, enqueue_many):
|
||||||
|
"""Serialize SparseTensors for feeding into batch, etc."""
|
||||||
|
is_sparse_list = [isinstance(t, ops.SparseTensor) for t in tensor_list]
|
||||||
|
sparse_dtypes_list = [
|
||||||
|
t.dtype if isinstance(t, ops.SparseTensor) else None
|
||||||
|
for t in tensor_list]
|
||||||
|
|
||||||
|
def _maybe_serialize(t, is_sparse):
|
||||||
|
if not is_sparse:
|
||||||
|
return t
|
||||||
|
return (sparse_ops.serialize_many_sparse(t) if enqueue_many
|
||||||
|
else sparse_ops.serialize_sparse(t))
|
||||||
|
serialized_list = [
|
||||||
|
_maybe_serialize(t, is_sparse)
|
||||||
|
for (t, is_sparse) in zip(tensor_list, is_sparse_list)]
|
||||||
|
return serialized_list, is_sparse_list, sparse_dtypes_list
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_sparse_tensors_join(tensor_list_list, enqueue_many):
|
||||||
|
"""Serialize SparseTensors for feeding into batch_join, etc."""
|
||||||
|
(s0, is_sparse_list, sparse_dtypes_list) = _serialize_sparse_tensors(
|
||||||
|
tensor_list_list[0], enqueue_many)
|
||||||
|
serialized_list_list = [s0]
|
||||||
|
for tensor_list in tensor_list_list[1:]:
|
||||||
|
(s, is_sparse_candidate, sparse_dtypes_candidate) = (
|
||||||
|
_serialize_sparse_tensors(tensor_list, enqueue_many))
|
||||||
|
if is_sparse_candidate != is_sparse_list:
|
||||||
|
raise ValueError("Inconsistent SparseTensors list: %s vs. %s"
|
||||||
|
% (tensor_list_list[0], tensor_list))
|
||||||
|
if sparse_dtypes_candidate != sparse_dtypes_list:
|
||||||
|
raise ValueError("Inconsistent SparseTensor dtypes in list: %s vs. %s"
|
||||||
|
% (tensor_list_list[0], tensor_list))
|
||||||
|
serialized_list_list.append(s)
|
||||||
|
return (serialized_list_list, is_sparse_list, sparse_dtypes_list)
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_sparse_tensors(serialized_list, is_sparse_list, sparse_dtypes):
|
||||||
|
"""Deserialize SparseTensors after dequeue in batch, batch_join, etc."""
|
||||||
|
received_sequence = isinstance(serialized_list, collections.Sequence)
|
||||||
|
if not received_sequence:
|
||||||
|
serialized_list = (serialized_list,)
|
||||||
|
tensors = [sparse_ops.deserialize_many_sparse(s, sparse_dtype) if is_sparse
|
||||||
|
else s
|
||||||
|
for (s, is_sparse, sparse_dtype)
|
||||||
|
in zip(serialized_list, is_sparse_list, sparse_dtypes)]
|
||||||
|
return tensors if received_sequence else tensors[0]
|
||||||
|
|
||||||
|
|
||||||
def _validate(tensor_list):
|
def _validate(tensor_list):
|
||||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
||||||
if not tensor_list:
|
if not tensor_list:
|
||||||
@ -343,6 +394,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
|
|||||||
"""
|
"""
|
||||||
with ops.op_scope(tensor_list, name, "batch") as name:
|
with ops.op_scope(tensor_list, name, "batch") as name:
|
||||||
tensor_list = _validate(tensor_list)
|
tensor_list = _validate(tensor_list)
|
||||||
|
tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
|
||||||
|
tensor_list, enqueue_many)
|
||||||
types = _dtypes([tensor_list])
|
types = _dtypes([tensor_list])
|
||||||
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
||||||
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
||||||
@ -352,7 +405,10 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
|
|||||||
logging_ops.scalar_summary(
|
logging_ops.scalar_summary(
|
||||||
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
||||||
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
||||||
return queue.dequeue_many(batch_size, name=name)
|
|
||||||
|
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||||
|
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||||
|
return dequeued
|
||||||
|
|
||||||
|
|
||||||
# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
|
# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
|
||||||
@ -422,6 +478,8 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
|
|||||||
"""
|
"""
|
||||||
with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
|
with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
|
||||||
tensor_list_list = _validate_join(tensor_list_list)
|
tensor_list_list = _validate_join(tensor_list_list)
|
||||||
|
tensor_list_list, is_sparse, sparse_dtypes = (
|
||||||
|
_serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
|
||||||
types = _dtypes(tensor_list_list)
|
types = _dtypes(tensor_list_list)
|
||||||
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
||||||
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
|
||||||
@ -431,7 +489,10 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
|
|||||||
logging_ops.scalar_summary(
|
logging_ops.scalar_summary(
|
||||||
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
|
||||||
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
|
||||||
return queue.dequeue_many(batch_size, name=name)
|
|
||||||
|
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||||
|
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||||
|
return dequeued
|
||||||
|
|
||||||
|
|
||||||
def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
||||||
@ -506,6 +567,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
|||||||
"""
|
"""
|
||||||
with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
|
with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
|
||||||
tensor_list = _validate(tensor_list)
|
tensor_list = _validate(tensor_list)
|
||||||
|
tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
|
||||||
|
tensor_list, enqueue_many)
|
||||||
types = _dtypes([tensor_list])
|
types = _dtypes([tensor_list])
|
||||||
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
shapes = _shapes([tensor_list], shapes, enqueue_many)
|
||||||
queue = data_flow_ops.RandomShuffleQueue(
|
queue = data_flow_ops.RandomShuffleQueue(
|
||||||
@ -522,7 +585,9 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
|||||||
(name, min_after_dequeue, capacity - min_after_dequeue))
|
(name, min_after_dequeue, capacity - min_after_dequeue))
|
||||||
logging_ops.scalar_summary(summary_name, full)
|
logging_ops.scalar_summary(summary_name, full)
|
||||||
|
|
||||||
return queue.dequeue_many(batch_size, name=name)
|
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||||
|
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||||
|
return dequeued
|
||||||
|
|
||||||
|
|
||||||
def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
||||||
@ -587,6 +652,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
|||||||
with ops.op_scope(
|
with ops.op_scope(
|
||||||
_flatten(tensor_list_list), name, "shuffle_batch_join") as name:
|
_flatten(tensor_list_list), name, "shuffle_batch_join") as name:
|
||||||
tensor_list_list = _validate_join(tensor_list_list)
|
tensor_list_list = _validate_join(tensor_list_list)
|
||||||
|
tensor_list_list, is_sparse, sparse_dtypes = (
|
||||||
|
_serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
|
||||||
types = _dtypes(tensor_list_list)
|
types = _dtypes(tensor_list_list)
|
||||||
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
|
||||||
queue = data_flow_ops.RandomShuffleQueue(
|
queue = data_flow_ops.RandomShuffleQueue(
|
||||||
@ -602,4 +669,7 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
|||||||
"queue/%sfraction_over_%d_of_%d_full" %
|
"queue/%sfraction_over_%d_of_%d_full" %
|
||||||
(name, min_after_dequeue, capacity - min_after_dequeue))
|
(name, min_after_dequeue, capacity - min_after_dequeue))
|
||||||
logging_ops.scalar_summary(summary_name, full)
|
logging_ops.scalar_summary(summary_name, full)
|
||||||
return queue.dequeue_many(batch_size, name=name)
|
|
||||||
|
dequeued = queue.dequeue_many(batch_size, name=name)
|
||||||
|
dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
|
||||||
|
return dequeued
|
||||||
|
@ -318,7 +318,12 @@ class BatchTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_batches * batch_size)
|
counter = examples.count_up_to(num_batches * batch_size)
|
||||||
batched = tf.train.batch([counter, "string"], batch_size=batch_size)
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(tf.pack([zero64, zero64 + 1]), [2, 1]),
|
||||||
|
values=tf.cast(tf.pack([counter, -counter]), tf.float32),
|
||||||
|
shape=[2])
|
||||||
|
batched = tf.train.batch(
|
||||||
|
[counter, sparse_counter, "string"], batch_size=batch_size)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
|
|
||||||
@ -326,7 +331,16 @@ class BatchTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
||||||
(i + 1) * batch_size))
|
(i + 1) * batch_size))
|
||||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ...
|
||||||
|
[0, 1] * batch_size)).T)
|
||||||
|
# [x, -x, x+1, -(x+1), ...]
|
||||||
|
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
|
||||||
|
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
|
||||||
|
self.assertAllEqual(results[1].values, expected)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 2])
|
||||||
|
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||||
|
|
||||||
# Reached the limit.
|
# Reached the limit.
|
||||||
with self.assertRaises(tf.errors.OutOfRangeError):
|
with self.assertRaises(tf.errors.OutOfRangeError):
|
||||||
@ -341,7 +355,12 @@ class BatchTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_batches * batch_size)
|
counter = examples.count_up_to(num_batches * batch_size)
|
||||||
pre_batched = tf.train.batch([counter, "string"], batch_size=2)
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
pre_batched = tf.train.batch(
|
||||||
|
[counter, sparse_counter, "string"], batch_size=2)
|
||||||
batched = tf.train.batch(pre_batched, enqueue_many=True,
|
batched = tf.train.batch(pre_batched, enqueue_many=True,
|
||||||
batch_size=batch_size)
|
batch_size=batch_size)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
@ -351,7 +370,13 @@ class BatchTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
self.assertAllEqual(results[0], np.arange(i * batch_size,
|
||||||
(i + 1) * batch_size))
|
(i + 1) * batch_size))
|
||||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(
|
||||||
|
results[1].values, np.arange(i * batch_size, (i + 1) * batch_size))
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
|
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||||
|
|
||||||
# Reached the limit.
|
# Reached the limit.
|
||||||
with self.assertRaises(tf.errors.OutOfRangeError):
|
with self.assertRaises(tf.errors.OutOfRangeError):
|
||||||
@ -364,10 +389,16 @@ class BatchTest(tf.test.TestCase):
|
|||||||
batch_size = 10
|
batch_size = 10
|
||||||
num_batches = 3
|
num_batches = 3
|
||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
|
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_batches * batch_size)
|
counter = examples.count_up_to(num_batches * batch_size)
|
||||||
batched = tf.train.batch([counter, "string"], batch_size=batch_size,
|
sparse_counter = tf.SparseTensor(
|
||||||
num_threads=4)
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
batched = tf.train.batch(
|
||||||
|
[counter, sparse_counter, "string"],
|
||||||
|
batch_size=batch_size, num_threads=4)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
|
|
||||||
@ -376,8 +407,13 @@ class BatchTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
tf.logging.info("Batch %d: %s", i, results[0])
|
tf.logging.info("Batch %d: %s", i, results[0])
|
||||||
self.assertEqual(len(results[0]), batch_size)
|
self.assertEqual(len(results[0]), batch_size)
|
||||||
|
self.assertAllEqual(results[0], results[1].values)
|
||||||
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
all_counts.extend(results[0])
|
all_counts.extend(results[0])
|
||||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||||
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
|
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
|
||||||
|
|
||||||
# Reached the limit.
|
# Reached the limit.
|
||||||
@ -411,16 +447,26 @@ class BatchJoinTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_a)
|
counter = examples.count_up_to(num_a)
|
||||||
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
|
||||||
# The second generates (99, "b") 90 times and then stops.
|
# The second generates (99, "b") 90 times and then stops.
|
||||||
num_b = 90
|
num_b = 90
|
||||||
ninety_nine = tf.train.limit_epochs(
|
ninety_nine = tf.train.limit_epochs(
|
||||||
tf.constant(99, dtype=tf.int64), num_b)
|
tf.constant(99, dtype=tf.int64), num_b)
|
||||||
|
sparse_ninety_nine = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
|
||||||
# These get joined together and grouped into batches of 5.
|
# These get joined together and grouped into batches of 5.
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
|
batched = tf.train.batch_join(
|
||||||
batch_size=batch_size)
|
[[counter, sparse_counter, "a"],
|
||||||
|
[ninety_nine, sparse_ninety_nine, "b"]],
|
||||||
|
batch_size=batch_size)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
|
|
||||||
@ -433,9 +479,14 @@ class BatchJoinTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
tf.logging.info("Batch %d: %s", i, results[0])
|
tf.logging.info("Batch %d: %s", i, results[0])
|
||||||
self.assertEqual(len(results[0]), batch_size)
|
self.assertEqual(len(results[0]), batch_size)
|
||||||
self.assertEqual(len(results[1]), batch_size)
|
self.assertEqual(len(results[2]), batch_size)
|
||||||
which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
|
self.assertAllEqual(results[0], results[1].values)
|
||||||
which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
|
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
|
||||||
|
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
|
||||||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||||
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
||||||
all_a.extend([results[0][i] for i in which_a])
|
all_a.extend([results[0][i] for i in which_a])
|
||||||
@ -481,8 +532,13 @@ class ShuffleBatchTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_batches * batch_size)
|
counter = examples.count_up_to(num_batches * batch_size)
|
||||||
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
batched = tf.train.shuffle_batch(
|
batched = tf.train.shuffle_batch(
|
||||||
[counter, "string"], batch_size=batch_size, capacity=32,
|
[counter, sparse_counter, "string"],
|
||||||
|
batch_size=batch_size, capacity=32,
|
||||||
min_after_dequeue=16, seed=141421)
|
min_after_dequeue=16, seed=141421)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
@ -492,7 +548,12 @@ class ShuffleBatchTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
self.assertEqual(len(results[0]), batch_size)
|
self.assertEqual(len(results[0]), batch_size)
|
||||||
all_counts.extend(results[0])
|
all_counts.extend(results[0])
|
||||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(results[0], results[1].values)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
|
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||||
# Results scrambled, but include all the expected numbers.
|
# Results scrambled, but include all the expected numbers.
|
||||||
deltas = [all_counts[i + 1] - all_counts[i]
|
deltas = [all_counts[i + 1] - all_counts[i]
|
||||||
for i in range(len(all_counts) - 1)]
|
for i in range(len(all_counts) - 1)]
|
||||||
@ -512,8 +573,13 @@ class ShuffleBatchTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_batches * batch_size)
|
counter = examples.count_up_to(num_batches * batch_size)
|
||||||
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
batched = tf.train.shuffle_batch(
|
batched = tf.train.shuffle_batch(
|
||||||
[counter, "string"], batch_size=batch_size, capacity=32,
|
[counter, sparse_counter, "string"],
|
||||||
|
batch_size=batch_size, capacity=32,
|
||||||
min_after_dequeue=16, seed=173205, num_threads=4)
|
min_after_dequeue=16, seed=173205, num_threads=4)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
@ -524,7 +590,12 @@ class ShuffleBatchTest(tf.test.TestCase):
|
|||||||
tf.logging.info("Batch %d: %s", i, results[0])
|
tf.logging.info("Batch %d: %s", i, results[0])
|
||||||
self.assertEqual(len(results[0]), batch_size)
|
self.assertEqual(len(results[0]), batch_size)
|
||||||
all_counts.extend(results[0])
|
all_counts.extend(results[0])
|
||||||
self.assertAllEqual(results[1], [b"string"] * batch_size)
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(results[0], results[1].values)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
|
self.assertAllEqual(results[2], [b"string"] * batch_size)
|
||||||
# Results scrambled, but include all the expected numbers.
|
# Results scrambled, but include all the expected numbers.
|
||||||
deltas = [all_counts[i + 1] - all_counts[i]
|
deltas = [all_counts[i + 1] - all_counts[i]
|
||||||
for i in range(len(all_counts) - 1)]
|
for i in range(len(all_counts) - 1)]
|
||||||
@ -564,17 +635,27 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
|
|||||||
zero64 = tf.constant(0, dtype=tf.int64)
|
zero64 = tf.constant(0, dtype=tf.int64)
|
||||||
examples = tf.Variable(zero64)
|
examples = tf.Variable(zero64)
|
||||||
counter = examples.count_up_to(num_a)
|
counter = examples.count_up_to(num_a)
|
||||||
|
sparse_counter = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(counter, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
|
||||||
# The second generates (99, "b") 35 times and then stops.
|
# The second generates (99, "b") 35 times and then stops.
|
||||||
num_b = 35
|
num_b = 35
|
||||||
ninety_nine = tf.train.limit_epochs(
|
ninety_nine = tf.train.limit_epochs(
|
||||||
tf.constant(99, dtype=tf.int64), num_b)
|
tf.constant(99, dtype=tf.int64), num_b)
|
||||||
|
sparse_ninety_nine = tf.SparseTensor(
|
||||||
|
indices=tf.reshape(zero64, [1, 1]),
|
||||||
|
values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
|
||||||
|
shape=[1])
|
||||||
|
|
||||||
# These get joined together and grouped into batches of 5.
|
# These get joined together and grouped into batches of 5.
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
batched = tf.train.shuffle_batch_join(
|
batched = tf.train.shuffle_batch_join(
|
||||||
[[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
|
[[counter, sparse_counter, "a"],
|
||||||
capacity=32, min_after_dequeue=16, seed=223607)
|
[ninety_nine, sparse_ninety_nine, "b"]],
|
||||||
|
batch_size=batch_size, capacity=32,
|
||||||
|
min_after_dequeue=16, seed=223607)
|
||||||
|
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
threads = tf.train.start_queue_runners()
|
threads = tf.train.start_queue_runners()
|
||||||
@ -588,9 +669,14 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
|
|||||||
results = sess.run(batched)
|
results = sess.run(batched)
|
||||||
tf.logging.info("Batch %d: %s", i, results[0])
|
tf.logging.info("Batch %d: %s", i, results[0])
|
||||||
self.assertEqual(len(results[0]), batch_size)
|
self.assertEqual(len(results[0]), batch_size)
|
||||||
self.assertEqual(len(results[1]), batch_size)
|
self.assertEqual(len(results[2]), batch_size)
|
||||||
which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
|
self.assertAllEqual(results[0], results[1].values)
|
||||||
which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
|
self.assertAllEqual(
|
||||||
|
results[1].indices,
|
||||||
|
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
|
||||||
|
self.assertAllEqual(results[1].shape, [batch_size, 1])
|
||||||
|
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
|
||||||
|
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
|
||||||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||||
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
|
||||||
all_a.extend([results[0][i] for i in which_a])
|
all_a.extend([results[0][i] for i in which_a])
|
||||||
|
Loading…
Reference in New Issue
Block a user