diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index a44143ba406..79901b6ee56 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -11,6 +11,7 @@ py_library( name = "training_py", srcs = [ "__init__.py", + "python/training/bucket_ops.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", ], @@ -67,6 +68,18 @@ py_test( ], ) +py_test( + name = "bucket_ops_test", + size = "medium", + srcs = ["python/training/bucket_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":training_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index d8cd9058008..3c0ff0f8cfa 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -38,6 +38,17 @@ balanced. @@stratified_sample @@stratified_sample_unknown_dist + +## Bucketing + +Use ['bucket'](#bucket) or +['bucket_by_sequence_length'](#bucket_by_sequence_length) to stratify +minibatches into groups ("buckets"). Use `bucket_by_sequence_length` +with the argument `dynamic_pad=True` to receive minibatches of similarly +sized sequences for efficient training via `dynamic_rnn`. + +@@bucket +@@bucket_by_sequence_length """ from __future__ import absolute_import @@ -45,6 +56,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.training.python.training.bucket_ops import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * from tensorflow.python.util.all_util import make_all diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py new file mode 100644 index 00000000000..3a28c9141fa --- /dev/null +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -0,0 +1,374 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Operations for bucketing data into groups. + +The classes and functions in this module are used to queue up data into +buckets conditional on side information (e.g. sequence length). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import input as input_py +from tensorflow.python.training import queue_runner + + +# pylint: disable=protected-access +_as_original_type = input_py._as_original_type +_as_tensor_list = input_py._as_tensor_list +_deserialize_sparse_tensors = input_py._deserialize_sparse_tensors +_dtypes = input_py._dtypes +_serialize_sparse_tensors = input_py._serialize_sparse_tensors +_shapes = input_py._shapes +_which_queue = input_py._which_queue +# pylint: enable=protected-access + + +def _validate_bucket(tensor_list): + tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list) + if not tensor_list: + raise ValueError("Expected at least one tensor in bucket().") + return tensor_list + + +def bucket(tensors, + which_bucket, + batch_size, + num_buckets, + num_threads=1, + capacity=32, + shapes=None, + dynamic_pad=False, + allow_smaller_final_batch=False, + keep_input=None, + shared_name=None, + name=None): + """Lazy bucketing of input tensors according to `which_bucket`. + + The argument `tensors` can be a list or a dictionary of tensors. + The value returned by the function will be of the same type + as `tensors`. + + The tensors entering this function are put into the bucket given by + `which_bucket`. Each bucket has its own queue. When a bucket contains + `batch_size` elements, this minibatch is pushed onto a top queue. The + tensors returned from this function are a the result of dequeueing the + next minibatch from this top queue. + + This function is implemented using several queues. A `QueueRunner` for the + queues is added to the current `Graph`'s `QUEUE_RUNNER` collection. + + As the returned tensors are the result of of a dequeue operation, evaluating + them will throw a `tf.errors.OutOfRangeError` when the input queue is + exhausted. If these tensors are feeding another input queue, its queue runner + will catch this exception, however, if they are used in your main thread + you are responsible for catching this yourself. + + *N.B.:* If `dynamic_pad` is `False`, you must ensure that either + (i) the `shapes` argument is passed, or (ii) all of the tensors in + `tensors` must have fully-defined shapes. `ValueError` will be + raised if neither of these conditions holds. + + If `dynamic_pad` is `True`, it is sufficient that the *rank* of the + tensors is known, but individual dimensions may have shape `None`. + In this case, for each enqueue the dimensions with value `None` + may have a variable length; upon dequeue, the output tensors will be padded + on the right to the maximum shape of the tensors in the current minibatch. + For numbers, this padding takes value 0. For strings, this padding is + the empty string. See `PaddingFIFOQueue` for more info. + + If `allow_smaller_final_batch` is `True`, a smaller batch value than + `batch_size` is returned when the queues are closed and there are not enough + elements to fill the batch, otherwise the pending elements are discarded. + In addition, all output tensors' static shapes, as accessed via the + `get_shape()` method will have a 0th `Dimension` value of `None`, and + operations that depend on fixed batch_size would fail. + + Args: + tensors: The list or dictionary of tensors, representing a single element, + to bucket. Nested lists are not supported. + which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. + batch_size: The new batch size pulled from the queue + (python int or int32 scalar). + num_buckets: A python integer, the number of buckets. + num_threads: An integer. The number of threads enqueuing `tensors`. + capacity: An integer. The maximum number of minibatches in the top queue, + and also the maximum number of elements within each bucket. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensors`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batches to be smaller if there are insufficient items left in the queues. + keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor + controls whether the input is added to the queue or not. If it evaluates + `True`, then `tensors` are added to the bucket; otherwise they are + dropped. This tensor essentially acts as a filtering mechanism. + The default behavior is to assume `keep_input=True`. + shared_name: (Optional). If set, the queues will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A tuple `(bucket, outputs)` where `bucket` is + a `int32` scalar tensor and `outputs` is a list or + dictionary of batched outputs corresponding to elements of `tensors`. + Every step will receive a new bucket of outputs. + + Raises: + ValueError: If the `shapes` are not specified, and cannot be + inferred from the elements of `tensors`. + """ + tensor_list = _as_tensor_list(tensors) + with ops.name_scope(name, "bucket", tensor_list) as name: + tensor_list = _validate_bucket(tensor_list) + (tensor_list, sparse_info) = _serialize_sparse_tensors( + tensor_list, enqueue_many=False) + + # Round-trip batch_size to a tensor, and possibly back + batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int32, name="batch_size") + static_batch_size = tensor_util.constant_value(batch_size) + batch_size = ( + static_batch_size if static_batch_size is not None else batch_size) + + types = _dtypes([tensor_list]) + shapes = _shapes([tensor_list], shapes, enqueue_many=False) + + which_bucket = ops.convert_to_tensor( + which_bucket, dtype=dtypes.int32, name="which_bucket") + + queue_creator = _which_queue(dynamic_pad) + bucket_queues = [] + for i in range(num_buckets): + shared_name_i = ( + "%s_%d" % (shared_name, i) if shared_name is not None else None) + bucket_queues.append( + queue_creator(capacity=capacity, + dtypes=types, + shapes=shapes, + shared_name=shared_name_i, name="bucket_queue_%d" % i)) + + maybe_static_batch_size = ( + None if allow_smaller_final_batch else static_batch_size) + + bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s) + for s in bucket_queues[0].shapes] + # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO + # queues because if we use allow_smaller_final_batch, shapes will + # contain Nones in their first entry; as a result, a regular + # FIFOQueue would die when being passed shapes that are not fully defined. + top_queue = data_flow_ops.PaddingFIFOQueue( + capacity=capacity, + dtypes=[dtypes.int32] + types, + shapes=[tensor_shape.scalar()] + bucket_shapes, + shared_name=shared_name, name="top_queue") + + def enqueue_which(): + def enqueue_single(i): + return bucket_queues[i].enqueue(tensor_list) + enqueues = [ + control_flow_ops.cond( + math_ops.equal(which_bucket, i), + functools.partial(enqueue_single, i), + control_flow_ops.no_op) + for i in range(num_buckets)] + return control_flow_ops.group(*enqueues, name="group_enqueues") + + if keep_input is not None: + # TODO(ebrevdo): Expand keep_input param to core training + # methods, and pipe through to _serialize_sparse_tensors; so + # that expensive serialization is guarded by keep_input. + maybe_enqueue = control_flow_ops.cond( + keep_input, + enqueue_which, + control_flow_ops.no_op) + else: + maybe_enqueue = enqueue_which() + + bucket_enqueue_ops = [maybe_enqueue] * num_threads + + if allow_smaller_final_batch: + which_dequeue = lambda q: q.dequeue_up_to + else: + which_dequeue = lambda q: q.dequeue_many + + enqueues_to_top = [ + top_queue.enqueue( + [constant_op.constant(i)] + + which_dequeue(q)(batch_size, name="read_bucket_%d" % i), + name="enqueue_from_bucket_%d" % i) + for i, q in enumerate(bucket_queues)] + + for i, q in enumerate(bucket_queues): + queue_runner.add_queue_runner(queue_runner.QueueRunner( + q, [enqueues_to_top[i]], + queue_closed_exception_types=( + errors.OutOfRangeError, errors.CancelledError))) + queue_runner.add_queue_runner(queue_runner.QueueRunner( + top_queue, bucket_enqueue_ops, + queue_closed_exception_types=( + errors.OutOfRangeError, errors.CancelledError))) + + for q in bucket_queues: + logging_ops.scalar_summary( + "bucket/%s/size" % q.name, + math_ops.cast(top_queue.size(), dtypes.float32)) + logging_ops.scalar_summary( + "bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity), + math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity)) + + dequeued = top_queue.dequeue(name="dequeue_top") + which_bucket_dequeued = dequeued[0] + dequeued = dequeued[1:] + dequeued = _deserialize_sparse_tensors(dequeued, sparse_info) + return (which_bucket_dequeued, _as_original_type(tensors, dequeued)) + + +def bucket_by_sequence_length(input_length, + tensors, + batch_size, + bucket_boundaries, + num_threads=1, + capacity=32, + shapes=None, + dynamic_pad=False, + allow_smaller_final_batch=False, + keep_input=None, + shared_name=None, + name=None): + """Lazy bucketing of inputs according to their length. + + This method calls `tf.contrib.training.bucket` under the hood, after first + subdividing the bucket boundaries into separate buckets and identifying which + bucket the given `input_length` belongs to. See the documentation for + `which_bucket` for details of the other arguments. + + Args: + input_length: `int32` scalar `Tensor`, the sequence length of tensors. + tensors: The list or dictionary of tensors, representing a single element, + to bucket. Nested lists are not supported. + batch_size: The new batch size pulled from the queue + (python int or int32 scalar). + bucket_boundaries: int list, increasing non-negative numbers. + The edges of the buckets to use when bucketing tensors. Two extra buckets + are created, one for `input_length < bucket_boundaries[0]` and + one for `input_length >= bucket_boundaries[-1]`. + num_threads: An integer. The number of threads enqueuing `tensors`. + capacity: An integer. The maximum number of minibatches in the top queue, + and also the maximum number of elements within each bucket. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensors`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batches to be smaller if there are insufficient items left in the queues. + keep_input: (Optional). A `bool` scalar Tensor. If provided, this tensor + controls whether the input is added to the queue or not. If it evaluates + `True`, then `tensors` are added to the bucket; otherwise they are + dropped. This tensor essentially acts as a filtering mechanism. + The default behavior is to assume `keep_input=True`. + shared_name: (Optional). If set, the queues will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A tuple `(sequence_length, outputs)` where `sequence_length` is + a 1-D `Tensor` of size `batch_size` and `outputs` is a list or dictionary + of batched, bucketed, outputs corresponding to elements of `tensors`. + + Raises: + TypeError: if `bucket_boundaries` is not a list of python integers. + ValueError: if `bucket_boundaries` is empty or contains non-increasing + values. + """ + tensor_list = _as_tensor_list(tensors) + if not isinstance(bucket_boundaries, (list, tuple)): + raise TypeError( + "bucket_boundaries must be a list or tuple, but received: %s" + % bucket_boundaries) + if not bucket_boundaries: + raise ValueError("bucket_boundaries must not be empty") + for (s, e) in zip(bucket_boundaries[:-1], bucket_boundaries[1:]): + if not isinstance(s, int) or not isinstance(e, int): + raise TypeError( + "bucket boundaries must be integers, but saw: %s and %s" % (s, e)) + if s >= e: + raise ValueError( + "Buckets must contain sequential increasing lengths, but saw: " + "%d before %d" % (s, e)) + + with ops.name_scope(name, "bucket_by_sequence_length", + [input_length] + tensor_list) as name: + input_length = ops.convert_to_tensor( + input_length, dtype=dtypes.int32, name="input_length") + # Bucketing conditions are: + # l < b[0] + # b[0] <= l < b[1] + # b[1] <= l < b[2] + # ... + # b[N-2] <= l < b[N-1] + # b[N-1] <= l + # Equivalent to: + # [-inf, b[0], b[1], ..., b[N-1]] <= l < [b[0], b[1], ..., b[N-1], inf] + buckets_min = [np.iinfo(np.int32).min] + list(bucket_boundaries) + buckets_max = list(bucket_boundaries) + [np.iinfo(np.int32).max] + conditions_c = math_ops.logical_and( + math_ops.less_equal(buckets_min, input_length), + math_ops.less(input_length, buckets_max)) + which_bucket = math_ops.reduce_min(array_ops.where(conditions_c)) + which_bucket = math_ops.to_int32(which_bucket) + + if shapes is not None: + shapes = [tensor_shape.scalar()] + shapes + + _, dequeued = bucket( + tensors=[input_length] + tensor_list, + which_bucket=which_bucket, + batch_size=batch_size, + num_buckets=len(bucket_boundaries) + 1, + num_threads=num_threads, + capacity=capacity, + shapes=shapes, + dynamic_pad=dynamic_pad, + allow_smaller_final_batch=allow_smaller_final_batch, + keep_input=keep_input, + shared_name=shared_name) + + return (dequeued[0], _as_original_type(tensors, dequeued[1:])) + + +__all__ = [ + "bucket", + "bucket_by_sequence_length" +] diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py new file mode 100644 index 00000000000..587cf9411ce --- /dev/null +++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py @@ -0,0 +1,356 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tf.contrib.training.bucket.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np +import tensorflow as tf + + +def _which_bucket(bucket_edges, v): + """Identify which bucket v falls into. + + Args: + bucket_edges: int array, bucket edges + v: int scalar, index + Returns: + int scalar, the bucket. + If v < bucket_edges[0], return 0. + If bucket_edges[0] <= v < bucket_edges[1], return 1. + ... + If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges). + If v >= bucket_edges[-1], return len(bucket_edges) + 1 + """ + v = np.asarray(v) + full = [0] + bucket_edges + found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0] + if not found.size: + return len(full) + return found[0] + + +class BucketTest(tf.test.TestCase): + + def setUp(self): + tf.reset_default_graph() + + self.scalar_int_feed = tf.placeholder(tf.int32, ()) + self.unk_int64_feed = tf.placeholder(tf.int64, (None,)) + self.vec3_str_feed = tf.placeholder(tf.string, (3,)) + + self._coord = tf.train.Coordinator() + # Make capacity very large so we can feed all the inputs in the + # main thread without blocking + input_queue = tf.PaddingFIFOQueue( + 5000, + dtypes=[tf.int32, tf.int64, tf.string], + shapes=[(), (None,), (3,)]) + + self._input_enqueue_op = input_queue.enqueue( + (self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed)) + self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue() + self._threads = None + self._close_op = input_queue.close() + self._sess = None + + def enqueue_inputs(self, sess, feed_dict): + sess.run(self._input_enqueue_op, feed_dict=feed_dict) + + def start_queue_runners(self, sess): + # Store session to be able to close inputs later + if self._sess is None: + self._sess = sess + self._threads = tf.train.start_queue_runners(coord=self._coord) + + def tearDown(self): + if self._sess is not None: + self._sess.run(self._close_op) + self._coord.request_stop() + self._coord.join(self._threads) + + def testSingleBucket(self): + bucketed_dynamic = tf.contrib.training.bucket( + tensors=[self.scalar_int, self.unk_int64, self.vec3_str], + which_bucket=tf.constant(0), + num_buckets=2, + batch_size=32, + num_threads=10, + dynamic_pad=True) + # Check shape inference on bucketing outputs + self.assertAllEqual( + [[32], [32, None], [32, 3]], + [out.get_shape().as_list() for out in bucketed_dynamic[1]]) + with self.test_session() as sess: + for v in range(32): + self.enqueue_inputs( + sess, + {self.scalar_int_feed: v, + self.unk_int64_feed: v * [v], + self.vec3_str_feed: 3 * [str(v)]}) + self.start_queue_runners(sess) + + # Get a single minibatch + bucketed_values = sess.run(bucketed_dynamic) + + # (which_bucket, bucket_tensors). + self.assertEqual(2, len(bucketed_values)) + + # Count number of bucket_tensors. + self.assertEqual(3, len(bucketed_values[1])) + + # Ensure bucket 0 was used for all minibatch entries. + self.assertAllEqual(0, bucketed_values[0]) + + expected_scalar_int = np.arange(32) + expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) + for i in range(32): + expected_unk_int64[i, :i] = i + expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T + + # Must resort the output because num_threads > 1 leads to + # sometimes-inconsistent insertion order. + resort = np.argsort(bucketed_values[1][0]) + self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort]) + self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort]) + self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort]) + + def testEvenOddBuckets(self): + which_bucket = (self.scalar_int % 2) + bucketed_dynamic = tf.contrib.training.bucket( + tensors=[self.scalar_int, self.unk_int64, self.vec3_str], + which_bucket=which_bucket, + num_buckets=2, + batch_size=32, + num_threads=10, + dynamic_pad=True) + # Check shape inference on bucketing outputs + self.assertAllEqual( + [[32], [32, None], [32, 3]], + [out.get_shape().as_list() for out in bucketed_dynamic[1]]) + with self.test_session() as sess: + for v in range(64): + self.enqueue_inputs( + sess, + {self.scalar_int_feed: v, + self.unk_int64_feed: v * [v], + self.vec3_str_feed: 3 * [str(v)]}) + self.start_queue_runners(sess) + + # Get two minibatches (one containing even values, one containing odds) + bucketed_values_0 = sess.run(bucketed_dynamic) + bucketed_values_1 = sess.run(bucketed_dynamic) + + # (which_bucket, bucket_tensors). + self.assertEqual(2, len(bucketed_values_0)) + self.assertEqual(2, len(bucketed_values_1)) + + # Count number of bucket_tensors. + self.assertEqual(3, len(bucketed_values_0[1])) + self.assertEqual(3, len(bucketed_values_1[1])) + + # Figure out which output has the even values (there's + # randomness due to the multithreaded nature of bucketing) + if bucketed_values_0[0] % 2 == 1: + bucketed_values_even, bucketed_values_odd = ( + bucketed_values_1, bucketed_values_0) + else: + bucketed_values_even, bucketed_values_odd = ( + bucketed_values_0, bucketed_values_1) + + # Ensure bucket 0 was used for all minibatch entries. + self.assertAllEqual(0, bucketed_values_even[0]) + self.assertAllEqual(1, bucketed_values_odd[0]) + + # Test the first bucket outputted, the events starting at 0 + expected_scalar_int = np.arange(0, 32 * 2, 2) + expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) + for i in range(0, 32): + expected_unk_int64[i, :2*i] = 2*i + expected_vec3_str = np.vstack( + 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T + + # Must resort the output because num_threads > 1 leads to + # sometimes-inconsistent insertion order. + resort = np.argsort(bucketed_values_even[1][0]) + self.assertAllEqual(expected_scalar_int, + bucketed_values_even[1][0][resort]) + self.assertAllEqual(expected_unk_int64, + bucketed_values_even[1][1][resort]) + self.assertAllEqual(expected_vec3_str, + bucketed_values_even[1][2][resort]) + + # Test the second bucket outputted, the odds starting at 1 + expected_scalar_int = np.arange(1, 32 * 2 + 1, 2) + expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) + for i in range(0, 32): + expected_unk_int64[i, :2*i + 1] = 2*i + 1 + expected_vec3_str = np.vstack( + 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T + + # Must resort the output because num_threads > 1 leads to + # sometimes-inconsistent insertion order. + resort = np.argsort(bucketed_values_odd[1][0]) + self.assertAllEqual(expected_scalar_int, + bucketed_values_odd[1][0][resort]) + self.assertAllEqual(expected_unk_int64, + bucketed_values_odd[1][1][resort]) + self.assertAllEqual(expected_vec3_str, + bucketed_values_odd[1][2][resort]) + + def testEvenOddBucketsFilterOutAllOdd(self): + which_bucket = (self.scalar_int % 2) + keep_input = tf.equal(which_bucket, 0) + bucketed_dynamic = tf.contrib.training.bucket( + tensors=[self.scalar_int, self.unk_int64, self.vec3_str], + which_bucket=which_bucket, + num_buckets=2, + batch_size=32, + num_threads=10, + keep_input=keep_input, + dynamic_pad=True) + # Check shape inference on bucketing outputs + self.assertAllEqual( + [[32], [32, None], [32, 3]], + [out.get_shape().as_list() for out in bucketed_dynamic[1]]) + with self.test_session() as sess: + for v in range(128): + self.enqueue_inputs( + sess, + {self.scalar_int_feed: v, + self.unk_int64_feed: v * [v], + self.vec3_str_feed: 3 * [str(v)]}) + self.start_queue_runners(sess) + + # Get two minibatches ([0, 2, ...] and [64, 66, ...]) + bucketed_values_even0 = sess.run(bucketed_dynamic) + bucketed_values_even1 = sess.run(bucketed_dynamic) + + # Ensure that bucket 1 was completely filtered out + self.assertAllEqual(0, bucketed_values_even0[0]) + self.assertAllEqual(0, bucketed_values_even1[0]) + + # Merge their output for sorting and comparison + bucketed_values_all_elem0 = np.concatenate( + (bucketed_values_even0[1][0], + bucketed_values_even1[1][0])) + + self.assertAllEqual( + np.arange(0, 128, 2), sorted(bucketed_values_all_elem0)) + + +class BucketBySequenceLengthTest(tf.test.TestCase): + + def _testBucketBySequenceLength(self, allow_small_batch): + tf.reset_default_graph() + + # All inputs must be identical lengths across tuple index. + # The input reader will get input_length from the first tuple + # entry. + data_len = 4 + target_len = 3 + input_pairs = [ + (length, + ([np.int64(length)] * data_len, + [str(length).encode("ascii")] * target_len)) + for length in (1, 3, 4, 5, 6, 10)] + + lengths = tf.placeholder(tf.int32, ()) + data = tf.placeholder(tf.int64, (data_len,)) + targets = tf.placeholder(tf.string, (target_len,)) + + batch_size = 8 + bucket_boundaries = [3, 4, 5, 10] + + # Make capacity very large so we can feed all the inputs in the + # main thread without blocking + input_queue = tf.FIFOQueue( + 5000, (tf.int32, tf.int64, tf.string), + ((), (data_len,), (target_len,))) + input_enqueue_op = input_queue.enqueue((lengths, data, targets)) + lengths_t, data_t, targets_t = input_queue.dequeue() + close_input_op = input_queue.close() + + (out_lengths_t, data_and_targets_t) = ( + tf.contrib.training.bucket_by_sequence_length( + input_length=lengths_t, + tensors=[data_t, targets_t], + batch_size=batch_size, + bucket_boundaries=bucket_boundaries, + allow_smaller_final_batch=allow_small_batch, + num_threads=10)) + + expected_batch_size = None if allow_small_batch else batch_size + self.assertEqual(out_lengths_t.get_shape().as_list(), + [expected_batch_size]) + self.assertEqual(data_and_targets_t[0].get_shape().as_list(), + [expected_batch_size, data_len]) + self.assertEqual(data_and_targets_t[1].get_shape().as_list(), + [expected_batch_size, target_len]) + + def _read_test(sess): + for _ in range(50): + (out_lengths, (data, targets)) = sess.run( + (out_lengths_t, data_and_targets_t)) + if allow_small_batch: + self.assertEqual(data_len, data.shape[1]) + self.assertEqual(target_len, targets.shape[1]) + self.assertGreaterEqual(batch_size, out_lengths.shape[0]) + self.assertGreaterEqual(batch_size, data.shape[0]) + self.assertGreaterEqual(batch_size, targets.shape[0]) + else: + self.assertEqual((batch_size, data_len), data.shape) + self.assertEqual((batch_size, target_len), targets.shape) + self.assertEqual((batch_size,), out_lengths.shape) + for (lr, dr, tr) in zip(out_lengths, data, targets): + # Make sure length matches data (here it's the same value) + self.assertEqual(dr[0], lr) + # Make sure data & targets match + self.assertEqual(dr[0], int(tr[0].decode("ascii"))) + # Make sure for each row, data came from the same bucket. + self.assertEqual(_which_bucket(bucket_boundaries, dr[0]), + _which_bucket(bucket_boundaries, dr[1])) + + with self.test_session() as sess: + coord = tf.train.Coordinator() + + # Feed the inputs, then close the input thread. + for _ in range(50 * batch_size + 100): + which = random.randint(0, len(input_pairs) - 1) + length, pair = input_pairs[which] + sess.run(input_enqueue_op, feed_dict={ + lengths: length, data: pair[0], targets: pair[1]}) + sess.run(close_input_op) + + # Start the queue runners + threads = tf.train.start_queue_runners(coord=coord) + # Read off the top of the bucket and ensure correctness of output + _read_test(sess) + coord.request_stop() + coord.join(threads) + + def testBucketBySequenceLength(self): + self._testBucketBySequenceLength(allow_small_batch=False) + + def testBucketBySequenceLengthAllow(self): + self._testBucketBySequenceLength(allow_small_batch=True) + + +if __name__ == "__main__": + tf.test.main()