Add bucketing input helpers to tf.contrib.training.
Change: 131891671
This commit is contained in:
		
							parent
							
								
									7a1210bdbd
								
							
						
					
					
						commit
						bc5df827de
					
				| @ -11,6 +11,7 @@ py_library( | ||||
|     name = "training_py", | ||||
|     srcs = [ | ||||
|         "__init__.py", | ||||
|         "python/training/bucket_ops.py", | ||||
|         "python/training/sampling_ops.py", | ||||
|         "python/training/sequence_queueing_state_saver.py", | ||||
|     ], | ||||
| @ -67,6 +68,18 @@ py_test( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "bucket_ops_test", | ||||
|     size = "medium", | ||||
|     srcs = ["python/training/bucket_ops_test.py"], | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":training_py", | ||||
|         "//tensorflow:tensorflow_py", | ||||
|         "//tensorflow/python:framework_test_lib", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| filegroup( | ||||
|     name = "all_files", | ||||
|     srcs = glob( | ||||
|  | ||||
| @ -38,6 +38,17 @@ balanced. | ||||
| 
 | ||||
| @@stratified_sample | ||||
| @@stratified_sample_unknown_dist | ||||
| 
 | ||||
| ## Bucketing | ||||
| 
 | ||||
| Use ['bucket'](#bucket) or | ||||
| ['bucket_by_sequence_length'](#bucket_by_sequence_length) to stratify | ||||
| minibatches into groups ("buckets").  Use `bucket_by_sequence_length` | ||||
| with the argument `dynamic_pad=True` to receive minibatches of similarly | ||||
| sized sequences for efficient training via `dynamic_rnn`. | ||||
| 
 | ||||
| @@bucket | ||||
| @@bucket_by_sequence_length | ||||
| """ | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| @ -45,6 +56,7 @@ from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| # pylint: disable=unused-import,wildcard-import | ||||
| from tensorflow.contrib.training.python.training.bucket_ops import * | ||||
| from tensorflow.contrib.training.python.training.sampling_ops import * | ||||
| from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * | ||||
| from tensorflow.python.util.all_util import make_all | ||||
|  | ||||
							
								
								
									
										374
									
								
								tensorflow/contrib/training/python/training/bucket_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										374
									
								
								tensorflow/contrib/training/python/training/bucket_ops.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,374 @@ | ||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| 
 | ||||
| """Operations for bucketing data into groups. | ||||
| 
 | ||||
| The classes and functions in this module are used to queue up data into | ||||
| buckets conditional on side information (e.g. sequence length). | ||||
| """ | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import functools | ||||
| 
 | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import errors | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.framework import tensor_util | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import data_flow_ops | ||||
| from tensorflow.python.ops import logging_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.training import input as input_py | ||||
| from tensorflow.python.training import queue_runner | ||||
| 
 | ||||
| 
 | ||||
| # pylint: disable=protected-access | ||||
| _as_original_type = input_py._as_original_type | ||||
| _as_tensor_list = input_py._as_tensor_list | ||||
| _deserialize_sparse_tensors = input_py._deserialize_sparse_tensors | ||||
| _dtypes = input_py._dtypes | ||||
| _serialize_sparse_tensors = input_py._serialize_sparse_tensors | ||||
| _shapes = input_py._shapes | ||||
| _which_queue = input_py._which_queue | ||||
| # pylint: enable=protected-access | ||||
| 
 | ||||
| 
 | ||||
| def _validate_bucket(tensor_list): | ||||
|   tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list) | ||||
|   if not tensor_list: | ||||
|     raise ValueError("Expected at least one tensor in bucket().") | ||||
|   return tensor_list | ||||
| 
 | ||||
| 
 | ||||
| def bucket(tensors, | ||||
|            which_bucket, | ||||
|            batch_size, | ||||
|            num_buckets, | ||||
|            num_threads=1, | ||||
|            capacity=32, | ||||
|            shapes=None, | ||||
|            dynamic_pad=False, | ||||
|            allow_smaller_final_batch=False, | ||||
|            keep_input=None, | ||||
|            shared_name=None, | ||||
|            name=None): | ||||
|   """Lazy bucketing of input tensors according to `which_bucket`. | ||||
| 
 | ||||
|   The argument `tensors` can be a list or a dictionary of tensors. | ||||
|   The value returned by the function will be of the same type | ||||
|   as `tensors`. | ||||
| 
 | ||||
|   The tensors entering this function are put into the bucket given by | ||||
|   `which_bucket`.  Each bucket has its own queue.  When a bucket contains | ||||
|   `batch_size` elements, this minibatch is pushed onto a top queue.  The | ||||
|   tensors returned from this function are a the result of dequeueing the | ||||
|   next minibatch from this top queue. | ||||
| 
 | ||||
|   This function is implemented using several queues. A `QueueRunner` for the | ||||
|   queues is added to the current `Graph`'s `QUEUE_RUNNER` collection. | ||||
| 
 | ||||
|   As the returned tensors are the result of of a dequeue operation, evaluating | ||||
|   them will throw a `tf.errors.OutOfRangeError` when the input queue is | ||||
|   exhausted.  If these tensors are feeding another input queue, its queue runner | ||||
|   will catch this exception, however, if they are used in your main thread | ||||
|   you are responsible for catching this yourself. | ||||
| 
 | ||||
|   *N.B.:* If `dynamic_pad` is `False`, you must ensure that either | ||||
|   (i) the `shapes` argument is passed, or (ii) all of the tensors in | ||||
|   `tensors` must have fully-defined shapes. `ValueError` will be | ||||
|   raised if neither of these conditions holds. | ||||
| 
 | ||||
|   If `dynamic_pad` is `True`, it is sufficient that the *rank* of the | ||||
|   tensors is known, but individual dimensions may have shape `None`. | ||||
|   In this case, for each enqueue the dimensions with value `None` | ||||
|   may have a variable length; upon dequeue, the output tensors will be padded | ||||
|   on the right to the maximum shape of the tensors in the current minibatch. | ||||
|   For numbers, this padding takes value 0.  For strings, this padding is | ||||
|   the empty string.  See `PaddingFIFOQueue` for more info. | ||||
| 
 | ||||
|   If `allow_smaller_final_batch` is `True`, a smaller batch value than | ||||
|   `batch_size` is returned when the queues are closed and there are not enough | ||||
|   elements to fill the batch, otherwise the pending elements are discarded. | ||||
|   In addition, all output tensors' static shapes, as accessed via the | ||||
|   `get_shape()` method will have a 0th `Dimension` value of `None`, and | ||||
|   operations that depend on fixed batch_size would fail. | ||||
| 
 | ||||
|   Args: | ||||
|     tensors: The list or dictionary of tensors, representing a single element, | ||||
|       to bucket.  Nested lists are not supported. | ||||
|     which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. | ||||
|     batch_size: The new batch size pulled from the queue | ||||
|       (python int or int32 scalar). | ||||
|     num_buckets: A python integer, the number of buckets. | ||||
|     num_threads: An integer.  The number of threads enqueuing `tensors`. | ||||
|     capacity: An integer. The maximum number of minibatches in the top queue, | ||||
|       and also the maximum number of elements within each bucket. | ||||
|     shapes: (Optional) The shapes for each example.  Defaults to the | ||||
|       inferred shapes for `tensors`. | ||||
|     dynamic_pad: Boolean.  Allow variable dimensions in input shapes. | ||||
|       The given dimensions are padded upon dequeue so that tensors within a | ||||
|       batch have the same shapes. | ||||
|     allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final | ||||
|       batches to be smaller if there are insufficient items left in the queues. | ||||
|     keep_input: (Optional).  A `bool` scalar Tensor.  If provided, this tensor | ||||
|       controls whether the input is added to the queue or not.  If it evaluates | ||||
|       `True`, then `tensors` are added to the bucket; otherwise they are | ||||
|       dropped.  This tensor essentially acts as a filtering mechanism. | ||||
|       The default behavior is to assume `keep_input=True`. | ||||
|     shared_name: (Optional). If set, the queues will be shared under the given | ||||
|       name across multiple sessions. | ||||
|     name: (Optional) A name for the operations. | ||||
| 
 | ||||
|   Returns: | ||||
|     A tuple `(bucket, outputs)` where `bucket` is | ||||
|     a `int32` scalar tensor and `outputs` is a list or | ||||
|     dictionary of batched outputs corresponding to elements of `tensors`. | ||||
|     Every step will receive a new bucket of outputs. | ||||
| 
 | ||||
|   Raises: | ||||
|     ValueError: If the `shapes` are not specified, and cannot be | ||||
|       inferred from the elements of `tensors`. | ||||
|   """ | ||||
|   tensor_list = _as_tensor_list(tensors) | ||||
|   with ops.name_scope(name, "bucket", tensor_list) as name: | ||||
|     tensor_list = _validate_bucket(tensor_list) | ||||
|     (tensor_list, sparse_info) = _serialize_sparse_tensors( | ||||
|         tensor_list, enqueue_many=False) | ||||
| 
 | ||||
|     # Round-trip batch_size to a tensor, and possibly back | ||||
|     batch_size = ops.convert_to_tensor( | ||||
|         batch_size, dtype=dtypes.int32, name="batch_size") | ||||
|     static_batch_size = tensor_util.constant_value(batch_size) | ||||
|     batch_size = ( | ||||
|         static_batch_size if static_batch_size is not None else batch_size) | ||||
| 
 | ||||
|     types = _dtypes([tensor_list]) | ||||
|     shapes = _shapes([tensor_list], shapes, enqueue_many=False) | ||||
| 
 | ||||
|     which_bucket = ops.convert_to_tensor( | ||||
|         which_bucket, dtype=dtypes.int32, name="which_bucket") | ||||
| 
 | ||||
|     queue_creator = _which_queue(dynamic_pad) | ||||
|     bucket_queues = [] | ||||
|     for i in range(num_buckets): | ||||
|       shared_name_i = ( | ||||
|           "%s_%d" % (shared_name, i) if shared_name is not None else None) | ||||
|       bucket_queues.append( | ||||
|           queue_creator(capacity=capacity, | ||||
|                         dtypes=types, | ||||
|                         shapes=shapes, | ||||
|                         shared_name=shared_name_i, name="bucket_queue_%d" % i)) | ||||
| 
 | ||||
|     maybe_static_batch_size = ( | ||||
|         None if allow_smaller_final_batch else static_batch_size) | ||||
| 
 | ||||
|     bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s) | ||||
|                      for s in bucket_queues[0].shapes] | ||||
|     # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO | ||||
|     # queues because if we use allow_smaller_final_batch, shapes will | ||||
|     # contain Nones in their first entry; as a result, a regular | ||||
|     # FIFOQueue would die when being passed shapes that are not fully defined. | ||||
|     top_queue = data_flow_ops.PaddingFIFOQueue( | ||||
|         capacity=capacity, | ||||
|         dtypes=[dtypes.int32] + types, | ||||
|         shapes=[tensor_shape.scalar()] + bucket_shapes, | ||||
|         shared_name=shared_name, name="top_queue") | ||||
| 
 | ||||
|     def enqueue_which(): | ||||
|       def enqueue_single(i): | ||||
|         return bucket_queues[i].enqueue(tensor_list) | ||||
|       enqueues = [ | ||||
|           control_flow_ops.cond( | ||||
|               math_ops.equal(which_bucket, i), | ||||
|               functools.partial(enqueue_single, i), | ||||
|               control_flow_ops.no_op) | ||||
|           for i in range(num_buckets)] | ||||
|       return control_flow_ops.group(*enqueues, name="group_enqueues") | ||||
| 
 | ||||
|     if keep_input is not None: | ||||
|       # TODO(ebrevdo): Expand keep_input param to core training | ||||
|       # methods, and pipe through to _serialize_sparse_tensors; so | ||||
|       # that expensive serialization is guarded by keep_input. | ||||
|       maybe_enqueue = control_flow_ops.cond( | ||||
|           keep_input, | ||||
|           enqueue_which, | ||||
|           control_flow_ops.no_op) | ||||
|     else: | ||||
|       maybe_enqueue = enqueue_which() | ||||
| 
 | ||||
|     bucket_enqueue_ops = [maybe_enqueue] * num_threads | ||||
| 
 | ||||
|     if allow_smaller_final_batch: | ||||
|       which_dequeue = lambda q: q.dequeue_up_to | ||||
|     else: | ||||
|       which_dequeue = lambda q: q.dequeue_many | ||||
| 
 | ||||
|     enqueues_to_top = [ | ||||
|         top_queue.enqueue( | ||||
|             [constant_op.constant(i)] + | ||||
|             which_dequeue(q)(batch_size, name="read_bucket_%d" % i), | ||||
|             name="enqueue_from_bucket_%d" % i) | ||||
|         for i, q in enumerate(bucket_queues)] | ||||
| 
 | ||||
|     for i, q in enumerate(bucket_queues): | ||||
|       queue_runner.add_queue_runner(queue_runner.QueueRunner( | ||||
|           q, [enqueues_to_top[i]], | ||||
|           queue_closed_exception_types=( | ||||
|               errors.OutOfRangeError, errors.CancelledError))) | ||||
|     queue_runner.add_queue_runner(queue_runner.QueueRunner( | ||||
|         top_queue, bucket_enqueue_ops, | ||||
|         queue_closed_exception_types=( | ||||
|             errors.OutOfRangeError, errors.CancelledError))) | ||||
| 
 | ||||
|     for q in bucket_queues: | ||||
|       logging_ops.scalar_summary( | ||||
|           "bucket/%s/size" % q.name, | ||||
|           math_ops.cast(top_queue.size(), dtypes.float32)) | ||||
|     logging_ops.scalar_summary( | ||||
|         "bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity), | ||||
|         math_ops.cast(top_queue.size(), dtypes.float32) * (1. / capacity)) | ||||
| 
 | ||||
|     dequeued = top_queue.dequeue(name="dequeue_top") | ||||
|     which_bucket_dequeued = dequeued[0] | ||||
|     dequeued = dequeued[1:] | ||||
|     dequeued = _deserialize_sparse_tensors(dequeued, sparse_info) | ||||
|     return (which_bucket_dequeued, _as_original_type(tensors, dequeued)) | ||||
| 
 | ||||
| 
 | ||||
| def bucket_by_sequence_length(input_length, | ||||
|                               tensors, | ||||
|                               batch_size, | ||||
|                               bucket_boundaries, | ||||
|                               num_threads=1, | ||||
|                               capacity=32, | ||||
|                               shapes=None, | ||||
|                               dynamic_pad=False, | ||||
|                               allow_smaller_final_batch=False, | ||||
|                               keep_input=None, | ||||
|                               shared_name=None, | ||||
|                               name=None): | ||||
|   """Lazy bucketing of inputs according to their length. | ||||
| 
 | ||||
|   This method calls `tf.contrib.training.bucket` under the hood, after first | ||||
|   subdividing the bucket boundaries into separate buckets and identifying which | ||||
|   bucket the given `input_length` belongs to.  See the documentation for | ||||
|   `which_bucket` for details of the other arguments. | ||||
| 
 | ||||
|   Args: | ||||
|     input_length: `int32` scalar `Tensor`, the sequence length of tensors. | ||||
|     tensors: The list or dictionary of tensors, representing a single element, | ||||
|       to bucket.  Nested lists are not supported. | ||||
|     batch_size: The new batch size pulled from the queue | ||||
|       (python int or int32 scalar). | ||||
|     bucket_boundaries: int list, increasing non-negative numbers. | ||||
|       The edges of the buckets to use when bucketing tensors.  Two extra buckets | ||||
|       are created, one for `input_length < bucket_boundaries[0]` and | ||||
|       one for `input_length >= bucket_boundaries[-1]`. | ||||
|     num_threads: An integer.  The number of threads enqueuing `tensors`. | ||||
|     capacity: An integer. The maximum number of minibatches in the top queue, | ||||
|       and also the maximum number of elements within each bucket. | ||||
|     shapes: (Optional) The shapes for each example.  Defaults to the | ||||
|       inferred shapes for `tensors`. | ||||
|     dynamic_pad: Boolean.  Allow variable dimensions in input shapes. | ||||
|       The given dimensions are padded upon dequeue so that tensors within a | ||||
|       batch have the same shapes. | ||||
|     allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final | ||||
|       batches to be smaller if there are insufficient items left in the queues. | ||||
|     keep_input: (Optional).  A `bool` scalar Tensor.  If provided, this tensor | ||||
|       controls whether the input is added to the queue or not.  If it evaluates | ||||
|       `True`, then `tensors` are added to the bucket; otherwise they are | ||||
|       dropped.  This tensor essentially acts as a filtering mechanism. | ||||
|       The default behavior is to assume `keep_input=True`. | ||||
|     shared_name: (Optional). If set, the queues will be shared under the given | ||||
|       name across multiple sessions. | ||||
|     name: (Optional) A name for the operations. | ||||
| 
 | ||||
|   Returns: | ||||
|     A tuple `(sequence_length, outputs)` where `sequence_length` is | ||||
|     a 1-D `Tensor` of size `batch_size` and `outputs` is a list or dictionary | ||||
|     of batched, bucketed, outputs corresponding to elements of `tensors`. | ||||
| 
 | ||||
|   Raises: | ||||
|     TypeError: if `bucket_boundaries` is not a list of python integers. | ||||
|     ValueError: if `bucket_boundaries` is empty or contains non-increasing | ||||
|       values. | ||||
|   """ | ||||
|   tensor_list = _as_tensor_list(tensors) | ||||
|   if not isinstance(bucket_boundaries, (list, tuple)): | ||||
|     raise TypeError( | ||||
|         "bucket_boundaries must be a list or tuple, but received: %s" | ||||
|         % bucket_boundaries) | ||||
|   if not bucket_boundaries: | ||||
|     raise ValueError("bucket_boundaries must not be empty") | ||||
|   for (s, e) in zip(bucket_boundaries[:-1], bucket_boundaries[1:]): | ||||
|     if not isinstance(s, int) or not isinstance(e, int): | ||||
|       raise TypeError( | ||||
|           "bucket boundaries must be integers, but saw: %s and %s" % (s, e)) | ||||
|     if s >= e: | ||||
|       raise ValueError( | ||||
|           "Buckets must contain sequential increasing lengths, but saw: " | ||||
|           "%d before %d" % (s, e)) | ||||
| 
 | ||||
|   with ops.name_scope(name, "bucket_by_sequence_length", | ||||
|                       [input_length] + tensor_list) as name: | ||||
|     input_length = ops.convert_to_tensor( | ||||
|         input_length, dtype=dtypes.int32, name="input_length") | ||||
|     # Bucketing conditions are: | ||||
|     #   l < b[0] | ||||
|     #   b[0] <= l < b[1] | ||||
|     #   b[1] <= l < b[2] | ||||
|     #   ... | ||||
|     #   b[N-2] <= l < b[N-1] | ||||
|     #   b[N-1] <= l | ||||
|     # Equivalent to: | ||||
|     #   [-inf, b[0], b[1], ..., b[N-1]] <= l < [b[0], b[1], ..., b[N-1], inf] | ||||
|     buckets_min = [np.iinfo(np.int32).min] + list(bucket_boundaries) | ||||
|     buckets_max = list(bucket_boundaries) + [np.iinfo(np.int32).max] | ||||
|     conditions_c = math_ops.logical_and( | ||||
|         math_ops.less_equal(buckets_min, input_length), | ||||
|         math_ops.less(input_length, buckets_max)) | ||||
|     which_bucket = math_ops.reduce_min(array_ops.where(conditions_c)) | ||||
|     which_bucket = math_ops.to_int32(which_bucket) | ||||
| 
 | ||||
|     if shapes is not None: | ||||
|       shapes = [tensor_shape.scalar()] + shapes | ||||
| 
 | ||||
|     _, dequeued = bucket( | ||||
|         tensors=[input_length] + tensor_list, | ||||
|         which_bucket=which_bucket, | ||||
|         batch_size=batch_size, | ||||
|         num_buckets=len(bucket_boundaries) + 1, | ||||
|         num_threads=num_threads, | ||||
|         capacity=capacity, | ||||
|         shapes=shapes, | ||||
|         dynamic_pad=dynamic_pad, | ||||
|         allow_smaller_final_batch=allow_smaller_final_batch, | ||||
|         keep_input=keep_input, | ||||
|         shared_name=shared_name) | ||||
| 
 | ||||
|     return (dequeued[0], _as_original_type(tensors, dequeued[1:])) | ||||
| 
 | ||||
| 
 | ||||
| __all__ = [ | ||||
|     "bucket", | ||||
|     "bucket_by_sequence_length" | ||||
| ] | ||||
							
								
								
									
										356
									
								
								tensorflow/contrib/training/python/training/bucket_ops_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										356
									
								
								tensorflow/contrib/training/python/training/bucket_ops_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,356 @@ | ||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| 
 | ||||
| """Tests for tf.contrib.training.bucket.""" | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import random | ||||
| 
 | ||||
| import numpy as np | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| 
 | ||||
| def _which_bucket(bucket_edges, v): | ||||
|   """Identify which bucket v falls into. | ||||
| 
 | ||||
|   Args: | ||||
|     bucket_edges: int array, bucket edges | ||||
|     v: int scalar, index | ||||
|   Returns: | ||||
|     int scalar, the bucket. | ||||
|     If v < bucket_edges[0], return 0. | ||||
|     If bucket_edges[0] <= v < bucket_edges[1], return 1. | ||||
|     ... | ||||
|     If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges). | ||||
|     If v >= bucket_edges[-1], return len(bucket_edges) + 1 | ||||
|   """ | ||||
|   v = np.asarray(v) | ||||
|   full = [0] + bucket_edges | ||||
|   found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0] | ||||
|   if not found.size: | ||||
|     return len(full) | ||||
|   return found[0] | ||||
| 
 | ||||
| 
 | ||||
| class BucketTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def setUp(self): | ||||
|     tf.reset_default_graph() | ||||
| 
 | ||||
|     self.scalar_int_feed = tf.placeholder(tf.int32, ()) | ||||
|     self.unk_int64_feed = tf.placeholder(tf.int64, (None,)) | ||||
|     self.vec3_str_feed = tf.placeholder(tf.string, (3,)) | ||||
| 
 | ||||
|     self._coord = tf.train.Coordinator() | ||||
|     # Make capacity very large so we can feed all the inputs in the | ||||
|     # main thread without blocking | ||||
|     input_queue = tf.PaddingFIFOQueue( | ||||
|         5000, | ||||
|         dtypes=[tf.int32, tf.int64, tf.string], | ||||
|         shapes=[(), (None,), (3,)]) | ||||
| 
 | ||||
|     self._input_enqueue_op = input_queue.enqueue( | ||||
|         (self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed)) | ||||
|     self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue() | ||||
|     self._threads = None | ||||
|     self._close_op = input_queue.close() | ||||
|     self._sess = None | ||||
| 
 | ||||
|   def enqueue_inputs(self, sess, feed_dict): | ||||
|     sess.run(self._input_enqueue_op, feed_dict=feed_dict) | ||||
| 
 | ||||
|   def start_queue_runners(self, sess): | ||||
|     # Store session to be able to close inputs later | ||||
|     if self._sess is None: | ||||
|       self._sess = sess | ||||
|     self._threads = tf.train.start_queue_runners(coord=self._coord) | ||||
| 
 | ||||
|   def tearDown(self): | ||||
|     if self._sess is not None: | ||||
|       self._sess.run(self._close_op) | ||||
|     self._coord.request_stop() | ||||
|     self._coord.join(self._threads) | ||||
| 
 | ||||
|   def testSingleBucket(self): | ||||
|     bucketed_dynamic = tf.contrib.training.bucket( | ||||
|         tensors=[self.scalar_int, self.unk_int64, self.vec3_str], | ||||
|         which_bucket=tf.constant(0), | ||||
|         num_buckets=2, | ||||
|         batch_size=32, | ||||
|         num_threads=10, | ||||
|         dynamic_pad=True) | ||||
|     # Check shape inference on bucketing outputs | ||||
|     self.assertAllEqual( | ||||
|         [[32], [32, None], [32, 3]], | ||||
|         [out.get_shape().as_list() for out in bucketed_dynamic[1]]) | ||||
|     with self.test_session() as sess: | ||||
|       for v in range(32): | ||||
|         self.enqueue_inputs( | ||||
|             sess, | ||||
|             {self.scalar_int_feed: v, | ||||
|              self.unk_int64_feed: v * [v], | ||||
|              self.vec3_str_feed: 3 * [str(v)]}) | ||||
|       self.start_queue_runners(sess) | ||||
| 
 | ||||
|       # Get a single minibatch | ||||
|       bucketed_values = sess.run(bucketed_dynamic) | ||||
| 
 | ||||
|       # (which_bucket, bucket_tensors). | ||||
|       self.assertEqual(2, len(bucketed_values)) | ||||
| 
 | ||||
|       # Count number of bucket_tensors. | ||||
|       self.assertEqual(3, len(bucketed_values[1])) | ||||
| 
 | ||||
|       # Ensure bucket 0 was used for all minibatch entries. | ||||
|       self.assertAllEqual(0, bucketed_values[0]) | ||||
| 
 | ||||
|       expected_scalar_int = np.arange(32) | ||||
|       expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) | ||||
|       for i in range(32): | ||||
|         expected_unk_int64[i, :i] = i | ||||
|       expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T | ||||
| 
 | ||||
|       # Must resort the output because num_threads > 1 leads to | ||||
|       # sometimes-inconsistent insertion order. | ||||
|       resort = np.argsort(bucketed_values[1][0]) | ||||
|       self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort]) | ||||
|       self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort]) | ||||
|       self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort]) | ||||
| 
 | ||||
|   def testEvenOddBuckets(self): | ||||
|     which_bucket = (self.scalar_int % 2) | ||||
|     bucketed_dynamic = tf.contrib.training.bucket( | ||||
|         tensors=[self.scalar_int, self.unk_int64, self.vec3_str], | ||||
|         which_bucket=which_bucket, | ||||
|         num_buckets=2, | ||||
|         batch_size=32, | ||||
|         num_threads=10, | ||||
|         dynamic_pad=True) | ||||
|     # Check shape inference on bucketing outputs | ||||
|     self.assertAllEqual( | ||||
|         [[32], [32, None], [32, 3]], | ||||
|         [out.get_shape().as_list() for out in bucketed_dynamic[1]]) | ||||
|     with self.test_session() as sess: | ||||
|       for v in range(64): | ||||
|         self.enqueue_inputs( | ||||
|             sess, | ||||
|             {self.scalar_int_feed: v, | ||||
|              self.unk_int64_feed: v * [v], | ||||
|              self.vec3_str_feed: 3 * [str(v)]}) | ||||
|       self.start_queue_runners(sess) | ||||
| 
 | ||||
|       # Get two minibatches (one containing even values, one containing odds) | ||||
|       bucketed_values_0 = sess.run(bucketed_dynamic) | ||||
|       bucketed_values_1 = sess.run(bucketed_dynamic) | ||||
| 
 | ||||
|       # (which_bucket, bucket_tensors). | ||||
|       self.assertEqual(2, len(bucketed_values_0)) | ||||
|       self.assertEqual(2, len(bucketed_values_1)) | ||||
| 
 | ||||
|       # Count number of bucket_tensors. | ||||
|       self.assertEqual(3, len(bucketed_values_0[1])) | ||||
|       self.assertEqual(3, len(bucketed_values_1[1])) | ||||
| 
 | ||||
|       # Figure out which output has the even values (there's | ||||
|       # randomness due to the multithreaded nature of bucketing) | ||||
|       if bucketed_values_0[0] % 2 == 1: | ||||
|         bucketed_values_even, bucketed_values_odd = ( | ||||
|             bucketed_values_1, bucketed_values_0) | ||||
|       else: | ||||
|         bucketed_values_even, bucketed_values_odd = ( | ||||
|             bucketed_values_0, bucketed_values_1) | ||||
| 
 | ||||
|       # Ensure bucket 0 was used for all minibatch entries. | ||||
|       self.assertAllEqual(0, bucketed_values_even[0]) | ||||
|       self.assertAllEqual(1, bucketed_values_odd[0]) | ||||
| 
 | ||||
|       # Test the first bucket outputted, the events starting at 0 | ||||
|       expected_scalar_int = np.arange(0, 32 * 2, 2) | ||||
|       expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) | ||||
|       for i in range(0, 32): | ||||
|         expected_unk_int64[i, :2*i] = 2*i | ||||
|       expected_vec3_str = np.vstack( | ||||
|           3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T | ||||
| 
 | ||||
|       # Must resort the output because num_threads > 1 leads to | ||||
|       # sometimes-inconsistent insertion order. | ||||
|       resort = np.argsort(bucketed_values_even[1][0]) | ||||
|       self.assertAllEqual(expected_scalar_int, | ||||
|                           bucketed_values_even[1][0][resort]) | ||||
|       self.assertAllEqual(expected_unk_int64, | ||||
|                           bucketed_values_even[1][1][resort]) | ||||
|       self.assertAllEqual(expected_vec3_str, | ||||
|                           bucketed_values_even[1][2][resort]) | ||||
| 
 | ||||
|       # Test the second bucket outputted, the odds starting at 1 | ||||
|       expected_scalar_int = np.arange(1, 32 * 2 + 1, 2) | ||||
|       expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) | ||||
|       for i in range(0, 32): | ||||
|         expected_unk_int64[i, :2*i + 1] = 2*i + 1 | ||||
|       expected_vec3_str = np.vstack( | ||||
|           3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T | ||||
| 
 | ||||
|       # Must resort the output because num_threads > 1 leads to | ||||
|       # sometimes-inconsistent insertion order. | ||||
|       resort = np.argsort(bucketed_values_odd[1][0]) | ||||
|       self.assertAllEqual(expected_scalar_int, | ||||
|                           bucketed_values_odd[1][0][resort]) | ||||
|       self.assertAllEqual(expected_unk_int64, | ||||
|                           bucketed_values_odd[1][1][resort]) | ||||
|       self.assertAllEqual(expected_vec3_str, | ||||
|                           bucketed_values_odd[1][2][resort]) | ||||
| 
 | ||||
|   def testEvenOddBucketsFilterOutAllOdd(self): | ||||
|     which_bucket = (self.scalar_int % 2) | ||||
|     keep_input = tf.equal(which_bucket, 0) | ||||
|     bucketed_dynamic = tf.contrib.training.bucket( | ||||
|         tensors=[self.scalar_int, self.unk_int64, self.vec3_str], | ||||
|         which_bucket=which_bucket, | ||||
|         num_buckets=2, | ||||
|         batch_size=32, | ||||
|         num_threads=10, | ||||
|         keep_input=keep_input, | ||||
|         dynamic_pad=True) | ||||
|     # Check shape inference on bucketing outputs | ||||
|     self.assertAllEqual( | ||||
|         [[32], [32, None], [32, 3]], | ||||
|         [out.get_shape().as_list() for out in bucketed_dynamic[1]]) | ||||
|     with self.test_session() as sess: | ||||
|       for v in range(128): | ||||
|         self.enqueue_inputs( | ||||
|             sess, | ||||
|             {self.scalar_int_feed: v, | ||||
|              self.unk_int64_feed: v * [v], | ||||
|              self.vec3_str_feed: 3 * [str(v)]}) | ||||
|       self.start_queue_runners(sess) | ||||
| 
 | ||||
|       # Get two minibatches ([0, 2, ...] and [64, 66, ...]) | ||||
|       bucketed_values_even0 = sess.run(bucketed_dynamic) | ||||
|       bucketed_values_even1 = sess.run(bucketed_dynamic) | ||||
| 
 | ||||
|       # Ensure that bucket 1 was completely filtered out | ||||
|       self.assertAllEqual(0, bucketed_values_even0[0]) | ||||
|       self.assertAllEqual(0, bucketed_values_even1[0]) | ||||
| 
 | ||||
|       # Merge their output for sorting and comparison | ||||
|       bucketed_values_all_elem0 = np.concatenate( | ||||
|           (bucketed_values_even0[1][0], | ||||
|            bucketed_values_even1[1][0])) | ||||
| 
 | ||||
|       self.assertAllEqual( | ||||
|           np.arange(0, 128, 2), sorted(bucketed_values_all_elem0)) | ||||
| 
 | ||||
| 
 | ||||
| class BucketBySequenceLengthTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def _testBucketBySequenceLength(self, allow_small_batch): | ||||
|     tf.reset_default_graph() | ||||
| 
 | ||||
|     # All inputs must be identical lengths across tuple index. | ||||
|     # The input reader will get input_length from the first tuple | ||||
|     # entry. | ||||
|     data_len = 4 | ||||
|     target_len = 3 | ||||
|     input_pairs = [ | ||||
|         (length, | ||||
|          ([np.int64(length)] * data_len, | ||||
|           [str(length).encode("ascii")] * target_len)) | ||||
|         for length in (1, 3, 4, 5, 6, 10)] | ||||
| 
 | ||||
|     lengths = tf.placeholder(tf.int32, ()) | ||||
|     data = tf.placeholder(tf.int64, (data_len,)) | ||||
|     targets = tf.placeholder(tf.string, (target_len,)) | ||||
| 
 | ||||
|     batch_size = 8 | ||||
|     bucket_boundaries = [3, 4, 5, 10] | ||||
| 
 | ||||
|     # Make capacity very large so we can feed all the inputs in the | ||||
|     # main thread without blocking | ||||
|     input_queue = tf.FIFOQueue( | ||||
|         5000, (tf.int32, tf.int64, tf.string), | ||||
|         ((), (data_len,), (target_len,))) | ||||
|     input_enqueue_op = input_queue.enqueue((lengths, data, targets)) | ||||
|     lengths_t, data_t, targets_t = input_queue.dequeue() | ||||
|     close_input_op = input_queue.close() | ||||
| 
 | ||||
|     (out_lengths_t, data_and_targets_t) = ( | ||||
|         tf.contrib.training.bucket_by_sequence_length( | ||||
|             input_length=lengths_t, | ||||
|             tensors=[data_t, targets_t], | ||||
|             batch_size=batch_size, | ||||
|             bucket_boundaries=bucket_boundaries, | ||||
|             allow_smaller_final_batch=allow_small_batch, | ||||
|             num_threads=10)) | ||||
| 
 | ||||
|     expected_batch_size = None if allow_small_batch else batch_size | ||||
|     self.assertEqual(out_lengths_t.get_shape().as_list(), | ||||
|                      [expected_batch_size]) | ||||
|     self.assertEqual(data_and_targets_t[0].get_shape().as_list(), | ||||
|                      [expected_batch_size, data_len]) | ||||
|     self.assertEqual(data_and_targets_t[1].get_shape().as_list(), | ||||
|                      [expected_batch_size, target_len]) | ||||
| 
 | ||||
|     def _read_test(sess): | ||||
|       for _ in range(50): | ||||
|         (out_lengths, (data, targets)) = sess.run( | ||||
|             (out_lengths_t, data_and_targets_t)) | ||||
|         if allow_small_batch: | ||||
|           self.assertEqual(data_len, data.shape[1]) | ||||
|           self.assertEqual(target_len, targets.shape[1]) | ||||
|           self.assertGreaterEqual(batch_size, out_lengths.shape[0]) | ||||
|           self.assertGreaterEqual(batch_size, data.shape[0]) | ||||
|           self.assertGreaterEqual(batch_size, targets.shape[0]) | ||||
|         else: | ||||
|           self.assertEqual((batch_size, data_len), data.shape) | ||||
|           self.assertEqual((batch_size, target_len), targets.shape) | ||||
|           self.assertEqual((batch_size,), out_lengths.shape) | ||||
|         for (lr, dr, tr) in zip(out_lengths, data, targets): | ||||
|           # Make sure length matches data (here it's the same value) | ||||
|           self.assertEqual(dr[0], lr) | ||||
|           # Make sure data & targets match | ||||
|           self.assertEqual(dr[0], int(tr[0].decode("ascii"))) | ||||
|           # Make sure for each row, data came from the same bucket. | ||||
|           self.assertEqual(_which_bucket(bucket_boundaries, dr[0]), | ||||
|                            _which_bucket(bucket_boundaries, dr[1])) | ||||
| 
 | ||||
|     with self.test_session() as sess: | ||||
|       coord = tf.train.Coordinator() | ||||
| 
 | ||||
|       # Feed the inputs, then close the input thread. | ||||
|       for _ in range(50 * batch_size + 100): | ||||
|         which = random.randint(0, len(input_pairs) - 1) | ||||
|         length, pair = input_pairs[which] | ||||
|         sess.run(input_enqueue_op, feed_dict={ | ||||
|             lengths: length, data: pair[0], targets: pair[1]}) | ||||
|       sess.run(close_input_op) | ||||
| 
 | ||||
|       # Start the queue runners | ||||
|       threads = tf.train.start_queue_runners(coord=coord) | ||||
|       # Read off the top of the bucket and ensure correctness of output | ||||
|       _read_test(sess) | ||||
|       coord.request_stop() | ||||
|       coord.join(threads) | ||||
| 
 | ||||
|   def testBucketBySequenceLength(self): | ||||
|     self._testBucketBySequenceLength(allow_small_batch=False) | ||||
| 
 | ||||
|   def testBucketBySequenceLengthAllow(self): | ||||
|     self._testBucketBySequenceLength(allow_small_batch=True) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   tf.test.main() | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user