From 76db97fe3961651617371902a1a623df61f9ed81 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 20 Dec 2017 13:29:07 -0800 Subject: [PATCH] Add prefetching into parallel_interleave This change adds 2 parameters to parallel_interleave: - prefetch_input_elements: determines the number of iterators to prefetch allowing buffers to warm up and data to be pre-fetched without blocking the main thread (i.e. the GetNext() call). - buffer_output_elements: in order to avoid creating thousands of threads, we fuse in the .prefetch() operator as an additional parameter. The value of this parameter is identical to the value passed to `.prefetch()` PiperOrigin-RevId: 179726088 --- .../interleave_dataset_op_test.py | 248 +++++++++- tensorflow/contrib/data/python/ops/BUILD | 1 + .../contrib/data/python/ops/interleave_ops.py | 54 ++- .../data/parallel_interleave_dataset_op.cc | 455 +++++++++++------- .../core/ops/compat/ops_history.v1.pbtxt | 8 + tensorflow/core/ops/dataset_ops.cc | 2 + tensorflow/python/data/ops/BUILD | 2 +- tensorflow/python/data/ops/readers.py | 28 +- tensorflow/python/data/util/BUILD | 24 + tensorflow/python/data/util/convert.py | 34 ++ tensorflow/python/data/util/convert_test.py | 53 ++ 11 files changed, 678 insertions(+), 231 deletions(-) create mode 100644 tensorflow/python/data/util/convert.py create mode 100644 tensorflow/python/data/util/convert_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index e66ed3f7aa2..e13c60c9a71 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -41,6 +41,7 @@ from tensorflow.python.platform import test class InterleaveDatasetTest(test.TestCase): def _interleave(self, lists, cycle_length, block_length): + # TODO(b/69678297): Consolidate python interleave implementations. num_open = 0 # `all_iterators` acts as a queue of iterators over each element of `lists`. @@ -255,11 +256,15 @@ class InterleaveDatasetSeriazationTest( class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) + self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.error = None self.repeat_count = 2 # Set up threading events used to sequence when items are produced that @@ -276,6 +281,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.write_coordination_events[x].wait() self.write_coordination_events[x].clear() self.read_coordination_events[x].release() + if self.error: + err = self.error + self.error = None + raise err # pylint: disable=raising-bad-type return x * x def map_fn(x): @@ -286,11 +295,13 @@ class ParallelInterleaveDatasetTest(test.TestCase): dataset = dataset.repeat(x) return dataset.map(map_fn) - self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) - .repeat(self.repeat_count).apply( - interleave_ops.parallel_interleave( - interleave_fn, self.cycle_length, - self.block_length, self.sloppy))) + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) self.iterator = self.dataset.make_initializable_iterator() self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() @@ -380,7 +391,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): for i in range(4, 7): self.write_coordination_events[i].set() - def _testSingleThreaded(self, sloppy=False): + def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. with self.test_session() as sess: @@ -391,7 +402,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 1, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for expected_element in self._interleave( @@ -408,6 +421,41 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedSloppy(self): self._testSingleThreaded(sloppy=True) + def testSingleThreadedPrefetch1Itr(self): + self._testSingleThreaded(prefetch_input_elements=1) + + def testSingleThreadedPrefetch1ItrSloppy(self): + self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) + + def testSingleThreadedRagged(self): + # Tests a sequence with wildly different elements per iterator. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [3, 7, 4], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + + # Add coordination values for 3 and 7 + self.read_coordination_events[3] = threading.Semaphore(0) + self.write_coordination_events[3] = threading.Event() + self.read_coordination_events[7] = threading.Semaphore(0) + self.write_coordination_events[7] = threading.Event() + + for expected_element in self._interleave( + [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): + self.write_coordination_events[expected_element].set() + output = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior @@ -420,7 +468,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -463,6 +513,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -502,7 +554,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -545,7 +599,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -583,7 +639,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -604,7 +662,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [0, 0, 0], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -615,7 +675,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testNonEmptyInputIntoEmptyOutputsSloppy(self): self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) - def _testPartiallyEmptyOutputs(self, sloppy=False): + def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): + race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. with self.test_session() as sess: self._clear_coordination_events() @@ -627,27 +688,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): self.write_coordination_events[expected_element].set() - if done_first_event: # First event starts the worker threads + # First event starts the worker threads. Additionally, when running the + # sloppy case with prefetch_input_elements=0, we get stuck if we wait + # for the read coordination event for certain event orderings in the + # presence of finishing iterators. + if done_first_event and not (sloppy and (i in race_indices)): self.read_coordination_events[expected_element].acquire() actual_element = sess.run(self.next_element) - if not done_first_event: + if not done_first_event or (sloppy and (i in race_indices)): done_first_event = True self.read_coordination_events[expected_element].acquire() self.assertEqual(expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.next_element) def testPartiallyEmptyOutputs(self): self._testPartiallyEmptyOutputs() def testPartiallyEmptyOutputsSloppy(self): - self._testPartiallyEmptyOutputs(sloppy=True) + self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid @@ -661,6 +726,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) mis_ordering = [ @@ -683,8 +750,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 3, - self.sloppy: True + self.block_length: 1, + self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. @@ -692,7 +761,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._interleave( [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, cycle_length=2, - block_length=2)): + block_length=3)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() @@ -716,7 +785,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 3, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) for i in range(4, 7): self.write_coordination_events[i].set() @@ -790,6 +861,139 @@ class ParallelInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testErrorsInOutputFn(self): + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + + except_on_element_indices = set([3]) + + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if i in except_on_element_indices: + self.error = ValueError() + self.write_coordination_events[expected_element].set() + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + self.write_coordination_events[expected_element].set() + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInputFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInterleaveFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + y = script_ops.py_func(map_py_fn, [x], x.dtype) + dataset = dataset.repeat(y) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 00af1f0b8ed..4349085a101 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -122,6 +122,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 53324e06e7f..3124ca1d154 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy): + sloppy, buffer_output_elements, prefetch_input_elements): """See `tf.contrib.data.parallel_interleave()` for details.""" super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset @@ -74,6 +75,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): block_length, dtype=dtypes.int64, name="block_length") self._sloppy = ops.convert_to_tensor( sloppy, dtype=dtypes.bool, name="sloppy") + self._buffer_output_elements = convert.optional_param_to_tensor( + "buffer_output_elements", + buffer_output_elements, + argument_default=2 * block_length) + self._prefetch_input_elements = convert.optional_param_to_tensor( + "prefetch_input_elements", + prefetch_input_elements, + argument_default=2 * cycle_length) def _as_variant_tensor(self): return gen_dataset_ops.parallel_interleave_dataset( @@ -82,6 +91,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._cycle_length, self._block_length, self._sloppy, + self._buffer_output_elements, + self._prefetch_input_elements, f=self._map_func, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -101,7 +112,12 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): return self._output_types -def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): +def parallel_interleave(map_func, + cycle_length, + block_length=1, + sloppy=False, + buffer_output_elements=None, + prefetch_input_elements=None): """A parallel version of the `Dataset.interleave()` transformation. `parallel_interleave()` maps `map_func` across its input to produce nested @@ -129,12 +145,17 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): Args: map_func: A function mapping a nested structure of tensors to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. sloppy: If false, elements are produced in deterministic order. Otherwise, the implementation is allowed, for the sake of expediency, to produce elements in a non-deterministic order. + buffer_output_elements: The number of elements each iterator being + interleaved should buffer (similar to the `.prefetch()` transformation for + each interleaved iterator). + prefetch_input_elements: The number of input elements to transform to + iterators before they are needed for interleaving. Returns: A `Dataset` transformation function, which can be passed to @@ -142,7 +163,9 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy) + dataset, map_func, cycle_length, block_length, sloppy, + buffer_output_elements, prefetch_input_elements) + return _apply_fn @@ -187,11 +210,11 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. Note: sloppy_interleave will - skip the remainder of elements in the block_length in order to avoid - blocking. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. Note: + `sloppy_interleave` will skip the remainder of elements in the + `block_length` in order to avoid blocking. Returns: A `Dataset` transformation function, which can be passed to @@ -199,5 +222,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy=True) + dataset, + map_func, + cycle_length, + block_length, + sloppy=True, + buffer_output_elements=None, + prefetch_input_elements=None) + return _apply_fn diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index cb6a83606e6..e429db215d2 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -48,28 +50,44 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { other_arguments.push_back(t); } - int64 cycle_length; + int64 cycle_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "cycle_length", &cycle_length)); OP_REQUIRES(ctx, cycle_length > 0, errors::InvalidArgument("`cycle_length` must be > 0")); - int64 block_length; + int64 block_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "block_length", &block_length)); OP_REQUIRES(ctx, block_length > 0, errors::InvalidArgument("`block_length` must be > 0")); - bool sloppy; + bool sloppy = false; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy)); + int64 buffer_output_elements = 0; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements", + &buffer_output_elements)); + OP_REQUIRES( + ctx, buffer_output_elements > 0, + errors::InvalidArgument("`buffer_output_elements` must be > 0")); + + int64 prefetch_input_elements = 0; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements", + &prefetch_input_elements)); + OP_REQUIRES( + ctx, prefetch_input_elements >= 0, + errors::InvalidArgument("`prefetch_input_elements` must be >= 0")); + std::unique_ptr captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_, std::move(other_arguments), &captured_func)); - *output = new Dataset(input, std::move(captured_func), cycle_length, - block_length, sloppy, output_types_, output_shapes_); + *output = + new Dataset(input, std::move(captured_func), cycle_length, block_length, + sloppy, buffer_output_elements, prefetch_input_elements, + output_types_, output_shapes_); } private: @@ -77,13 +95,16 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: Dataset(const DatasetBase* input, std::unique_ptr captured_func, int64 cycle_length, - int64 block_length, bool sloppy, const DataTypeVector& output_types, + int64 block_length, bool sloppy, int64 buffer_output_elements, + int64 prefetch_input_elements, const DataTypeVector& output_types, const std::vector& output_shapes) : input_(input), captured_func_(std::move(captured_func)), cycle_length_(cycle_length), block_length_(block_length), sloppy_(sloppy), + buffer_output_elements_(buffer_output_elements), + prefetch_input_elements_(prefetch_input_elements), output_types_(output_types), output_shapes_(output_shapes) { input_->Ref(); @@ -109,244 +130,317 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } private: + int64 num_threads() const { + return cycle_length_ + prefetch_input_elements_; + } + + // Parallel interleave's implementation is designed around a few principles: + // 1. Thread creation is relatively expensive. (Not reusing + // threads causes a number of indirect costs such as poorer tcmalloc + // performance due to thread-local caches, etc.) We allocate a fixed + // number of threads at the start and never change. This is why we've + // fused functionality that is theoretically orthogonal (i.e. + // .prefetch()) into the implementation. + // 2. Drop-in replacement for standard interleave. The goal will be to + // auto-opt people into an optimized implementation without any work + // on the customer's part. We thus go through great pains to maintain + // identical iteration orders, full determinism (disabled only via a + // flag, etc.) + // 3. Performance across a variety of environments and I/O envelopes. + // + // The actual implementation centers around a collection of worker threads + // and their corresponding worker state (tracked in the `workers_` vector). + // Worker threads repeatedly receive a vector of Tensors that are used as + // input to the flat-map function (`captured_func_`). The output of this + // function must be a dataset. The worker thread then repeatedly calls + // `GetNext()`, maintaining a buffer of elements to minimize the likelihood + // that a caller will block waiting for an element to be produced. + // + // Pointers to these worker states are kept in 2 disjoint data structures: + // 1. `interleave_` is a vector containing pointers to `WorkerState`s that + // we + // are interleaving. Worker threads backing these WorkerStates should + // be regularly producing values. + // 2. `staging_` is a deque containing pointers to WorkerStates that we + // will move to `interleave_` when an iterator in `interleave_` is + // exhausted. + // + // The client calls `GetNext[Internal]()` to retrieve an output element. The + // internal implementation updates the state of `interleave_` and `staging_` + // as output iterators (run by the worker threads) are exhausted. + // + // `input_impl_` is the input iterator that generates arguments for the + // flat-map function (`captured_func_`). It is set to an iterator at + // Iterator construction, and is fixed until we consume all input elements. + // Once it is exhausted, we reset the unique_ptr to eagerly deallocate + // memory. + // + // A few invariants are maintained: + // 1. No element in interleave_ should be a nullptr unless `staging_` is + // empty and `input_impl_` is empty. + // 2. Every `worker_` element is pointed to by at most one element of the + // union of `interleave_` and `staging_`. + // 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by + // an element in `interleave_` or `staging_`. class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params), input_impl_(params.dataset->input_->MakeIterator(params.prefix)), - output_elements_(params.dataset->cycle_length_) {} + workers_(dataset()->num_threads()) {} ~Iterator() override { mutex_lock l(mu_); cancelled_ = true; // Notify all workers in case they are blocked. - for (int64 i = 0; i < dataset()->cycle_length_; ++i) { - output_elements_[i].cond_var.notify_all(); + for (auto& worker : workers_) { + worker.cond_var.notify_all(); } } // It is implemented so that it matches the deterministic interleave - // unless we would block waiting for an element, at which point it skips - // along to the next available value. + // unless getting the next element would block and we are allowed to be + // sloppy. Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); - const int64 num_workers = worker_threads_.size(); - if (num_workers == 0) { - *end_of_sequence = true; - return Status::OK(); - } while (!cancelled_) { // Wait for an item to become available, blocking if necessary. If we // are allowed to be sloppy, we can skip over input datasets that do // not have an item readily available. - const int64 n = dataset()->sloppy_ ? num_workers : 1LL; - for (int64 i = 0; i < n; ++i) { - int64 index = (next_index_ + i) % num_workers; - if (output_elements_[index].is_produced) { + bool can_produce_elements = false; + bool must_wait_for_input = true; + for (int64 i = 0; i < interleave_.size(); ++i) { + int64 index = (next_index_ + i) % interleave_.size(); + WorkerState* current_worker = interleave_[index]; + if (!current_worker) continue; // Empty interleave elements. + can_produce_elements |= current_worker->MayHaveElements(); + if (!current_worker->outputs.empty()) { + // We have an element! next_index_ = index; if (i == 0) { block_count_++; if (block_count_ == dataset()->block_length_) { - next_index_ = (index + 1) % num_workers; + next_index_ = (index + 1) % interleave_.size(); block_count_ = 0; } } else { block_count_ = 0; } - // If we encounter an EoF, advance to the next iterator - if (output_elements_[index].end_of_sequence) { - output_elements_[index].is_produced = false; - output_elements_[index].cond_var.notify_one(); - next_index_ = (index + 1) % num_workers; - block_count_ = 0; - i = -1; // Restart the inner loop - continue; - } *end_of_sequence = false; - if (output_elements_[index].output_status.ok()) { - output_elements_[index].output_value.swap(*out_tensors); + Status s = current_worker->outputs.front().status; + current_worker->outputs.front().output.swap(*out_tensors); + current_worker->outputs.pop_front(); + current_worker->cond_var.notify_one(); + return s; + } else if (current_worker->is_producing && !dataset()->sloppy_) { + // current_worker.outputs.empty(), and we must wait for this + // iterator. + if (next_index_ != index) { + // We have advanced to a new iterator; reset block counts. + next_index_ = index; + block_count_ = 0; + } + break; + } else if (!current_worker->is_producing) { + // This iterator has reached end of input. + interleave_[index] = nullptr; + if (input_impl_) { + // Start prefetching a new iterator. + std::vector args; + bool end_of_input = false; + Status s = input_impl_->GetNext(ctx, &args, &end_of_input); + if (end_of_input) { + input_impl_.reset(); + } else { + current_worker->SetInputs(s, std::move(args)); + staging_.emplace_back(current_worker); + } + } + + if (!staging_.empty()) { + // Move a worker from `staging_` to `interleave_`. + interleave_[index] = staging_.front(); + staging_.pop_front(); + + next_index_ = (index + 1) % interleave_.size(); + block_count_ = 0; + // Restart the inner [for] loop + can_produce_elements = true; + must_wait_for_input = false; + break; } - output_elements_[index].is_produced = false; - output_elements_[index].cond_var.notify_one(); - return output_elements_[index].output_status; } } - if (num_active_threads_ == 0) { + if (!can_produce_elements && !input_impl_) { // No potential for future values. - // - // Note: this condition check must occur after checking the output - // buffer, as its possible for there to be values in the output - // buffer, even if the number of live threads is zero. *end_of_sequence = true; return Status::OK(); } - // If we are not allowed to be sloppy and - // `worker_threads_[next_index]` has finished, advance `next_index`. - if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) { - next_index_ = (next_index_ + 1) % num_workers; - continue; + if (must_wait_for_input) { + // Wait for elements to become available. + if (dataset()->sloppy_) { + sloppy_cond_var_.wait(l); + } else { + interleave_[next_index_]->cond_var.wait(l); + } } - - // No values available; wait until woken up. - // TODO(jsimsa): Use slot-specific condition variable for - // coordination of elements consumption. - cond_var_.wait(l); } return errors::Cancelled( "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext"); } private: - // Internal structure to manage thread coordination. All values are - // guarded by the enclosing Iterator's mu_. - struct OutputBufferElement { - // The producer must set `is_produced` to `true` after - // `output_status` or `output_value` has been written. - bool is_produced = false; - // The producer sets `output_status` if either getting the input element - // or applying the function to it fails. - Status output_status; - // Reached end of sequence for the underlying iterator. - bool end_of_sequence = false; - // The output data element. - std::vector output_value; - // The producer thread waits on this condition variable after having - // produced an element. The reader thread notifies this condition - // variable after reading the value. - condition_variable cond_var; + // OutputElem contains the information from a call to GetNext by an output + // iterator. + struct OutputElem { + // The output iterator sets `status` if getting the output element + // fails. + Status status; + // The buffered data element. + std::vector output; + + explicit OutputElem(const Status& s) : status(s) {} }; - struct ThreadStatus { - // The underlying thread uses `finished` to communicate to the producer - // that it has finished. - bool finished = false; - // The underlying thread object. - std::unique_ptr thread; + // Worker threads operate on their relevant WorkerState structs. + // + // WorkerState's fields are all protected by mu_; + struct WorkerState { + // The arguments to be used to construct an output iterator. + std::vector input; + // The buffered output elements. + std::deque outputs; + // Set to true iff the worker thread expects to append more elements to + // outputs. is_producing can be false despite !outputs.empty(). + // Concretely, all output elements will have been consumed only when: + // is_producing == false && outputs.empty(); + bool is_producing = false; + // Condition variable used to coordinate between threads. The worker + // thread waits on this condition variable when it is either (1) waiting + // for the main thread to add arguments to `input`, or (2) waiting for + // the main thread to consume an element of `outputs`. The main thread + // waits on cond_var if it is waiting for the worker thread to produce + // an element into `outputs` (this implies sloppy_==false). + condition_variable cond_var; - explicit ThreadStatus(Thread* thread) : thread(thread) {} + inline bool MayHaveElements() const { + return is_producing || !outputs.empty(); + } + + // Sets inputs for a worker thread and notifies it to start processing. + void SetInputs(const Status& s, std::vector input_arguments) { + if (s.ok()) { + DCHECK(!MayHaveElements()) + << "Tried to start inputs, despite already producing!"; + input = std::move(input_arguments); + is_producing = true; + cond_var.notify_one(); + } else { + outputs.emplace_back(s); + } + } }; Status EnsureWorkerThreadsStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (worker_threads_.empty()) { - for (int64 i = 0; i < dataset()->cycle_length_; ++i) { - // Serialize the creation of the workers and their corresponding - // input elements to ensure we match the standard interleave when - // the underlying iterators induce no delay. + worker_threads_.reserve(dataset()->num_threads()); + for (int64 i = 0; i < dataset()->num_threads(); ++i) { std::vector args; - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, &args, &end_of_input_)); - if (end_of_input_) { - LOG(WARNING) << "Input iterator exhausted after " << i - << " elements; cannot start all " - << dataset()->cycle_length_ << " worker threads."; + bool end_of_input = false; + Status s = input_impl_->GetNext(ctx, &args, &end_of_input); + if (end_of_input) { + input_impl_.reset(); return Status::OK(); } - std::unique_ptr itr; - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( - ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr)); + workers_[i].SetInputs(s, std::move(args)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i, itr.release()))); - num_active_threads_ = i + 1; + new IteratorContext(*ctx), i))); + if (i < dataset()->cycle_length_) { + interleave_.push_back(&workers_[i]); + } else { + staging_.push_back(&workers_[i]); + } } + DCHECK(interleave_.size() == dataset()->cycle_length_); + DCHECK(staging_.size() == dataset()->prefetch_input_elements_); } return Status::OK(); } - void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index, - const Status& status, - bool end_of_sequence, - std::vector* out_tensors) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // We have produced an element; push it into the output buffer - // when space is available. - while (!cancelled_ && output_elements_[thread_index].is_produced) { - output_elements_[thread_index].cond_var.wait(*l); - } - if (cancelled_) { - return; - } - output_elements_[thread_index].is_produced = true; - output_elements_[thread_index].output_status = status; - output_elements_[thread_index].end_of_sequence = end_of_sequence; - if (status.ok()) { - output_elements_[thread_index].output_value.swap(*out_tensors); - } else { - output_elements_[thread_index].output_value.clear(); - } - cond_var_.notify_one(); - } - - // Races to produce elements into the output queue buffers. - void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index, - IteratorBase* out_iterator_ptr) { + // Produces elements into the worker's output buffers. + void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) { // std::function arguments are copy-constructable, so we pass raw // pointers, and then immediately wrap them to ensure correct ownership. std::unique_ptr ctx(ctx_ptr); - std::unique_ptr out_iterator(out_iterator_ptr); auto cleanup = gtl::MakeCleanup([this, thread_index] { mutex_lock l(mu_); - worker_threads_[thread_index].finished = true; - num_active_threads_--; - cond_var_.notify_all(); + workers_[thread_index].cond_var.notify_all(); }); + while (true) { - // Attempt to produce an element. - bool end_of_out_itr_input = false; - std::vector out_tensors; - Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors, - &end_of_out_itr_input); - // Handle output. + // 1. Wait for input. + std::vector input; { mutex_lock l(mu_); - BlockAndUpdateOutputBuffer(&l, thread_index, element_status, - end_of_out_itr_input, &out_tensors); - if (end_of_out_itr_input) { - // We have exhausted our current iterator; get a new iterator; - // loop to handle errors. - while (!cancelled_) { - if (end_of_input_) { - // No more iterator inputs; we're done! - return; - } - std::vector args; - // BlockAndUpdateOutputBuffer() sequences calls to - // input_impl_->GetNext when the out_iterator doesn't cause - // slopping. - Status input_status = - input_impl_->GetNext(ctx.get(), &args, &end_of_input_); - if (end_of_input_) { - // No more elements to produce, stop the worker thread. - return; - } - if (input_status.ok()) { - input_status = dataset::MakeIteratorFromInputElement( - ctx.get(), args, thread_index, - dataset()->captured_func_.get(), prefix(), &out_iterator); - } - if (input_status.ok()) { - // Successfully have a new out_iterator; restart the outer - // loop to produce an element. - break; - } - - // We encountered an error; push the error to the output buffer. - BlockAndUpdateOutputBuffer(&l, thread_index, input_status, - /* end_of_sequence = */ false, - &out_tensors); - } + while (!cancelled_ && !workers_[thread_index].is_producing) { + workers_[thread_index].cond_var.wait(l); } + if (cancelled_) return; + input.swap(workers_[thread_index].input); + } - // Check if we should exit. - if (cancelled_) { - return; + // 2. Run the user defined function to produce a new iterator. + std::unique_ptr iterator; + Status s = dataset::MakeIteratorFromInputElement( + ctx.get(), input, thread_index, dataset()->captured_func_.get(), + prefix(), &iterator); + input.clear(); // Release memory as early as possible. + + if (!s.ok()) { + mutex_lock l(mu_); + workers_[thread_index].outputs.emplace_back(s); + workers_[thread_index].is_producing = false; + workers_[thread_index].cond_var.notify_one(); + } else { + // 3. Produce elements + bool end_of_sequence = false; + while (!end_of_sequence) { + // 3.a Produce an element! + std::vector output_elem; + s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence); + + // 3.b Make it available to the client. + { + mutex_lock l(mu_); + + // Wait for space in the prefetch queue. + while (!cancelled_ && workers_[thread_index].outputs.size() == + dataset()->buffer_output_elements_) { + workers_[thread_index].cond_var.wait(l); + } + if (cancelled_) return; + + // Output the element. + workers_[thread_index].is_producing = !end_of_sequence; + if (!end_of_sequence) { + workers_[thread_index].outputs.emplace_back(s); + workers_[thread_index].outputs.back().output.swap( + output_elem); + } + if (dataset()->sloppy_) { + sloppy_cond_var_.notify_one(); + } else { + workers_[thread_index].cond_var.notify_one(); + } + } } } } @@ -355,27 +449,34 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Mutex & condition variable to guard mutable iterator internals and // coordinate among worker threads and client thread[s]. mutex mu_; - condition_variable cond_var_; + // The main thread waits on this condition variable if running in sloppy + // mode and no values are available. + condition_variable sloppy_cond_var_; + // The iterator producing elements which are converted to datasets by // the dataset()->captured_func_ then interleaved together. - const std::unique_ptr input_impl_ GUARDED_BY(mu_); - // Whether the input_impl_ can produce future elements. - bool end_of_input_ GUARDED_BY(mu_) = false; - // The buffer of elements to be produced. Each worker thread operates - // on a single OutputBufferElement. - std::vector output_elements_ GUARDED_BY(mu_); + // input_impl_ is reset when we have exhausted its input. + std::unique_ptr input_impl_ GUARDED_BY(mu_); + + // The WorkerState structs the worker threads operate on. + // workers_ elements are in at most one of interleave_ and staging_. + std::vector workers_ GUARDED_BY(mu_); + + // The iterators to interleave + std::vector interleave_ GUARDED_BY(mu_); + // Prefetched iterators + std::deque staging_ GUARDED_BY(mu_); + // The index into output_elements_ for next element to produce. size_t next_index_ GUARDED_BY(mu_) = 0; // The number of items produced so far within the block size_t block_count_ GUARDED_BY(mu_) = 0; - // Number of active threads. - size_t num_active_threads_ GUARDED_BY(mu_) = 0; // Flag to instruct the worker threads to exit. bool cancelled_ GUARDED_BY(mu_) = false; - // Pointers to the worker threads. This must be last to ensure the + // The worker threads. This must be last to ensure the // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. - std::vector worker_threads_ GUARDED_BY(mu_); + std::vector> worker_threads_ GUARDED_BY(mu_); }; const DatasetBase* const input_; @@ -383,6 +484,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const int64 cycle_length_; const int64 block_length_; const bool sloppy_; + const int64 buffer_output_elements_; + const int64 prefetch_input_elements_; const DataTypeVector output_types_; const std::vector output_shapes_; }; diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 18ac55ec53c..bd420509f13 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -27387,6 +27387,14 @@ op { name: "sloppy" type: DT_BOOL } + input_arg { + name: "buffer_output_elements" + type: DT_INT64 + } + input_arg { + name: "prefetch_input_elements" + type: DT_INT64 + } output_arg { name: "handle" type: DT_VARIANT diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 66e5c163288..e943a698ae2 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -313,6 +313,8 @@ REGISTER_OP("ParallelInterleaveDataset") .Input("cycle_length: int64") .Input("block_length: int64") .Input("sloppy: bool") + .Input("buffer_output_elements: int64") + .Input("prefetch_input_elements: int64") .Output("handle: variant") .Attr("f: func") .Attr("Targuments: list(type) >= 0") diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 695d3ef7904..f12b358a7dc 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -34,11 +34,11 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", - "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/util:convert", ], ) diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index c6fb8531aea..830dc5cec4a 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops.dataset_ops import Dataset -from tensorflow.python.framework import constant_op +from tensorflow.python.data.util import convert from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,18 +29,6 @@ from tensorflow.python.ops import gen_dataset_ops _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB -def _convert_optional_param_to_tensor(argument_name, - argument_value, - argument_default=0, - argument_dtype=dtypes.int64): - if argument_value is not None: - return ops.convert_to_tensor( - argument_value, dtype=argument_dtype, name=argument_name) - else: - return constant_op.constant( - argument_default, dtype=argument_dtype, name=argument_name) - - class TextLineDataset(Dataset): """A `Dataset` comprising lines from one or more text files.""" @@ -58,12 +46,12 @@ class TextLineDataset(Dataset): super(TextLineDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - self._compression_type = _convert_optional_param_to_tensor( + self._compression_type = convert.optional_param_to_tensor( "compression_type", compression_type, argument_default="", argument_dtype=dtypes.string) - self._buffer_size = _convert_optional_param_to_tensor( + self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) def _as_variant_tensor(self): @@ -100,12 +88,12 @@ class TFRecordDataset(Dataset): # Force the type to string even if filenames is an empty list. self._filenames = ops.convert_to_tensor( filenames, dtypes.string, name="filenames") - self._compression_type = _convert_optional_param_to_tensor( + self._compression_type = convert.optional_param_to_tensor( "compression_type", compression_type, argument_default="", argument_dtype=dtypes.string) - self._buffer_size = _convert_optional_param_to_tensor( + self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) @@ -155,11 +143,11 @@ class FixedLengthRecordDataset(Dataset): self._record_bytes = ops.convert_to_tensor( record_bytes, dtype=dtypes.int64, name="record_bytes") - self._header_bytes = _convert_optional_param_to_tensor( + self._header_bytes = convert.optional_param_to_tensor( "header_bytes", header_bytes) - self._footer_bytes = _convert_optional_param_to_tensor( + self._footer_bytes = convert.optional_param_to_tensor( "footer_bytes", footer_bytes) - self._buffer_size = _convert_optional_param_to_tensor( + self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) def _as_variant_tensor(self): diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index f7d7fe98d3e..e32c7b54a48 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -62,6 +62,30 @@ py_test( ], ) +py_library( + name = "convert", + srcs = ["convert.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + ], +) + +py_test( + name = "convert_test", + size = "small", + srcs = ["convert_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":convert", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:util", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/data/util/convert.py b/tensorflow/python/data/util/convert.py new file mode 100644 index 00000000000..eeb1d700f3c --- /dev/null +++ b/tensorflow/python/data/util/convert.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers constructing Datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + + +def optional_param_to_tensor(argument_name, + argument_value, + argument_default=0, + argument_dtype=dtypes.int64): + if argument_value is not None: + return ops.convert_to_tensor( + argument_value, dtype=argument_dtype, name=argument_name) + else: + return constant_op.constant( + argument_default, dtype=argument_dtype, name=argument_name) diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py new file mode 100644 index 00000000000..2cb6488070e --- /dev/null +++ b/tensorflow/python/data/util/convert_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities working with user input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import convert +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ConvertTest(test.TestCase): + + def testInteger(self): + resp = convert.optional_param_to_tensor("foo", 3) + with self.test_session() as sess: + self.assertEqual(3, sess.run(resp)) + + def testIntegerDefault(self): + resp = convert.optional_param_to_tensor("foo", None) + with self.test_session() as sess: + self.assertEqual(0, sess.run(resp)) + + def testStringDefault(self): + resp = convert.optional_param_to_tensor("bar", None, "default", + dtypes.string) + with self.test_session() as sess: + self.assertEqual(compat.as_bytes("default"), sess.run(resp)) + + def testString(self): + resp = convert.optional_param_to_tensor("bar", "value", "default", + dtypes.string) + with self.test_session() as sess: + self.assertEqual(compat.as_bytes("value"), sess.run(resp)) + + +if __name__ == "__main__": + test.main()