Add bucketing input helpers to tf.contrib.training.
Change: 131891671
This commit is contained in:
parent
7a1210bdbd
commit
bc5df827de
@ -11,6 +11,7 @@ py_library(
|
||||
name = "training_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/training/bucket_ops.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
],
|
||||
@ -67,6 +68,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bucket_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/training/bucket_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -38,6 +38,17 @@ balanced.
|
||||
|
||||
@@stratified_sample
|
||||
@@stratified_sample_unknown_dist
|
||||
|
||||
## Bucketing
|
||||
|
||||
Use ['bucket'](#bucket) or
|
||||
['bucket_by_sequence_length'](#bucket_by_sequence_length) to stratify
|
||||
minibatches into groups ("buckets"). Use `bucket_by_sequence_length`
|
||||
with the argument `dynamic_pad=True` to receive minibatches of similarly
|
||||
sized sequences for efficient training via `dynamic_rnn`.
|
||||
|
||||
@@bucket
|
||||
@@bucket_by_sequence_length
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -45,6 +56,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
374
tensorflow/contrib/training/python/training/bucket_ops.py
Normal file
@ -0,0 +1,374 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations for bucketing data into groups.
|
||||
|
||||
The classes and functions in this module are used to queue up data into
|
||||
buckets conditional on side information (e.g. sequence length).
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import input as input_py
|
||||
from tensorflow.python.training import queue_runner
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_as_original_type = input_py._as_original_type
|
||||
_as_tensor_list = input_py._as_tensor_list
|
||||
_deserialize_sparse_tensors = input_py._deserialize_sparse_tensors
|
||||
_dtypes = input_py._dtypes
|
||||
_serialize_sparse_tensors = input_py._serialize_sparse_tensors
|
||||
_shapes = input_py._shapes
|
||||
_which_queue = input_py._which_queue
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _validate_bucket(tensor_list):
|
||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
|
||||
if not tensor_list:
|
||||
raise ValueError("Expected at least one tensor in bucket().")
|
||||
return tensor_list
|
||||
|
||||
|
||||
def bucket(tensors,
|
||||
which_bucket,
|
||||
batch_size,
|
||||
num_buckets,
|
||||
num_threads=1,
|
||||
capacity=32,
|
||||
shapes=None,
|
||||
dynamic_pad=False,
|
||||
allow_smaller_final_batch=False,
|
||||
keep_input=None,
|
||||
shared_name=None,
|
||||
name=None):
|
||||
"""Lazy bucketing of input tensors according to `which_bucket`.
|
||||
|
||||
The argument `tensors` can be a list or a dictionary of tensors.
|
||||
The value returned by the function will be of the same type
|
||||
as `tensors`.
|
||||
|
||||
The tensors entering this function are put into the bucket given by
|
||||
`which_bucket`. Each bucket has its own queue. When a bucket contains
|
||||
`batch_size` elements, this minibatch is pushed onto a top queue. The
|
||||
tensors returned from this function are a the result of dequeueing the
|
||||
next minibatch from this top queue.
|
||||
|
||||
This function is implemented using several queues. A `QueueRunner` for the
|
||||
queues is added to the current `Graph`'s `QUEUE_RUNNER` collection.
|
||||
|
||||
As the returned tensors are the result of of a dequeue operation, evaluating
|
||||
them will throw a `tf.errors.OutOfRangeError` when the input queue is
|
||||
exhausted. If these tensors are feeding another input queue, its queue runner
|
||||
will catch this exception, however, if they are used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
*N.B.:* If `dynamic_pad` is `False`, you must ensure that either
|
||||
(i) the `shapes` argument is passed, or (ii) all of the tensors in
|
||||
`tensors` must have fully-defined shapes. `ValueError` will be
|
||||
raised if neither of these conditions holds.
|
||||
|
||||
If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
|
||||
tensors is known, but individual dimensions may have shape `None`.
|
||||
In this case, for each enqueue the dimensions with value `None`
|
||||
may have a variable length; upon dequeue, the output tensors will be padded
|
||||
on the right to the maximum shape of the tensors in the current minibatch.
|
||||
For numbers, this padding takes value 0. For strings, this padding is
|
||||
the empty string. See `PaddingFIFOQueue` for more info.
|
||||
|
||||
If `allow_smaller_final_batch` is `True`, a smaller batch value than
|
||||
`batch_size` is returned when the queues are closed and there are not enough
|
||||
elements to fill the batch, otherwise the pending elements are discarded.
|
||||
In addition, all output tensors' static shapes, as accessed via the
|
||||
`get_shape()` method will have a 0th `Dimension` value of `None`, and
|
||||
operations that depend on fixed batch_size would fail.
|
||||
|
||||
Args:
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
num_buckets: A python integer, the number of buckets.
|
||||
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||
and also the maximum number of elements within each bucket.
|
||||
shapes: (Optional) The shapes for each example. Defaults to the
|
||||
inferred shapes for `tensors`.
|
||||
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||
The given dimensions are padded upon dequeue so that tensors within a
|
||||
batch have the same shapes.
|
||||
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||
batches to be smaller if there are insufficient items left in the queues.
|
||||
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||
controls whether the input is added to the queue or not. If it evaluates
|
||||
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||
dropped. This tensor essentially acts as a filtering mechanism.
|
||||
The default behavior is to assume `keep_input=True`.
|
||||
shared_name: (Optional). If set, the queues will be shared under the given
|
||||
name across multiple sessions.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A tuple `(bucket, outputs)` where `bucket` is
|
||||
a `int32` scalar tensor and `outputs` is a list or
|
||||
dictionary of batched outputs corresponding to elements of `tensors`.
|
||||
Every step will receive a new bucket of outputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `shapes` are not specified, and cannot be
|
||||
inferred from the elements of `tensors`.
|
||||
"""
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
with ops.name_scope(name, "bucket", tensor_list) as name:
|
||||
tensor_list = _validate_bucket(tensor_list)
|
||||
(tensor_list, sparse_info) = _serialize_sparse_tensors(
|
||||
tensor_list, enqueue_many=False)
|
||||
|
||||
# Round-trip batch_size to a tensor, and possibly back
|
||||
batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int32, name="batch_size")
|
||||
static_batch_size = tensor_util.constant_value(batch_size)
|
||||
batch_size = (
|
||||
static_batch_size if static_batch_size is not None else batch_size)
|
||||
|
||||
types = _dtypes([tensor_list])
|
||||
shapes = _shapes([tensor_list], shapes, enqueue_many=False)
|
||||
|
||||
which_bucket = ops.convert_to_tensor(
|
||||
which_bucket, dtype=dtypes.int32, name="which_bucket")
|
||||
|
||||
queue_creator = _which_queue(dynamic_pad)
|
||||
bucket_queues = []
|
||||
for i in range(num_buckets):
|
||||
shared_name_i = (
|
||||
"%s_%d" % (shared_name, i) if shared_name is not None else None)
|
||||
bucket_queues.append(
|
||||
queue_creator(capacity=capacity,
|
||||
dtypes=types,
|
||||
shapes=shapes,
|
||||
shared_name=shared_name_i, name="bucket_queue_%d" % i))
|
||||
|
||||
maybe_static_batch_size = (
|
||||
None if allow_smaller_final_batch else static_batch_size)
|
||||
|
||||
bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s)
|
||||
for s in bucket_queues[0].shapes]
|
||||
# top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
|
||||
# queues because if we use allow_smaller_final_batch, shapes will
|
||||
# contain Nones in their first entry; as a result, a regular
|
||||
# FIFOQueue would die when being passed shapes that are not fully defined.
|
||||
top_queue = data_flow_ops.PaddingFIFOQueue(
|
||||
capacity=capacity,
|
||||
dtypes=[dtypes.int32] + types,
|
||||
shapes=[tensor_shape.scalar()] + bucket_shapes,
|
||||
shared_name=shared_name, name="top_queue")
|
||||
|
||||
def enqueue_which():
|
||||
def enqueue_single(i):
|
||||
return bucket_queues[i].enqueue(tensor_list)
|
||||
enqueues = [
|
||||
control_flow_ops.cond(
|
||||
math_ops.equal(which_bucket, i),
|
||||
functools.partial(enqueue_single, i),
|
||||
control_flow_ops.no_op)
|
||||
for i in range(num_buckets)]
|
||||
return control_flow_ops.group(*enqueues, name="group_enqueues")
|
||||
|
||||
if keep_input is not None:
|
||||
# TODO(ebrevdo): Expand keep_input param to core training
|
||||
# methods, and pipe through to _serialize_sparse_tensors; so
|
||||
# that expensive serialization is guarded by keep_input.
|
||||
maybe_enqueue = control_flow_ops.cond(
|
||||
keep_input,
|
||||
enqueue_which,
|
||||
control_flow_ops.no_op)
|
||||
else:
|
||||
maybe_enqueue = enqueue_which()
|
||||
|
||||
bucket_enqueue_ops = [maybe_enqueue] * num_threads
|
||||
|
||||
if allow_smaller_final_batch:
|
||||
which_dequeue = lambda q: q.dequeue_up_to
|
||||
else:
|
||||
which_dequeue = lambda q: q.dequeue_many
|
||||
|
||||
enqueues_to_top = [
|
||||
top_queue.enqueue(
|
||||
[constant_op.constant(i)] +
|
||||
which_dequeue(q)(batch_size, name="read_bucket_%d" % i),
|
||||
name="enqueue_from_bucket_%d" % i)
|
||||
for i, q in enumerate(bucket_queues)]
|
||||
|
||||
for i, q in enumerate(bucket_queues):
|
||||
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||
q, [enqueues_to_top[i]],
|
||||
queue_closed_exception_types=(
|
||||
errors.OutOfRangeError, errors.CancelledError)))
|
||||
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||
top_queue, bucket_enqueue_ops,
|
||||
queue_closed_exception_types=(
|
||||
errors.OutOfRangeError, errors.CancelledError)))
|
||||
|
||||
for q in bucket_queues:
|
||||
logging_ops.scalar_summary(
|
||||
"bucket/%s/size" % q.name,
|
||||
math_ops.cast(top_queue.size(), dtypes.float32))
|
||||
logging_ops.scalar_summary(
|
||||
"bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity),
|
||||
math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity))
|
||||
|
||||
dequeued = top_queue.dequeue(name="dequeue_top")
|
||||
which_bucket_dequeued = dequeued[0]
|
||||
dequeued = dequeued[1:]
|
||||
dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
|
||||
return (which_bucket_dequeued, _as_original_type(tensors, dequeued))
|
||||
|
||||
|
||||
def bucket_by_sequence_length(input_length,
|
||||
tensors,
|
||||
batch_size,
|
||||
bucket_boundaries,
|
||||
num_threads=1,
|
||||
capacity=32,
|
||||
shapes=None,
|
||||
dynamic_pad=False,
|
||||
allow_smaller_final_batch=False,
|
||||
keep_input=None,
|
||||
shared_name=None,
|
||||
name=None):
|
||||
"""Lazy bucketing of inputs according to their length.
|
||||
|
||||
This method calls `tf.contrib.training.bucket` under the hood, after first
|
||||
subdividing the bucket boundaries into separate buckets and identifying which
|
||||
bucket the given `input_length` belongs to. See the documentation for
|
||||
`which_bucket` for details of the other arguments.
|
||||
|
||||
Args:
|
||||
input_length: `int32` scalar `Tensor`, the sequence length of tensors.
|
||||
tensors: The list or dictionary of tensors, representing a single element,
|
||||
to bucket. Nested lists are not supported.
|
||||
batch_size: The new batch size pulled from the queue
|
||||
(python int or int32 scalar).
|
||||
bucket_boundaries: int list, increasing non-negative numbers.
|
||||
The edges of the buckets to use when bucketing tensors. Two extra buckets
|
||||
are created, one for `input_length < bucket_boundaries[0]` and
|
||||
one for `input_length >= bucket_boundaries[-1]`.
|
||||
num_threads: An integer. The number of threads enqueuing `tensors`.
|
||||
capacity: An integer. The maximum number of minibatches in the top queue,
|
||||
and also the maximum number of elements within each bucket.
|
||||
shapes: (Optional) The shapes for each example. Defaults to the
|
||||
inferred shapes for `tensors`.
|
||||
dynamic_pad: Boolean. Allow variable dimensions in input shapes.
|
||||
The given dimensions are padded upon dequeue so that tensors within a
|
||||
batch have the same shapes.
|
||||
allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
|
||||
batches to be smaller if there are insufficient items left in the queues.
|
||||
keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor
|
||||
controls whether the input is added to the queue or not. If it evaluates
|
||||
`True`, then `tensors` are added to the bucket; otherwise they are
|
||||
dropped. This tensor essentially acts as a filtering mechanism.
|
||||
The default behavior is to assume `keep_input=True`.
|
||||
shared_name: (Optional). If set, the queues will be shared under the given
|
||||
name across multiple sessions.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A tuple `(sequence_length, outputs)` where `sequence_length` is
|
||||
a 1-D `Tensor` of size `batch_size` and `outputs` is a list or dictionary
|
||||
of batched, bucketed, outputs corresponding to elements of `tensors`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `bucket_boundaries` is not a list of python integers.
|
||||
ValueError: if `bucket_boundaries` is empty or contains non-increasing
|
||||
values.
|
||||
"""
|
||||
tensor_list = _as_tensor_list(tensors)
|
||||
if not isinstance(bucket_boundaries, (list, tuple)):
|
||||
raise TypeError(
|
||||
"bucket_boundaries must be a list or tuple, but received: %s"
|
||||
% bucket_boundaries)
|
||||
if not bucket_boundaries:
|
||||
raise ValueError("bucket_boundaries must not be empty")
|
||||
for (s, e) in zip(bucket_boundaries[:-1], bucket_boundaries[1:]):
|
||||
if not isinstance(s, int) or not isinstance(e, int):
|
||||
raise TypeError(
|
||||
"bucket boundaries must be integers, but saw: %s and %s" % (s, e))
|
||||
if s >= e:
|
||||
raise ValueError(
|
||||
"Buckets must contain sequential increasing lengths, but saw: "
|
||||
"%d before %d" % (s, e))
|
||||
|
||||
with ops.name_scope(name, "bucket_by_sequence_length",
|
||||
[input_length] + tensor_list) as name:
|
||||
input_length = ops.convert_to_tensor(
|
||||
input_length, dtype=dtypes.int32, name="input_length")
|
||||
# Bucketing conditions are:
|
||||
# l < b[0]
|
||||
# b[0] <= l < b[1]
|
||||
# b[1] <= l < b[2]
|
||||
# ...
|
||||
# b[N-2] <= l < b[N-1]
|
||||
# b[N-1] <= l
|
||||
# Equivalent to:
|
||||
# [-inf, b[0], b[1], ..., b[N-1]] <= l < [b[0], b[1], ..., b[N-1], inf]
|
||||
buckets_min = [np.iinfo(np.int32).min] + list(bucket_boundaries)
|
||||
buckets_max = list(bucket_boundaries) + [np.iinfo(np.int32).max]
|
||||
conditions_c = math_ops.logical_and(
|
||||
math_ops.less_equal(buckets_min, input_length),
|
||||
math_ops.less(input_length, buckets_max))
|
||||
which_bucket = math_ops.reduce_min(array_ops.where(conditions_c))
|
||||
which_bucket = math_ops.to_int32(which_bucket)
|
||||
|
||||
if shapes is not None:
|
||||
shapes = [tensor_shape.scalar()] + shapes
|
||||
|
||||
_, dequeued = bucket(
|
||||
tensors=[input_length] + tensor_list,
|
||||
which_bucket=which_bucket,
|
||||
batch_size=batch_size,
|
||||
num_buckets=len(bucket_boundaries) + 1,
|
||||
num_threads=num_threads,
|
||||
capacity=capacity,
|
||||
shapes=shapes,
|
||||
dynamic_pad=dynamic_pad,
|
||||
allow_smaller_final_batch=allow_smaller_final_batch,
|
||||
keep_input=keep_input,
|
||||
shared_name=shared_name)
|
||||
|
||||
return (dequeued[0], _as_original_type(tensors, dequeued[1:]))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"bucket",
|
||||
"bucket_by_sequence_length"
|
||||
]
|
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
356
tensorflow/contrib/training/python/training/bucket_ops_test.py
Normal file
@ -0,0 +1,356 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tf.contrib.training.bucket."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def _which_bucket(bucket_edges, v):
|
||||
"""Identify which bucket v falls into.
|
||||
|
||||
Args:
|
||||
bucket_edges: int array, bucket edges
|
||||
v: int scalar, index
|
||||
Returns:
|
||||
int scalar, the bucket.
|
||||
If v < bucket_edges[0], return 0.
|
||||
If bucket_edges[0] <= v < bucket_edges[1], return 1.
|
||||
...
|
||||
If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges).
|
||||
If v >= bucket_edges[-1], return len(bucket_edges) + 1
|
||||
"""
|
||||
v = np.asarray(v)
|
||||
full = [0] + bucket_edges
|
||||
found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0]
|
||||
if not found.size:
|
||||
return len(full)
|
||||
return found[0]
|
||||
|
||||
|
||||
class BucketTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.scalar_int_feed = tf.placeholder(tf.int32, ())
|
||||
self.unk_int64_feed = tf.placeholder(tf.int64, (None,))
|
||||
self.vec3_str_feed = tf.placeholder(tf.string, (3,))
|
||||
|
||||
self._coord = tf.train.Coordinator()
|
||||
# Make capacity very large so we can feed all the inputs in the
|
||||
# main thread without blocking
|
||||
input_queue = tf.PaddingFIFOQueue(
|
||||
5000,
|
||||
dtypes=[tf.int32, tf.int64, tf.string],
|
||||
shapes=[(), (None,), (3,)])
|
||||
|
||||
self._input_enqueue_op = input_queue.enqueue(
|
||||
(self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed))
|
||||
self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue()
|
||||
self._threads = None
|
||||
self._close_op = input_queue.close()
|
||||
self._sess = None
|
||||
|
||||
def enqueue_inputs(self, sess, feed_dict):
|
||||
sess.run(self._input_enqueue_op, feed_dict=feed_dict)
|
||||
|
||||
def start_queue_runners(self, sess):
|
||||
# Store session to be able to close inputs later
|
||||
if self._sess is None:
|
||||
self._sess = sess
|
||||
self._threads = tf.train.start_queue_runners(coord=self._coord)
|
||||
|
||||
def tearDown(self):
|
||||
if self._sess is not None:
|
||||
self._sess.run(self._close_op)
|
||||
self._coord.request_stop()
|
||||
self._coord.join(self._threads)
|
||||
|
||||
def testSingleBucket(self):
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=tf.constant(0),
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(32):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get a single minibatch
|
||||
bucketed_values = sess.run(bucketed_dynamic)
|
||||
|
||||
# (which_bucket, bucket_tensors).
|
||||
self.assertEqual(2, len(bucketed_values))
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values[1]))
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, bucketed_values[0])
|
||||
|
||||
expected_scalar_int = np.arange(32)
|
||||
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||
for i in range(32):
|
||||
expected_unk_int64[i, :i] = i
|
||||
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values[1][0])
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
which_bucket = (self.scalar_int % 2)
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=which_bucket,
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(64):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
bucketed_values_0 = sess.run(bucketed_dynamic)
|
||||
bucketed_values_1 = sess.run(bucketed_dynamic)
|
||||
|
||||
# (which_bucket, bucket_tensors).
|
||||
self.assertEqual(2, len(bucketed_values_0))
|
||||
self.assertEqual(2, len(bucketed_values_1))
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_0[1]))
|
||||
self.assertEqual(3, len(bucketed_values_1[1]))
|
||||
|
||||
# Figure out which output has the even values (there's
|
||||
# randomness due to the multithreaded nature of bucketing)
|
||||
if bucketed_values_0[0] % 2 == 1:
|
||||
bucketed_values_even, bucketed_values_odd = (
|
||||
bucketed_values_1, bucketed_values_0)
|
||||
else:
|
||||
bucketed_values_even, bucketed_values_odd = (
|
||||
bucketed_values_0, bucketed_values_1)
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, bucketed_values_even[0])
|
||||
self.assertAllEqual(1, bucketed_values_odd[0])
|
||||
|
||||
# Test the first bucket outputted, the events starting at 0
|
||||
expected_scalar_int = np.arange(0, 32 * 2, 2)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2*i] = 2*i
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values_even[1][0])
|
||||
self.assertAllEqual(expected_scalar_int,
|
||||
bucketed_values_even[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64,
|
||||
bucketed_values_even[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str,
|
||||
bucketed_values_even[1][2][resort])
|
||||
|
||||
# Test the second bucket outputted, the odds starting at 1
|
||||
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2*i + 1] = 2*i + 1
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||
|
||||
# Must resort the output because num_threads > 1 leads to
|
||||
# sometimes-inconsistent insertion order.
|
||||
resort = np.argsort(bucketed_values_odd[1][0])
|
||||
self.assertAllEqual(expected_scalar_int,
|
||||
bucketed_values_odd[1][0][resort])
|
||||
self.assertAllEqual(expected_unk_int64,
|
||||
bucketed_values_odd[1][1][resort])
|
||||
self.assertAllEqual(expected_vec3_str,
|
||||
bucketed_values_odd[1][2][resort])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
which_bucket = (self.scalar_int % 2)
|
||||
keep_input = tf.equal(which_bucket, 0)
|
||||
bucketed_dynamic = tf.contrib.training.bucket(
|
||||
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
|
||||
which_bucket=which_bucket,
|
||||
num_buckets=2,
|
||||
batch_size=32,
|
||||
num_threads=10,
|
||||
keep_input=keep_input,
|
||||
dynamic_pad=True)
|
||||
# Check shape inference on bucketing outputs
|
||||
self.assertAllEqual(
|
||||
[[32], [32, None], [32, 3]],
|
||||
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
|
||||
with self.test_session() as sess:
|
||||
for v in range(128):
|
||||
self.enqueue_inputs(
|
||||
sess,
|
||||
{self.scalar_int_feed: v,
|
||||
self.unk_int64_feed: v * [v],
|
||||
self.vec3_str_feed: 3 * [str(v)]})
|
||||
self.start_queue_runners(sess)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
bucketed_values_even0 = sess.run(bucketed_dynamic)
|
||||
bucketed_values_even1 = sess.run(bucketed_dynamic)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, bucketed_values_even0[0])
|
||||
self.assertAllEqual(0, bucketed_values_even1[0])
|
||||
|
||||
# Merge their output for sorting and comparison
|
||||
bucketed_values_all_elem0 = np.concatenate(
|
||||
(bucketed_values_even0[1][0],
|
||||
bucketed_values_even1[1][0]))
|
||||
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 128, 2), sorted(bucketed_values_all_elem0))
|
||||
|
||||
|
||||
class BucketBySequenceLengthTest(tf.test.TestCase):
|
||||
|
||||
def _testBucketBySequenceLength(self, allow_small_batch):
|
||||
tf.reset_default_graph()
|
||||
|
||||
# All inputs must be identical lengths across tuple index.
|
||||
# The input reader will get input_length from the first tuple
|
||||
# entry.
|
||||
data_len = 4
|
||||
target_len = 3
|
||||
input_pairs = [
|
||||
(length,
|
||||
([np.int64(length)] * data_len,
|
||||
[str(length).encode("ascii")] * target_len))
|
||||
for length in (1, 3, 4, 5, 6, 10)]
|
||||
|
||||
lengths = tf.placeholder(tf.int32, ())
|
||||
data = tf.placeholder(tf.int64, (data_len,))
|
||||
targets = tf.placeholder(tf.string, (target_len,))
|
||||
|
||||
batch_size = 8
|
||||
bucket_boundaries = [3, 4, 5, 10]
|
||||
|
||||
# Make capacity very large so we can feed all the inputs in the
|
||||
# main thread without blocking
|
||||
input_queue = tf.FIFOQueue(
|
||||
5000, (tf.int32, tf.int64, tf.string),
|
||||
((), (data_len,), (target_len,)))
|
||||
input_enqueue_op = input_queue.enqueue((lengths, data, targets))
|
||||
lengths_t, data_t, targets_t = input_queue.dequeue()
|
||||
close_input_op = input_queue.close()
|
||||
|
||||
(out_lengths_t, data_and_targets_t) = (
|
||||
tf.contrib.training.bucket_by_sequence_length(
|
||||
input_length=lengths_t,
|
||||
tensors=[data_t, targets_t],
|
||||
batch_size=batch_size,
|
||||
bucket_boundaries=bucket_boundaries,
|
||||
allow_smaller_final_batch=allow_small_batch,
|
||||
num_threads=10))
|
||||
|
||||
expected_batch_size = None if allow_small_batch else batch_size
|
||||
self.assertEqual(out_lengths_t.get_shape().as_list(),
|
||||
[expected_batch_size])
|
||||
self.assertEqual(data_and_targets_t[0].get_shape().as_list(),
|
||||
[expected_batch_size, data_len])
|
||||
self.assertEqual(data_and_targets_t[1].get_shape().as_list(),
|
||||
[expected_batch_size, target_len])
|
||||
|
||||
def _read_test(sess):
|
||||
for _ in range(50):
|
||||
(out_lengths, (data, targets)) = sess.run(
|
||||
(out_lengths_t, data_and_targets_t))
|
||||
if allow_small_batch:
|
||||
self.assertEqual(data_len, data.shape[1])
|
||||
self.assertEqual(target_len, targets.shape[1])
|
||||
self.assertGreaterEqual(batch_size, out_lengths.shape[0])
|
||||
self.assertGreaterEqual(batch_size, data.shape[0])
|
||||
self.assertGreaterEqual(batch_size, targets.shape[0])
|
||||
else:
|
||||
self.assertEqual((batch_size, data_len), data.shape)
|
||||
self.assertEqual((batch_size, target_len), targets.shape)
|
||||
self.assertEqual((batch_size,), out_lengths.shape)
|
||||
for (lr, dr, tr) in zip(out_lengths, data, targets):
|
||||
# Make sure length matches data (here it's the same value)
|
||||
self.assertEqual(dr[0], lr)
|
||||
# Make sure data & targets match
|
||||
self.assertEqual(dr[0], int(tr[0].decode("ascii")))
|
||||
# Make sure for each row, data came from the same bucket.
|
||||
self.assertEqual(_which_bucket(bucket_boundaries, dr[0]),
|
||||
_which_bucket(bucket_boundaries, dr[1]))
|
||||
|
||||
with self.test_session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
|
||||
# Feed the inputs, then close the input thread.
|
||||
for _ in range(50 * batch_size + 100):
|
||||
which = random.randint(0, len(input_pairs) - 1)
|
||||
length, pair = input_pairs[which]
|
||||
sess.run(input_enqueue_op, feed_dict={
|
||||
lengths: length, data: pair[0], targets: pair[1]})
|
||||
sess.run(close_input_op)
|
||||
|
||||
# Start the queue runners
|
||||
threads = tf.train.start_queue_runners(coord=coord)
|
||||
# Read off the top of the bucket and ensure correctness of output
|
||||
_read_test(sess)
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def testBucketBySequenceLength(self):
|
||||
self._testBucketBySequenceLength(allow_small_batch=False)
|
||||
|
||||
def testBucketBySequenceLengthAllow(self):
|
||||
self._testBucketBySequenceLength(allow_small_batch=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user