STT-tensorflow/tensorflow/python/training/input_test.py
A. Unique TensorFlower a78fa541d8 Replace list comprehension with generator expressions.
PiperOrigin-RevId: 285822581
Change-Id: I679256cc6f5890fa93ff3a2bfb9136b5d679d3ac
2019-12-16 12:26:12 -08:00

2424 lines
96 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for training.input."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test as test_lib
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import input as inp
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.util import compat
class MatchFilenamesOnceTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def test(self):
temp_dir = self.get_temp_dir()
filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)]
additional = [
os.path.join(self.get_temp_dir(), "match_filenames.%d" % i)
for i in range(3)
]
for name in additional:
open(name, "w").write("Some contents")
filenames = list(set(filenames + additional))
with self.cached_session():
star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
question = inp.match_filenames_once(
os.path.join(self.get_temp_dir(), "match_filenames.?"))
one = inp.match_filenames_once(additional[1])
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
self.assertItemsEqual(
map(compat.as_bytes, filenames), self.evaluate(star))
self.assertItemsEqual(
map(compat.as_bytes, additional), self.evaluate(question))
self.assertItemsEqual([compat.as_bytes(additional[1])],
self.evaluate(one))
class LimitEpochsTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def testNoLimit(self):
with self.cached_session():
seven = constant_op.constant(7)
seven_forever = inp.limit_epochs(seven)
variables.local_variables_initializer().run()
for _ in range(100):
self.assertEqual(7, self.evaluate(seven_forever))
@test_util.run_deprecated_v1
def testLimit(self):
with self.cached_session():
love_me = constant_op.constant("Love Me")
love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
self.assertEqual(b"Love Me", self.evaluate(love_me_two_times))
self.assertEqual(b"Love Me", self.evaluate(love_me_two_times))
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(love_me_two_times)
class InputProducerTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def testNoShuffle(self):
with self.cached_session():
input_tensor = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
num_epochs = 2
queue = inp.input_producer(
input_tensor, num_epochs=num_epochs, shuffle=False)
dequeue_many = queue.dequeue_many(len(input_tensor) * num_epochs)
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# No randomness, so just see repeated copies of the input.
self.assertAllEqual(input_tensor * num_epochs,
self.evaluate(dequeue_many))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testNoShapeInference(self):
with self.cached_session():
# Disable shape inference for the input.
input_value = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
input_tensor = array_ops.placeholder_with_default(input_value, shape=None)
num_epochs = 2
queue = inp.input_producer(
input_tensor, element_shape=[4], num_epochs=num_epochs, shuffle=False)
dequeue_many = queue.dequeue_many(len(input_value) * num_epochs)
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# No randomness, so just see repeated copies of the input.
self.assertAllEqual(input_value * num_epochs, self.evaluate(dequeue_many))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testShapeError(self):
input_tensor = array_ops.placeholder(dtypes.float32, None)
with self.assertRaisesRegexp(ValueError, "fully defined shape"):
_ = inp.input_producer(input_tensor)
class StringInputProducerTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def testNoShuffle(self):
with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
num_epochs = 3
queue = inp.string_input_producer(
strings, num_epochs=num_epochs, shuffle=False)
dequeue_many = queue.dequeue_many(len(strings) * num_epochs)
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# No randomness, so just see repeated copies of the input.
output = self.evaluate(dequeue_many)
self.assertAllEqual(strings * num_epochs, output)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testShuffle(self):
with self.cached_session():
strings = [b"a", b"b", b"c"]
num_epochs = 600
queue = inp.string_input_producer(
strings, num_epochs=num_epochs, shuffle=True, seed=271828)
dequeue_many = queue.dequeue_many(len(strings))
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Validate that we only shuffle the strings within an epoch and
# count how often each possible order appears.
expected = [b"abc", b"acb", b"bac", b"bca", b"cab", b"cba"]
frequency = {}
for e in expected:
frequency[e] = 0
for _ in range(num_epochs):
output = self.evaluate(dequeue_many)
key = b"".join(output)
self.assertIn(key, expected)
frequency[key] += 1
# Expect an approximately even distribution over all possible orders.
expected_frequency = num_epochs / len(expected)
margin = expected_frequency * 0.4
tf_logging.info("Observed counts: %s", frequency)
for key in expected:
value = frequency[key]
self.assertGreater(value, expected_frequency - margin)
self.assertLess(value, expected_frequency + margin)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
def testNullStringPython(self):
# Graph-construction time check for empty string list:
with self.cached_session():
with self.assertRaises(ValueError):
_ = inp.string_input_producer([])
@test_util.run_deprecated_v1
def testNullString(self):
# Runtime check for empty string list. This is slightly oblique:
# The queue runner should die with an assertion error on the null
# input tensor, causing the dequeue to fail with an OutOfRangeError.
with self.cached_session():
coord = coordinator.Coordinator()
queue = inp.string_input_producer(
constant_op.constant(
[], dtype=dtypes.string))
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners(coord=coord)
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
coord.request_stop()
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(
strings, shared_name="SHARED_NAME_XYZ", name="Q")
self.assertProtoEquals("s: 'SHARED_NAME_XYZ'",
queue.queue_ref.op.node_def.attr["shared_name"])
@test_util.run_deprecated_v1
def testConstructionRace(self):
with self.cached_session() as sess:
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(strings, shuffle=False)
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
for _ in range(2):
for string in strings:
# NOTE(mrry): This is not the recommended way to write
# dequeuing code (instead you should create a single dequeue
# op before starting the queue runners, and run it
# repeatedly), because it leads to concurrent reading and
# writing of the `tf.Graph` object. However, many users
# write code this way, so we include this test to ensure
# that we can support it.
self.assertEquals(string, self.evaluate(queue.dequeue()))
coord.request_stop()
coord.join(threads)
class RangeInputProducerTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def testNoShuffle(self):
with self.cached_session():
num_epochs = 3
range_size = 5
queue = inp.range_input_producer(
range_size, num_epochs=num_epochs, shuffle=False)
dequeue_many = queue.dequeue_many(range_size * num_epochs)
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# No randomness, so just see repeated copies of the input.
output = self.evaluate(dequeue_many)
self.assertAllEqual(list(xrange(range_size)) * num_epochs, output)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testShuffle(self):
with self.cached_session():
num_epochs = 200
range_size = 2
queue = inp.range_input_producer(
range_size, num_epochs=num_epochs, shuffle=True, seed=314159)
dequeue_many = queue.dequeue_many(range_size)
dequeue = queue.dequeue()
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Validate that we only shuffle the integers within an epoch and
# count how often each possible order appears.
expected = [12, 21]
frequency = {}
for e in expected:
frequency[e] = 0
for _ in range(num_epochs):
output = self.evaluate(dequeue_many)
key = 10 * (output[0] + 1) + (output[1] + 1)
self.assertIn(key, expected)
frequency[key] += 1
# Expect an approximately even distribution over all possible orders.
expected_frequency = num_epochs / len(expected)
margin = expected_frequency * 0.4
tf_logging.info("Observed counts: %s", frequency)
for key in expected:
value = frequency[key]
self.assertGreater(value, expected_frequency - margin)
self.assertLess(value, expected_frequency + margin)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(dequeue)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
range_size = 5
queue = inp.range_input_producer(
range_size, shared_name="SHARED_NAME_XYZ", name="Q")
self.assertProtoEquals("s: 'SHARED_NAME_XYZ'",
queue.queue_ref.op.node_def.attr["shared_name"])
class SliceInputProducerTest(test_lib.TestCase):
@test_util.run_deprecated_v1
def testNoShuffle(self):
with self.cached_session() as sess:
num_epochs = 3
source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
source_ints = [2, 3, 5, 7]
slices = inp.slice_input_producer(
[source_strings, source_ints], num_epochs=num_epochs, shuffle=False)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# No randomness, so just see repeated copies of the input.
num_items = len(source_strings) * num_epochs
output = [self.evaluate(slices) for _ in range(num_items)]
out_strings, out_ints = zip(*output)
self.assertAllEqual(source_strings * num_epochs, out_strings)
self.assertAllEqual(source_ints * num_epochs, out_ints)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(slices)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testShuffle(self):
with self.cached_session() as sess:
num_epochs = 1200
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
slices = inp.slice_input_producer(
[source_strings, source_ints],
num_epochs=num_epochs,
shuffle=True,
seed=161803)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Validate that we only shuffle the integers within an epoch and
# count how often each possible order appears.
expected = [
b",".join(x)
for x in itertools.permutations([b"A7", b"B3", b"D5", b"G2"])
]
frequency = {}
for e in expected:
frequency[e] = 0
for _ in range(num_epochs):
output = [self.evaluate(slices) for _ in range(len(source_strings))]
key = b",".join(s + compat.as_bytes(str(i)) for s, i in output)
self.assertIn(key, expected)
frequency[key] += 1
# Expect an approximately even distribution over all possible orders.
expected_frequency = num_epochs / len(expected)
margin = expected_frequency * 0.4
tf_logging.info("Observed counts: %s", frequency)
for key in expected:
value = frequency[key]
self.assertGreater(value, expected_frequency - margin)
self.assertLess(value, expected_frequency + margin)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(slices)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
slices = inp.slice_input_producer(
[source_strings, source_ints],
shared_name="SHARED_NAME_XYZ",
name="sip")
self.assertProtoEquals(
"s: 'SHARED_NAME_XYZ'",
slices[0].op.inputs[1].op.inputs[0].op.node_def.attr["shared_name"])
class DictHelperTest(test_lib.TestCase):
def testListInputs(self):
l = [1, 2, 3, 11, 22, 33]
l2 = inp._as_tensor_list(l)
self.assertEquals(l, l2)
l3 = inp._as_original_type(l, l2)
self.assertEquals(l, l3)
def testDictInputs(self):
d = {"a": 1, "b": 2, "c": 3, "aa": 11, "bb": 22, "cc": 33}
l = inp._as_tensor_list(d)
self.assertEquals([1, 11, 2, 22, 3, 33], l)
d2 = inp._as_original_type(d, l)
self.assertEquals(d, d2)
def testHeterogeneousKeysDictInputs(self):
d = {"z": 1, 1: 42, ("a", "b"): 100}
l = inp._as_tensor_list(d)
self.assertEquals([100, 42, 1], l)
d2 = inp._as_original_type(d, l)
self.assertEquals(d, d2)
class BatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(
array_ops.stack([zero64, zero64 + 1]), [2, 1]),
values=math_ops.cast(
array_ops.stack([counter, -counter]), dtypes.float32),
dense_shape=[2])
if use_dict:
batched = inp.batch(
{
"c": counter,
"s": sparse_counter,
"S": "string"
},
batch_size=batch_size)
batched_fetch = [batched["c"], batched["s"], batched["S"]]
else:
batched = inp.batch(
[counter, sparse_counter, "string"], batch_size=batch_size)
batched_fetch = batched
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
for i in range(num_batches):
results = self.evaluate(batched_fetch)
self.assertAllEqual(results[0],
np.arange(i * batch_size, (i + 1) * batch_size))
self.assertAllEqual(
results[1].indices,
np.vstack((
np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ...
[0, 1] * batch_size)).T)
# [x, -x, x+1, -(x+1), ...]
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
self.assertAllEqual(results[1].values, expected)
self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testOneThread(self):
self._testOneThreadHelper(use_dict=False)
@test_util.run_deprecated_v1
def testOneThreadDict(self):
self._testOneThreadHelper(use_dict=True)
@test_util.run_deprecated_v1
def testUint32DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
batched = inp.batch([values], batch_size=2)
with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testUint64DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
batched = inp.batch([values], batch_size=2)
with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testOneThreadDynamicPad(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
string = array_ops.tile(["string"],
math_ops.cast(array_ops.stack([counter]),
dtypes.int32))
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
batched = inp.batch(
[counter, string], batch_size=batch_size, dynamic_pad=True)
threads = queue_runner_impl.start_queue_runners()
for i in range(num_batches):
results = self.evaluate(batched)
expected_results = np.arange(i * batch_size, (i + 1) * batch_size)
max_len = expected_results[-1]
self.assertAllEqual(results[0], expected_results)
expected_strings = [[b"string"] * rep + [b""] * (max_len - rep)
for rep in expected_results]
self.assertAllEqual(results[1], expected_strings)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testOneThreadEnqueueMany(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
pre_batched = inp.batch([counter, sparse_counter, "string"], batch_size=2)
batched = inp.batch(pre_batched, enqueue_many=True, batch_size=batch_size)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
for i in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual(results[0],
np.arange(i * batch_size, (i + 1) * batch_size))
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].values,
np.arange(i * batch_size, (i + 1) * batch_size))
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testManyThreads(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
batched = inp.batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
num_threads=4)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * batch_size)
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testOneThreadSmallerBatch(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size + extra_elements)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(
array_ops.stack([zero64, zero64 + 1]), [2, 1]),
values=math_ops.cast(
array_ops.stack([counter, -counter]), dtypes.float32),
dense_shape=[2])
batched = inp.batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
allow_smaller_final_batch=True)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
for i in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual(results[0],
np.arange(i * batch_size, (i + 1) * batch_size))
self.assertAllEqual(
results[1].indices,
np.vstack((
np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ...
[0, 1] * batch_size)).T)
# [x, -x, x+1, -(x+1), ...]
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
self.assertAllEqual(results[1].values, expected)
self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra_elements.
results = self.evaluate(batched)
self.assertAllEqual(results[0],
np.arange(num_batches * batch_size,
num_batches * batch_size + extra_elements))
self.assertAllEqual(
results[1].indices,
np.vstack((
np.arange(2 * extra_elements) // 2, # 0, 0, 1, 1, ...
[0, 1] * extra_elements)).T)
self.assertAllEqual(results[1].dense_shape, [extra_elements, 2])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testManyThreadsSmallerBatch(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size + extra_elements)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
batched = inp.batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
num_threads=4,
allow_smaller_final_batch=True)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra_elements.
results = self.evaluate(batched)
tf_logging.info("Last Batch: %s", results[0])
self.assertEqual(len(results[0]), extra_elements)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(extra_elements), np.zeros(extra_elements))).T)
self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
self.assertItemsEqual(all_counts,
range(num_batches * batch_size + extra_elements))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
batched = inp.batch(
[counter, "string"],
batch_size=batch_size,
shared_name="SHARED_NAME_XYZ",
name="Q")
self.assertProtoEquals(
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
@test_util.run_deprecated_v1
def testCannotInferRankError(self):
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch([x], batch_size=2)
@test_util.run_deprecated_v1
def testBatchedSparseTensorInferredShape(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.batch([sparse], batch_size=2)
self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.batch([sparse], batch_size=2, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.batch([sparse], batch_size=2)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.batch([sparse], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testSingleElementDict(self):
x = inp.batch({"c": [12, 12]}, batch_size=8)
self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
counter = examples.count_up_to(num_batches * batch_size * 2)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.zeros(
[1, 1], dtype=dtypes.int64),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
keep_input = array_ops.squeeze(
math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
batched = inp.maybe_batch(
to_batch,
keep_input,
batch_size,
num_threads=num_threads,
enqueue_many=enqueue_many)
variables.initialize_all_variables().run()
variables.initialize_local_variables().run()
threads = queue_runner_impl.start_queue_runners()
for _ in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
self.assertAllEqual([b"string"] * batch_size, results[2])
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInput(self):
self._testKeepInputHelper(1, False)
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(1, True)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInput(self):
self._testKeepInputHelper(5, False)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(5, True)
@test_util.run_deprecated_v1
def testMaybeEnqueuePerExample(self):
self._testKeepInputHelper(1, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testMultipleThreadMaybeEnqueuePerExample(self):
self._testKeepInputHelper(5, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testInvalidKeepInputVector(self):
# Can't have vector `keep_input` with `enqueue_many=False`.
with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
inp.maybe_batch([array_ops.zeros(5)],
keep_input=constant_op.constant([True, False]),
batch_size=1,
enqueue_many=False)
# Can't have `keep_input` with more than one dimension.
with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
inp.maybe_batch([array_ops.zeros(5)],
keep_input=constant_op.constant([[True], [False]]),
batch_size=1,
enqueue_many=True)
# `keep_input` must have dimensions determined at graph construction.
with self.assertRaisesRegexp(ValueError,
"must be known at graph construction"):
inp.maybe_batch([array_ops.zeros(5)],
keep_input=array_ops.placeholder(dtypes.bool),
batch_size=1,
enqueue_many=True)
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShape(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2)
self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch(
[sparse], keep_input=True, batch_size=2, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch(
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch(
[sparse], keep_input=True, batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch(
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchCorrectValues(self):
sparse_t = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class BatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
with self.cached_session() as sess:
# Two threads, the first generates (0..69, "a").
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
# The second generates (99, "b") 90 times and then stops.
num_b = 90
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
sparse_ninety_nine = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
if use_dict:
batched = inp.batch_join(
[{
"c": counter,
"s": sparse_counter,
"S": "a"
}, {
"c": ninety_nine,
"s": sparse_ninety_nine,
"S": "b"
}],
batch_size=batch_size)
batched_fetch = [batched["c"], batched["s"], batched["S"]]
else:
batched = inp.batch_join(
[[counter, sparse_counter, "a"],
[ninety_nine, sparse_ninety_nine, "b"]],
batch_size=batch_size)
batched_fetch = batched
# Shapes.
self.assertEqual(3, len(batched_fetch))
self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list())
self.assertAllEqual((None, 2),
batched_fetch[1].indices.get_shape().as_list())
self.assertAllEqual((None,),
batched_fetch[1].values.get_shape().as_list())
self.assertAllEqual((2,),
batched_fetch[1].dense_shape.get_shape().as_list())
self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched_fetch)
self.assertEqual(3, len(results))
self.assertEqual(batch_size, len(results[0]))
self.assertEqual(batch_size, len(results[2]))
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# We'd like to see some minimum level of mixing of the results of both
# threads, but we can't rely on fair thread scheduling, so we just log.
# self.assertGreater(saw_both, 1)
tf_logging.info("testTwoThreads%s saw both count: %s",
"Dict" if use_dict else "", saw_both)
# Verify the order of results from "a" were preserved.
self.assertAllEqual(all_a, np.arange(num_a))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testTwoThreads(self):
self._testTwoThreadsHelper(use_dict=False)
@test_util.run_deprecated_v1
def testTwoThreadsDict(self):
self._testTwoThreadsHelper(use_dict=True)
@test_util.run_deprecated_v1
def testMismatchedDictKeys(self):
with self.assertRaisesRegexp(ValueError, "must have the same keys"):
inp.batch_join(
[{
"c": 12,
"s": 123,
"S": "a"
}, {
"cool": -12,
"s": 99,
"S": "b"
}],
batch_size=8)
@test_util.run_deprecated_v1
def testTwoThreadsDynamicPad(self):
with self.cached_session() as sess:
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
# The second generates (99, ["b"] * 99) 90 times and then stops.
num_b = 90
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
# These get joined together and grouped into batches of 5.
batch_size = 5
a = array_ops.tile(
["a"],
math_ops.cast(array_ops.stack([counter + 1]), dtypes.int32))
b = array_ops.tile(
["b"],
math_ops.cast(array_ops.stack([ninety_nine]), dtypes.int32))
batched = inp.batch_join(
[[counter, a], [ninety_nine, b]],
batch_size=batch_size,
dynamic_pad=True)
# Shapes.
self.assertEqual(2, len(batched))
self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
self.assertAllEqual((batch_size, None), batched[1].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
count_string_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched)
self.assertEqual(2, len(results))
self.assertEqual(len(results[0]), batch_size)
self.assertEqual(len(results[1]), batch_size)
for s in results[1]:
if s[0] == b"b":
self.assertAllEqual(s, [b"b"] * 99)
else:
count_string_a.append(sum(x == b"a" for x in s))
which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# We'd like to see some minimum level of mixing of the results of both
# threads, but we can't rely on fair thread scheduling, so we just log.
# self.assertGreater(saw_both, 1)
tf_logging.info("testTwoThreadsDynamicPad saw both count: %s", saw_both)
# Verify the order of results from "a" were preserved.
self.assertAllEqual( # tiled "a" with counter + 1
count_string_a, np.arange(num_a) + 1)
self.assertAllEqual(all_a, np.arange(num_a))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testTwoThreadsSmallerBatch(self):
with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, "a").
num_a = 70 + extra_elements
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
# The second generates (99, "b") 90 times and then stops.
num_b = 90 + extra_elements
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
sparse_ninety_nine = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
batched = inp.batch_join(
[[counter, sparse_counter, "a"],
[ninety_nine, sparse_ninety_nine, "b"]],
batch_size=batch_size,
allow_smaller_final_batch=True)
# Shapes.
self.assertEqual(3, len(batched))
self.assertAllEqual((None,), batched[0].get_shape().as_list())
self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
self.assertAllEqual((None,), batched[2].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
self.assertEqual(len(results[2]), batch_size)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# Reached the final batch with 2 * extra_elements.
results = self.evaluate(batched)
tf_logging.info("Last Batch: %s", results[0])
self.assertEqual(len(results[0]), 2 * extra_elements)
self.assertEqual(len(results[2]), 2 * extra_elements)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].indices,
np.vstack((np.arange(2 * extra_elements),
np.zeros(2 * extra_elements))).T)
self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
# We'd like to see some minimum level of mixing of the results of both
# threads, but we can't rely on fair thread scheduling, so we just log.
# self.assertGreater(saw_both, 1)
tf_logging.info("testTwoThreadsSmallerBatch saw both count: %s", saw_both)
# Verify the order of results from "a" were preserved.
self.assertAllEqual(all_a, np.arange(num_a))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testTwoThreadsDynamicPadSmallerBatch(self):
with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70 + extra_elements
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
# The second generates (99, ["b"] * 99) 90 times and then stops.
num_b = 90 + extra_elements
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
# These get joined together and grouped into batches of 5.
batch_size = 5
a = array_ops.tile(
["a"],
math_ops.cast(array_ops.stack([counter + 1]), dtypes.int32))
b = array_ops.tile(
["b"],
math_ops.cast(array_ops.stack([ninety_nine]), dtypes.int32))
batched = inp.batch_join(
[[counter, a], [ninety_nine, b]],
batch_size=batch_size,
dynamic_pad=True,
allow_smaller_final_batch=True)
# Shapes.
self.assertEqual(2, len(batched))
self.assertAllEqual((None,), batched[0].get_shape().as_list())
self.assertAllEqual((None, None), batched[1].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
count_string_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
self.assertEqual(len(results[1]), batch_size)
for s in results[1]:
if s[0] == b"b":
self.assertAllEqual(s, [b"b"] * 99)
else:
count_string_a.append(sum(x == b"a" for x in s))
which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# Reached the final batch with 2 * extra_elements.
results = self.evaluate(batched)
tf_logging.info("Last Batch: %s", results[0])
self.assertEqual(len(results[0]), 2 * extra_elements)
self.assertEqual(len(results[1]), 2 * extra_elements)
for s in results[1]:
if s[0] == b"b":
self.assertAllEqual(s, [b"b"] * 99)
else:
count_string_a.append(sum(x == b"a" for x in s))
which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
# We'd like to see some minimum level of mixing of the results of both
# threads, but we can't rely on fair thread scheduling, so we just log.
# self.assertGreater(saw_both, 1)
tf_logging.info("testTwoThreadsDynamicPadSmallerBatch saw both count: %s",
saw_both)
# Verify the order of results from "a" were preserved.
self.assertAllEqual( # tiled "a" with counter + 1
count_string_a, np.arange(num_a) + 1)
self.assertAllEqual(all_a, np.arange(num_a))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
batched = inp.batch_join(
[[counter, "string"]],
batch_size=batch_size,
shared_name="SHARED_NAME_XYZ",
name="Q")
# Shapes.
self.assertEqual(2, len(batched))
self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
self.assertAllEqual((batch_size,), batched[1].get_shape().as_list())
self.assertProtoEquals(
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
@test_util.run_deprecated_v1
def testCannotInferRankError(self):
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch_join([[x]], batch_size=2)
@test_util.run_deprecated_v1
def testSingleElementDict(self):
x = inp.batch_join([{"c": [12, 12]}], batch_size=8)
self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
counter = examples.count_up_to(num_batches * batch_size * 2)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.zeros(
[1, 1], dtype=dtypes.int64),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
keep_input = array_ops.squeeze(
math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
batched = inp.maybe_batch_join(
[to_batch] * num_threads,
keep_input,
batch_size,
enqueue_many=enqueue_many)
variables.initialize_all_variables().run()
variables.initialize_local_variables().run()
threads = queue_runner_impl.start_queue_runners()
for _ in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual(
[0] * batch_size,
np.mod(results[0], 2),)
self.assertAllEqual(
[0] * batch_size,
np.mod(results[1].values, 2),)
self.assertAllEqual([b"string"] * batch_size, results[2])
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInput(self):
self._testKeepInputHelper(1, False)
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(1, True)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInput(self):
self._testKeepInputHelper(5, False)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(5, True)
@test_util.run_deprecated_v1
def testSingleThreadKeepInputPerExample(self):
self._testKeepInputHelper(1, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testMultipleThreadKeepInputPerExample(self):
self._testKeepInputHelper(5, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testInvalidKeepInputVector(self):
# Can't have vector `keep_input` with `enqueue_many=False`.
with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
inp.maybe_batch_join([[array_ops.zeros(5)]],
keep_input=constant_op.constant([True, False]),
batch_size=1,
enqueue_many=False)
# Can't have `keep_input` with more than one dimension.
with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
inp.maybe_batch_join([[array_ops.zeros(5)]],
keep_input=constant_op.constant([[True], [False]]),
batch_size=1,
enqueue_many=True)
# `keep_input` must have dimensions determined at graph construction.
with self.assertRaisesRegexp(ValueError,
"must be known at graph construction"):
inp.maybe_batch_join([[array_ops.zeros(5)]],
keep_input=array_ops.placeholder(dtypes.bool),
batch_size=1,
enqueue_many=True)
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShape(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2)
self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_batch_join(
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_batch_join(
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchCorrectValues(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class ShuffleBatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
if use_dict:
batched = inp.shuffle_batch(
{
"c": counter,
"s": sparse_counter,
"S": "string"
},
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=141421)
batched_fetch = [batched["c"], batched["s"], batched["S"]]
else:
batched = inp.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=141421)
batched_fetch = batched
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = self.evaluate(batched_fetch)
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [
all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testOneThread(self):
self._testOneThreadHelper(use_dict=False)
@test_util.run_deprecated_v1
def testOneThreadDict(self):
self._testOneThreadHelper(use_dict=True)
@test_util.run_deprecated_v1
def testOneThreadSmallerBatch(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
total_elements = num_batches * batch_size + extra_elements
counter = examples.count_up_to(total_elements)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
batched = inp.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=141421,
allow_smaller_final_batch=True)
batched_fetch = batched
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for _ in range(num_batches):
results = self.evaluate(batched_fetch)
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra elements.
results = self.evaluate(batched)
self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
all_counts.extend(results[0])
# Results scrambled, but include all the expected numbers.
deltas = [
all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertItemsEqual(all_counts, range(total_elements))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testManyThreads(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
batched = inp.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=173205,
num_threads=4)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [
all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testManyThreadsSmallerBatch(self):
with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
total_elements = num_batches * batch_size + extra_elements
counter = examples.count_up_to(total_elements)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
batched = inp.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=173205,
num_threads=4,
allow_smaller_final_batch=True)
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra elements.
results = self.evaluate(batched)
self.assertAllEqual(results[0].shape, [extra_elements])
self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
all_counts.extend(results[0])
# Results scrambled, but include all the expected numbers.
deltas = [
all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertItemsEqual(all_counts, range(total_elements))
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
batched = inp.shuffle_batch(
[counter, "string"],
batch_size=batch_size,
capacity=32,
min_after_dequeue=10,
shared_name="SHARED_NAME_XYZ",
name="Q")
self.assertProtoEquals(
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
counter = examples.count_up_to(num_batches * batch_size * 2)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.zeros(
[1, 1], dtype=dtypes.int64),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
keep_input = array_ops.squeeze(
math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
batched = inp.maybe_shuffle_batch(
to_batch,
batch_size,
10,
1,
keep_input,
num_threads=num_threads,
enqueue_many=enqueue_many)
variables.initialize_all_variables().run()
variables.initialize_local_variables().run()
threads = queue_runner_impl.start_queue_runners()
for _ in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
self.assertAllEqual([b"string"] * batch_size, results[2])
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInput(self):
self._testKeepInputHelper(1, False)
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(1, True)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInput(self):
self._testKeepInputHelper(5, False)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(5, True)
@test_util.run_deprecated_v1
def testSingleThreadKeepInputPerExample(self):
self._testKeepInputHelper(1, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testMultipleThreadKeepInputPerExample(self):
self._testKeepInputHelper(5, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testInvalidKeepInputVector(self):
# Can't have vector `keep_input` with `enqueue_many=False`.
with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
keep_input=constant_op.constant([True, False]),
enqueue_many=False)
# Can't have `keep_input` with more than one dimension.
with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
keep_input=constant_op.constant([[True]]),
enqueue_many=True)
# `keep_input` must have dimensions determined at graph construction.
with self.assertRaisesRegexp(ValueError,
"must be known at graph construction"):
inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
keep_input=array_ops.placeholder(dtypes.bool),
enqueue_many=True)
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShape(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True)
self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch(
[sparse], 2, 10, 1, True, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch(
[sparse], 2, 10, 1, [True, False], enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch(
[sparse], 2, 10, 1, True, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch(
[sparse], 2, 10, 1, [True, False], enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
class ShuffleBatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
with self.cached_session() as sess:
# Two threads, the first generates (0..24, "a").
num_a = 25
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
# The second generates (99, "b") 35 times and then stops.
num_b = 35
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
sparse_ninety_nine = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
if use_dict:
batched = inp.shuffle_batch_join(
[{
"c": counter,
"s": sparse_counter,
"S": "a"
}, {
"c": ninety_nine,
"s": sparse_ninety_nine,
"S": "b"
}],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=223607)
batched_fetch = [batched["c"], batched["s"], batched["S"]]
else:
batched = inp.shuffle_batch_join(
[[counter, sparse_counter, "a"],
[ninety_nine, sparse_ninety_nine, "b"]],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=223607)
batched_fetch = batched
# Shapes.
self.assertEqual(3, len(batched_fetch))
self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list())
self.assertAllEqual((None, 2),
batched_fetch[1].indices.get_shape().as_list())
self.assertAllEqual((None,),
batched_fetch[1].values.get_shape().as_list())
self.assertAllEqual((2,),
batched_fetch[1].dense_shape.get_shape().as_list())
self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched_fetch)
self.assertEqual(3, len(results))
self.assertEqual(len(results[0]), batch_size)
self.assertEqual(len(results[2]), batch_size)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# Some minimum level of mixing of the results of both threads.
self.assertGreater(saw_both, 1)
# Saw all the items from "a", but scrambled.
self.assertItemsEqual(all_a, range(num_a))
deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testTwoThreads(self):
self._testTwoThreadsHelper(use_dict=False)
@test_util.run_deprecated_v1
def testTwoThreadsDict(self):
self._testTwoThreadsHelper(use_dict=True)
@test_util.run_deprecated_v1
def testTwoThreadsSmallerBatch(self):
with self.cached_session() as sess:
# Two threads, the first generates (0..26, "a").
extra_elements = 2
num_a = 25 + extra_elements
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_a)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
# The second generates (99, "b") 37 times and then stops.
num_b = 35 + extra_elements
ninety_nine = inp.limit_epochs(
constant_op.constant(
99, dtype=dtypes.int64), num_b)
sparse_ninety_nine = sparse_tensor.SparseTensor(
indices=array_ops.reshape(zero64, [1, 1]),
values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
batched = inp.shuffle_batch_join(
[[counter, sparse_counter, "a"],
[ninety_nine, sparse_ninety_nine, "b"]],
batch_size=batch_size,
capacity=32,
min_after_dequeue=16,
seed=223607,
allow_smaller_final_batch=True)
# Shapes.
self.assertEqual(3, len(batched))
self.assertAllEqual((None,), batched[0].get_shape().as_list())
self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
self.assertAllEqual((None,), batched[2].get_shape().as_list())
self.evaluate(variables.global_variables_initializer())
variables.local_variables_initializer().run()
threads = queue_runner_impl.start_queue_runners()
# Should see the "a" and "b" threads mixed together.
all_a = []
seen_b = 0
saw_both = 0
num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = self.evaluate(batched)
tf_logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
self.assertEqual(len(results[2]), batch_size)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),
[results[0][i] for i in which_b])
# Reached end with 2 * extra_elements left
results = self.evaluate(batched)
self.assertEqual(len(results[0]), 2 * extra_elements)
self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
self.assertEqual(len(results[2]), 2 * extra_elements)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].indices,
np.vstack((np.arange(2 * extra_elements),
np.zeros(2 * extra_elements))).T)
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
if which_a and which_b:
saw_both += 1
all_a.extend(results[0][i] for i in which_a)
seen_b += len(which_b)
# Some minimum level of mixing of the results of both threads.
self.assertGreater(saw_both, 1)
# Saw all the items from "a", but scrambled, including extras.
self.assertItemsEqual(all_a, range(num_a))
deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertEqual(seen_b, num_b)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_deprecated_v1
def testMismatchedDictKeys(self):
with self.assertRaisesRegexp(ValueError, "must have the same keys"):
inp.shuffle_batch_join(
[{
"c": 12,
"s": 123,
"S": "a"
}, {
"cool": -12,
"s": 99,
"S": "b"
}],
batch_size=8,
capacity=32,
min_after_dequeue=16,
seed=223607)
@test_util.run_deprecated_v1
def testSharedName(self):
with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
examples = variables.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
batched = inp.shuffle_batch_join(
[[counter, "string"]],
batch_size=batch_size,
capacity=32,
min_after_dequeue=10,
shared_name="SHARED_NAME_XYZ",
name="Q")
# Shapes.
self.assertEqual(2, len(batched))
self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
self.assertAllEqual((batch_size,), batched[1].get_shape().as_list())
self.assertProtoEquals(
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
counter = examples.count_up_to(num_batches * batch_size * 2)
sparse_counter = sparse_tensor.SparseTensor(
indices=array_ops.zeros(
[1, 1], dtype=dtypes.int64),
values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
keep_input = array_ops.squeeze(
math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
batched = inp.maybe_shuffle_batch_join(
[to_batch] * num_threads,
batch_size,
10,
1,
keep_input,
enqueue_many=enqueue_many)
variables.initialize_all_variables().run()
variables.initialize_local_variables().run()
threads = queue_runner_impl.start_queue_runners()
for _ in range(num_batches):
results = self.evaluate(batched)
self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
self.assertAllEqual([b"string"] * batch_size, results[2])
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
self.evaluate(batched)
for thread in threads:
thread.join()
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInput(self):
self._testKeepInputHelper(1, False)
@test_util.run_v1_only("b/120545219")
def testSingleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(1, True)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInput(self):
self._testKeepInputHelper(5, False)
@test_util.run_v1_only("b/120545219")
def testMultipleThreadKeepInputEnqueueMany(self):
self._testKeepInputHelper(5, True)
@test_util.run_deprecated_v1
def testSingleThreadKeepInputPerExample(self):
self._testKeepInputHelper(1, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testMultipleThreadKeepInputPerExample(self):
self._testKeepInputHelper(5, True, keep_input_vector=True)
@test_util.run_deprecated_v1
def testInvalidKeepInputVector(self):
# Can't have vector `keep_input` with `enqueue_many=False`.
with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
inp.maybe_shuffle_batch_join(
[[array_ops.zeros(5)]], 1, 10, 1,
keep_input=constant_op.constant([True, False]),
enqueue_many=False)
# Can't have `keep_input` with more than one dimension.
with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
inp.maybe_shuffle_batch_join(
[[array_ops.zeros(5)]], 1, 10, 1,
keep_input=constant_op.constant([[True]]),
enqueue_many=True)
# `keep_input` must have dimensions determined at graph construction.
with self.assertRaisesRegexp(ValueError,
"must be known at graph construction"):
inp.maybe_shuffle_batch_join(
[[array_ops.zeros(5)]], 1, 10, 1,
keep_input=array_ops.placeholder(dtypes.bool),
enqueue_many=True)
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShape(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, True, enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = inp.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, [True, False], enqueue_many=True)
self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, True, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
@test_util.run_deprecated_v1
def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
sparse = sparse_tensor.SparseTensor(
indices=array_ops.placeholder(dtypes.int64),
values=array_ops.placeholder(dtypes.float32),
dense_shape=array_ops.placeholder(dtypes.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = inp.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, [True, False], enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
if __name__ == "__main__":
test_lib.main()