Add prefetching into parallel_interleave
This change adds 2 parameters to parallel_interleave: - prefetch_input_elements: determines the number of iterators to prefetch allowing buffers to warm up and data to be pre-fetched without blocking the main thread (i.e. the GetNext() call). - buffer_output_elements: in order to avoid creating thousands of threads, we fuse in the .prefetch() operator as an additional parameter. The value of this parameter is identical to the value passed to `.prefetch()` PiperOrigin-RevId: 179726088
This commit is contained in:
parent
7d94d7672a
commit
76db97fe39
@ -41,6 +41,7 @@ from tensorflow.python.platform import test
|
||||
class InterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def _interleave(self, lists, cycle_length, block_length):
|
||||
# TODO(b/69678297): Consolidate python interleave implementations.
|
||||
num_open = 0
|
||||
|
||||
# `all_iterators` acts as a queue of iterators over each element of `lists`.
|
||||
@ -255,11 +256,15 @@ class InterleaveDatasetSeriazationTest(
|
||||
class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
|
||||
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
|
||||
self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
self.error = None
|
||||
self.repeat_count = 2
|
||||
|
||||
# Set up threading events used to sequence when items are produced that
|
||||
@ -276,6 +281,10 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.write_coordination_events[x].wait()
|
||||
self.write_coordination_events[x].clear()
|
||||
self.read_coordination_events[x].release()
|
||||
if self.error:
|
||||
err = self.error
|
||||
self.error = None
|
||||
raise err # pylint: disable=raising-bad-type
|
||||
return x * x
|
||||
|
||||
def map_fn(x):
|
||||
@ -286,11 +295,13 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
dataset = dataset.repeat(x)
|
||||
return dataset.map(map_fn)
|
||||
|
||||
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
||||
.repeat(self.repeat_count).apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
interleave_fn, self.cycle_length,
|
||||
self.block_length, self.sloppy)))
|
||||
self.dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
||||
.repeat(self.repeat_count).apply(
|
||||
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
|
||||
self.block_length, self.sloppy,
|
||||
self.buffer_output_elements,
|
||||
self.prefetch_input_elements)))
|
||||
self.iterator = self.dataset.make_initializable_iterator()
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
@ -380,7 +391,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
for i in range(4, 7):
|
||||
self.write_coordination_events[i].set()
|
||||
|
||||
def _testSingleThreaded(self, sloppy=False):
|
||||
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
|
||||
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
|
||||
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
|
||||
with self.test_session() as sess:
|
||||
@ -391,7 +402,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 1,
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: prefetch_input_elements,
|
||||
})
|
||||
|
||||
for expected_element in self._interleave(
|
||||
@ -408,6 +421,41 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def testSingleThreadedSloppy(self):
|
||||
self._testSingleThreaded(sloppy=True)
|
||||
|
||||
def testSingleThreadedPrefetch1Itr(self):
|
||||
self._testSingleThreaded(prefetch_input_elements=1)
|
||||
|
||||
def testSingleThreadedPrefetch1ItrSloppy(self):
|
||||
self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
|
||||
|
||||
def testSingleThreadedRagged(self):
|
||||
# Tests a sequence with wildly different elements per iterator.
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [3, 7, 4],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: False,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
|
||||
# Add coordination values for 3 and 7
|
||||
self.read_coordination_events[3] = threading.Semaphore(0)
|
||||
self.write_coordination_events[3] = threading.Event()
|
||||
self.read_coordination_events[7] = threading.Semaphore(0)
|
||||
self.write_coordination_events[7] = threading.Event()
|
||||
|
||||
for expected_element in self._interleave(
|
||||
[[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
output = sess.run(self.next_element)
|
||||
self.assertEqual(expected_element * expected_element, output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def _testTwoThreadsNoContention(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||
@ -420,7 +468,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -463,6 +513,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -502,7 +554,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -545,7 +599,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -583,7 +639,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
@ -604,7 +662,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [0, 0, 0],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
@ -615,7 +675,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
|
||||
|
||||
def _testPartiallyEmptyOutputs(self, sloppy=False):
|
||||
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
|
||||
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
|
||||
# Mixture of non-empty and empty interleaved datasets.
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
@ -627,27 +688,31 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: prefetch_input_elements,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
if done_first_event: # First event starts the worker threads
|
||||
# First event starts the worker threads. Additionally, when running the
|
||||
# sloppy case with prefetch_input_elements=0, we get stuck if we wait
|
||||
# for the read coordination event for certain event orderings in the
|
||||
# presence of finishing iterators.
|
||||
if done_first_event and not (sloppy and (i in race_indices)):
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
actual_element = sess.run(self.next_element)
|
||||
if not done_first_event:
|
||||
if not done_first_event or (sloppy and (i in race_indices)):
|
||||
done_first_event = True
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
self.assertEqual(expected_element * expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testPartiallyEmptyOutputs(self):
|
||||
self._testPartiallyEmptyOutputs()
|
||||
|
||||
def testPartiallyEmptyOutputsSloppy(self):
|
||||
self._testPartiallyEmptyOutputs(sloppy=True)
|
||||
self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
|
||||
|
||||
def testDelayedOutputSloppy(self):
|
||||
# Explicitly control the sequence of events to ensure we correctly avoid
|
||||
@ -661,6 +726,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: True,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
|
||||
mis_ordering = [
|
||||
@ -683,8 +750,10 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3,
|
||||
self.sloppy: True
|
||||
self.block_length: 1,
|
||||
self.sloppy: True,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 1,
|
||||
})
|
||||
# Test against a generating sequence that differs from the uncontended
|
||||
# case, in order to prove sloppy correctness.
|
||||
@ -692,7 +761,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self._interleave(
|
||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
|
||||
cycle_length=2,
|
||||
block_length=2)):
|
||||
block_length=3)):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
@ -716,7 +785,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 3,
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
self.sloppy: sloppy,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
for i in range(4, 7):
|
||||
self.write_coordination_events[i].set()
|
||||
@ -790,6 +861,139 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testErrorsInOutputFn(self):
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: False,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
|
||||
except_on_element_indices = set([3])
|
||||
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
1)):
|
||||
if i in except_on_element_indices:
|
||||
self.error = ValueError()
|
||||
self.write_coordination_events[expected_element].set()
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
actual_element = sess.run(self.next_element)
|
||||
self.assertEqual(expected_element * expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testErrorsInInputFn(self):
|
||||
|
||||
def map_py_fn(x):
|
||||
if x == 5:
|
||||
raise ValueError()
|
||||
return x
|
||||
|
||||
def map_fn(x):
|
||||
return script_ops.py_func(map_py_fn, [x], x.dtype)
|
||||
|
||||
def interleave_fn(x):
|
||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||
dataset = dataset.repeat(x)
|
||||
return dataset
|
||||
|
||||
self.dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
|
||||
.repeat(self.repeat_count).apply(
|
||||
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
|
||||
self.block_length, self.sloppy,
|
||||
self.buffer_output_elements,
|
||||
self.prefetch_input_elements)))
|
||||
|
||||
self.iterator = self.dataset.make_initializable_iterator()
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: False,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
if expected_element == 5:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
else:
|
||||
actual_element = sess.run(self.next_element)
|
||||
self.assertEqual(expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testErrorsInInterleaveFn(self):
|
||||
|
||||
def map_py_fn(x):
|
||||
if x == 5:
|
||||
raise ValueError()
|
||||
return x
|
||||
|
||||
def interleave_fn(x):
|
||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||
y = script_ops.py_func(map_py_fn, [x], x.dtype)
|
||||
dataset = dataset.repeat(y)
|
||||
return dataset
|
||||
|
||||
self.dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
||||
.repeat(self.repeat_count).apply(
|
||||
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
|
||||
self.block_length, self.sloppy,
|
||||
self.buffer_output_elements,
|
||||
self.prefetch_input_elements)))
|
||||
|
||||
self.iterator = self.dataset.make_initializable_iterator()
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: False,
|
||||
self.buffer_output_elements: 1,
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
if expected_element == 5:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
else:
|
||||
actual_element = sess.run(self.next_element)
|
||||
self.assertEqual(expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -122,6 +122,7 @@ py_library(
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:convert",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,7 +32,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy):
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.contrib.data.parallel_interleave()` for details."""
|
||||
super(ParallelInterleaveDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
@ -74,6 +75,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
argument_default=2 * block_length)
|
||||
self._prefetch_input_elements = convert.optional_param_to_tensor(
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.parallel_interleave_dataset(
|
||||
@ -82,6 +91,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func,
|
||||
output_types=nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)),
|
||||
@ -101,7 +112,12 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||
return self._output_types
|
||||
|
||||
|
||||
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
|
||||
def parallel_interleave(map_func,
|
||||
cycle_length,
|
||||
block_length=1,
|
||||
sloppy=False,
|
||||
buffer_output_elements=None,
|
||||
prefetch_input_elements=None):
|
||||
"""A parallel version of the `Dataset.interleave()` transformation.
|
||||
|
||||
`parallel_interleave()` maps `map_func` across its input to produce nested
|
||||
@ -129,12 +145,17 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a nested structure of tensors to a `Dataset`.
|
||||
cycle_length: The number of threads to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from a thread
|
||||
before advancing to the next thread.
|
||||
cycle_length: The number of input `Dataset`s to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from an input
|
||||
`Dataset` before advancing to the next input `Dataset`.
|
||||
sloppy: If false, elements are produced in deterministic order. Otherwise,
|
||||
the implementation is allowed, for the sake of expediency, to produce
|
||||
elements in a non-deterministic order.
|
||||
buffer_output_elements: The number of elements each iterator being
|
||||
interleaved should buffer (similar to the `.prefetch()` transformation for
|
||||
each interleaved iterator).
|
||||
prefetch_input_elements: The number of input elements to transform to
|
||||
iterators before they are needed for interleaving.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@ -142,7 +163,9 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy)
|
||||
dataset, map_func, cycle_length, block_length, sloppy,
|
||||
buffer_output_elements, prefetch_input_elements)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
@ -187,11 +210,11 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
|
||||
map_func: A function mapping a nested structure of tensors (having shapes
|
||||
and types defined by `self.output_shapes` and `self.output_types`) to a
|
||||
`Dataset`.
|
||||
cycle_length: The number of threads to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from a thread
|
||||
before advancing to the next thread. Note: sloppy_interleave will
|
||||
skip the remainder of elements in the block_length in order to avoid
|
||||
blocking.
|
||||
cycle_length: The number of input `Dataset`s to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from an input
|
||||
`Dataset` before advancing to the next input `Dataset`. Note:
|
||||
`sloppy_interleave` will skip the remainder of elements in the
|
||||
`block_length` in order to avoid blocking.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@ -199,5 +222,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy=True)
|
||||
dataset,
|
||||
map_func,
|
||||
cycle_length,
|
||||
block_length,
|
||||
sloppy=True,
|
||||
buffer_output_elements=None,
|
||||
prefetch_input_elements=None)
|
||||
|
||||
return _apply_fn
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -48,28 +50,44 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.push_back(t);
|
||||
}
|
||||
|
||||
int64 cycle_length;
|
||||
int64 cycle_length = 0;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
|
||||
OP_REQUIRES(ctx, cycle_length > 0,
|
||||
errors::InvalidArgument("`cycle_length` must be > 0"));
|
||||
|
||||
int64 block_length;
|
||||
int64 block_length = 0;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument(ctx, "block_length", &block_length));
|
||||
OP_REQUIRES(ctx, block_length > 0,
|
||||
errors::InvalidArgument("`block_length` must be > 0"));
|
||||
|
||||
bool sloppy;
|
||||
bool sloppy = false;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
|
||||
|
||||
int64 buffer_output_elements = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
|
||||
&buffer_output_elements));
|
||||
OP_REQUIRES(
|
||||
ctx, buffer_output_elements > 0,
|
||||
errors::InvalidArgument("`buffer_output_elements` must be > 0"));
|
||||
|
||||
int64 prefetch_input_elements = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
|
||||
&prefetch_input_elements));
|
||||
OP_REQUIRES(
|
||||
ctx, prefetch_input_elements >= 0,
|
||||
errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
|
||||
std::move(other_arguments),
|
||||
&captured_func));
|
||||
|
||||
*output = new Dataset(input, std::move(captured_func), cycle_length,
|
||||
block_length, sloppy, output_types_, output_shapes_);
|
||||
*output =
|
||||
new Dataset(input, std::move(captured_func), cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements,
|
||||
output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -77,13 +95,16 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
Dataset(const DatasetBase* input,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, bool sloppy, const DataTypeVector& output_types,
|
||||
int64 block_length, bool sloppy, int64 buffer_output_elements,
|
||||
int64 prefetch_input_elements, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: input_(input),
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
sloppy_(sloppy),
|
||||
buffer_output_elements_(buffer_output_elements),
|
||||
prefetch_input_elements_(prefetch_input_elements),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
input_->Ref();
|
||||
@ -109,244 +130,317 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_threads() const {
|
||||
return cycle_length_ + prefetch_input_elements_;
|
||||
}
|
||||
|
||||
// Parallel interleave's implementation is designed around a few principles:
|
||||
// 1. Thread creation is relatively expensive. (Not reusing
|
||||
// threads causes a number of indirect costs such as poorer tcmalloc
|
||||
// performance due to thread-local caches, etc.) We allocate a fixed
|
||||
// number of threads at the start and never change. This is why we've
|
||||
// fused functionality that is theoretically orthogonal (i.e.
|
||||
// .prefetch()) into the implementation.
|
||||
// 2. Drop-in replacement for standard interleave. The goal will be to
|
||||
// auto-opt people into an optimized implementation without any work
|
||||
// on the customer's part. We thus go through great pains to maintain
|
||||
// identical iteration orders, full determinism (disabled only via a
|
||||
// flag, etc.)
|
||||
// 3. Performance across a variety of environments and I/O envelopes.
|
||||
//
|
||||
// The actual implementation centers around a collection of worker threads
|
||||
// and their corresponding worker state (tracked in the `workers_` vector).
|
||||
// Worker threads repeatedly receive a vector of Tensors that are used as
|
||||
// input to the flat-map function (`captured_func_`). The output of this
|
||||
// function must be a dataset. The worker thread then repeatedly calls
|
||||
// `GetNext()`, maintaining a buffer of elements to minimize the likelihood
|
||||
// that a caller will block waiting for an element to be produced.
|
||||
//
|
||||
// Pointers to these worker states are kept in 2 disjoint data structures:
|
||||
// 1. `interleave_` is a vector containing pointers to `WorkerState`s that
|
||||
// we
|
||||
// are interleaving. Worker threads backing these WorkerStates should
|
||||
// be regularly producing values.
|
||||
// 2. `staging_` is a deque containing pointers to WorkerStates that we
|
||||
// will move to `interleave_` when an iterator in `interleave_` is
|
||||
// exhausted.
|
||||
//
|
||||
// The client calls `GetNext[Internal]()` to retrieve an output element. The
|
||||
// internal implementation updates the state of `interleave_` and `staging_`
|
||||
// as output iterators (run by the worker threads) are exhausted.
|
||||
//
|
||||
// `input_impl_` is the input iterator that generates arguments for the
|
||||
// flat-map function (`captured_func_`). It is set to an iterator at
|
||||
// Iterator construction, and is fixed until we consume all input elements.
|
||||
// Once it is exhausted, we reset the unique_ptr to eagerly deallocate
|
||||
// memory.
|
||||
//
|
||||
// A few invariants are maintained:
|
||||
// 1. No element in interleave_ should be a nullptr unless `staging_` is
|
||||
// empty and `input_impl_` is empty.
|
||||
// 2. Every `worker_` element is pointed to by at most one element of the
|
||||
// union of `interleave_` and `staging_`.
|
||||
// 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
|
||||
// an element in `interleave_` or `staging_`.
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
|
||||
output_elements_(params.dataset->cycle_length_) {}
|
||||
workers_(dataset()->num_threads()) {}
|
||||
|
||||
~Iterator() override {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
// Notify all workers in case they are blocked.
|
||||
for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
output_elements_[i].cond_var.notify_all();
|
||||
for (auto& worker : workers_) {
|
||||
worker.cond_var.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
// It is implemented so that it matches the deterministic interleave
|
||||
// unless we would block waiting for an element, at which point it skips
|
||||
// along to the next available value.
|
||||
// unless getting the next element would block and we are allowed to be
|
||||
// sloppy.
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||
const int64 num_workers = worker_threads_.size();
|
||||
if (num_workers == 0) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
while (!cancelled_) {
|
||||
// Wait for an item to become available, blocking if necessary. If we
|
||||
// are allowed to be sloppy, we can skip over input datasets that do
|
||||
// not have an item readily available.
|
||||
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
|
||||
for (int64 i = 0; i < n; ++i) {
|
||||
int64 index = (next_index_ + i) % num_workers;
|
||||
if (output_elements_[index].is_produced) {
|
||||
bool can_produce_elements = false;
|
||||
bool must_wait_for_input = true;
|
||||
for (int64 i = 0; i < interleave_.size(); ++i) {
|
||||
int64 index = (next_index_ + i) % interleave_.size();
|
||||
WorkerState* current_worker = interleave_[index];
|
||||
if (!current_worker) continue; // Empty interleave elements.
|
||||
can_produce_elements |= current_worker->MayHaveElements();
|
||||
if (!current_worker->outputs.empty()) {
|
||||
// We have an element!
|
||||
next_index_ = index;
|
||||
if (i == 0) {
|
||||
block_count_++;
|
||||
if (block_count_ == dataset()->block_length_) {
|
||||
next_index_ = (index + 1) % num_workers;
|
||||
next_index_ = (index + 1) % interleave_.size();
|
||||
block_count_ = 0;
|
||||
}
|
||||
} else {
|
||||
block_count_ = 0;
|
||||
}
|
||||
// If we encounter an EoF, advance to the next iterator
|
||||
if (output_elements_[index].end_of_sequence) {
|
||||
output_elements_[index].is_produced = false;
|
||||
output_elements_[index].cond_var.notify_one();
|
||||
next_index_ = (index + 1) % num_workers;
|
||||
block_count_ = 0;
|
||||
i = -1; // Restart the inner loop
|
||||
continue;
|
||||
}
|
||||
*end_of_sequence = false;
|
||||
if (output_elements_[index].output_status.ok()) {
|
||||
output_elements_[index].output_value.swap(*out_tensors);
|
||||
Status s = current_worker->outputs.front().status;
|
||||
current_worker->outputs.front().output.swap(*out_tensors);
|
||||
current_worker->outputs.pop_front();
|
||||
current_worker->cond_var.notify_one();
|
||||
return s;
|
||||
} else if (current_worker->is_producing && !dataset()->sloppy_) {
|
||||
// current_worker.outputs.empty(), and we must wait for this
|
||||
// iterator.
|
||||
if (next_index_ != index) {
|
||||
// We have advanced to a new iterator; reset block counts.
|
||||
next_index_ = index;
|
||||
block_count_ = 0;
|
||||
}
|
||||
break;
|
||||
} else if (!current_worker->is_producing) {
|
||||
// This iterator has reached end of input.
|
||||
interleave_[index] = nullptr;
|
||||
if (input_impl_) {
|
||||
// Start prefetching a new iterator.
|
||||
std::vector<Tensor> args;
|
||||
bool end_of_input = false;
|
||||
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
|
||||
if (end_of_input) {
|
||||
input_impl_.reset();
|
||||
} else {
|
||||
current_worker->SetInputs(s, std::move(args));
|
||||
staging_.emplace_back(current_worker);
|
||||
}
|
||||
}
|
||||
|
||||
if (!staging_.empty()) {
|
||||
// Move a worker from `staging_` to `interleave_`.
|
||||
interleave_[index] = staging_.front();
|
||||
staging_.pop_front();
|
||||
|
||||
next_index_ = (index + 1) % interleave_.size();
|
||||
block_count_ = 0;
|
||||
// Restart the inner [for] loop
|
||||
can_produce_elements = true;
|
||||
must_wait_for_input = false;
|
||||
break;
|
||||
}
|
||||
output_elements_[index].is_produced = false;
|
||||
output_elements_[index].cond_var.notify_one();
|
||||
return output_elements_[index].output_status;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_active_threads_ == 0) {
|
||||
if (!can_produce_elements && !input_impl_) {
|
||||
// No potential for future values.
|
||||
//
|
||||
// Note: this condition check must occur after checking the output
|
||||
// buffer, as its possible for there to be values in the output
|
||||
// buffer, even if the number of live threads is zero.
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If we are not allowed to be sloppy and
|
||||
// `worker_threads_[next_index]` has finished, advance `next_index`.
|
||||
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
|
||||
next_index_ = (next_index_ + 1) % num_workers;
|
||||
continue;
|
||||
if (must_wait_for_input) {
|
||||
// Wait for elements to become available.
|
||||
if (dataset()->sloppy_) {
|
||||
sloppy_cond_var_.wait(l);
|
||||
} else {
|
||||
interleave_[next_index_]->cond_var.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
// No values available; wait until woken up.
|
||||
// TODO(jsimsa): Use slot-specific condition variable for
|
||||
// coordination of elements consumption.
|
||||
cond_var_.wait(l);
|
||||
}
|
||||
return errors::Cancelled(
|
||||
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
|
||||
}
|
||||
|
||||
private:
|
||||
// Internal structure to manage thread coordination. All values are
|
||||
// guarded by the enclosing Iterator's mu_.
|
||||
struct OutputBufferElement {
|
||||
// The producer must set `is_produced` to `true` after
|
||||
// `output_status` or `output_value` has been written.
|
||||
bool is_produced = false;
|
||||
// The producer sets `output_status` if either getting the input element
|
||||
// or applying the function to it fails.
|
||||
Status output_status;
|
||||
// Reached end of sequence for the underlying iterator.
|
||||
bool end_of_sequence = false;
|
||||
// The output data element.
|
||||
std::vector<Tensor> output_value;
|
||||
// The producer thread waits on this condition variable after having
|
||||
// produced an element. The reader thread notifies this condition
|
||||
// variable after reading the value.
|
||||
condition_variable cond_var;
|
||||
// OutputElem contains the information from a call to GetNext by an output
|
||||
// iterator.
|
||||
struct OutputElem {
|
||||
// The output iterator sets `status` if getting the output element
|
||||
// fails.
|
||||
Status status;
|
||||
// The buffered data element.
|
||||
std::vector<Tensor> output;
|
||||
|
||||
explicit OutputElem(const Status& s) : status(s) {}
|
||||
};
|
||||
|
||||
struct ThreadStatus {
|
||||
// The underlying thread uses `finished` to communicate to the producer
|
||||
// that it has finished.
|
||||
bool finished = false;
|
||||
// The underlying thread object.
|
||||
std::unique_ptr<Thread> thread;
|
||||
// Worker threads operate on their relevant WorkerState structs.
|
||||
//
|
||||
// WorkerState's fields are all protected by mu_;
|
||||
struct WorkerState {
|
||||
// The arguments to be used to construct an output iterator.
|
||||
std::vector<Tensor> input;
|
||||
// The buffered output elements.
|
||||
std::deque<OutputElem> outputs;
|
||||
// Set to true iff the worker thread expects to append more elements to
|
||||
// outputs. is_producing can be false despite !outputs.empty().
|
||||
// Concretely, all output elements will have been consumed only when:
|
||||
// is_producing == false && outputs.empty();
|
||||
bool is_producing = false;
|
||||
// Condition variable used to coordinate between threads. The worker
|
||||
// thread waits on this condition variable when it is either (1) waiting
|
||||
// for the main thread to add arguments to `input`, or (2) waiting for
|
||||
// the main thread to consume an element of `outputs`. The main thread
|
||||
// waits on cond_var if it is waiting for the worker thread to produce
|
||||
// an element into `outputs` (this implies sloppy_==false).
|
||||
condition_variable cond_var;
|
||||
|
||||
explicit ThreadStatus(Thread* thread) : thread(thread) {}
|
||||
inline bool MayHaveElements() const {
|
||||
return is_producing || !outputs.empty();
|
||||
}
|
||||
|
||||
// Sets inputs for a worker thread and notifies it to start processing.
|
||||
void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
|
||||
if (s.ok()) {
|
||||
DCHECK(!MayHaveElements())
|
||||
<< "Tried to start inputs, despite already producing!";
|
||||
input = std::move(input_arguments);
|
||||
is_producing = true;
|
||||
cond_var.notify_one();
|
||||
} else {
|
||||
outputs.emplace_back(s);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (worker_threads_.empty()) {
|
||||
for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
// Serialize the creation of the workers and their corresponding
|
||||
// input elements to ensure we match the standard interleave when
|
||||
// the underlying iterators induce no delay.
|
||||
worker_threads_.reserve(dataset()->num_threads());
|
||||
for (int64 i = 0; i < dataset()->num_threads(); ++i) {
|
||||
std::vector<Tensor> args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, &args, &end_of_input_));
|
||||
if (end_of_input_) {
|
||||
LOG(WARNING) << "Input iterator exhausted after " << i
|
||||
<< " elements; cannot start all "
|
||||
<< dataset()->cycle_length_ << " worker threads.";
|
||||
bool end_of_input = false;
|
||||
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
|
||||
if (end_of_input) {
|
||||
input_impl_.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
std::unique_ptr<IteratorBase> itr;
|
||||
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
||||
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
|
||||
workers_[i].SetInputs(s, std::move(args));
|
||||
worker_threads_.emplace_back(ctx->env()->StartThread(
|
||||
{}, "worker_thread",
|
||||
std::bind(&Iterator::WorkerThread, this,
|
||||
new IteratorContext(*ctx), i, itr.release())));
|
||||
num_active_threads_ = i + 1;
|
||||
new IteratorContext(*ctx), i)));
|
||||
if (i < dataset()->cycle_length_) {
|
||||
interleave_.push_back(&workers_[i]);
|
||||
} else {
|
||||
staging_.push_back(&workers_[i]);
|
||||
}
|
||||
}
|
||||
DCHECK(interleave_.size() == dataset()->cycle_length_);
|
||||
DCHECK(staging_.size() == dataset()->prefetch_input_elements_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index,
|
||||
const Status& status,
|
||||
bool end_of_sequence,
|
||||
std::vector<Tensor>* out_tensors)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// We have produced an element; push it into the output buffer
|
||||
// when space is available.
|
||||
while (!cancelled_ && output_elements_[thread_index].is_produced) {
|
||||
output_elements_[thread_index].cond_var.wait(*l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
output_elements_[thread_index].is_produced = true;
|
||||
output_elements_[thread_index].output_status = status;
|
||||
output_elements_[thread_index].end_of_sequence = end_of_sequence;
|
||||
if (status.ok()) {
|
||||
output_elements_[thread_index].output_value.swap(*out_tensors);
|
||||
} else {
|
||||
output_elements_[thread_index].output_value.clear();
|
||||
}
|
||||
cond_var_.notify_one();
|
||||
}
|
||||
|
||||
// Races to produce elements into the output queue buffers.
|
||||
void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index,
|
||||
IteratorBase* out_iterator_ptr) {
|
||||
// Produces elements into the worker's output buffers.
|
||||
void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
|
||||
// std::function arguments are copy-constructable, so we pass raw
|
||||
// pointers, and then immediately wrap them to ensure correct ownership.
|
||||
std::unique_ptr<IteratorContext> ctx(ctx_ptr);
|
||||
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
|
||||
auto cleanup = gtl::MakeCleanup([this, thread_index] {
|
||||
mutex_lock l(mu_);
|
||||
worker_threads_[thread_index].finished = true;
|
||||
num_active_threads_--;
|
||||
cond_var_.notify_all();
|
||||
workers_[thread_index].cond_var.notify_all();
|
||||
});
|
||||
|
||||
while (true) {
|
||||
// Attempt to produce an element.
|
||||
bool end_of_out_itr_input = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors,
|
||||
&end_of_out_itr_input);
|
||||
// Handle output.
|
||||
// 1. Wait for input.
|
||||
std::vector<Tensor> input;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
BlockAndUpdateOutputBuffer(&l, thread_index, element_status,
|
||||
end_of_out_itr_input, &out_tensors);
|
||||
if (end_of_out_itr_input) {
|
||||
// We have exhausted our current iterator; get a new iterator;
|
||||
// loop to handle errors.
|
||||
while (!cancelled_) {
|
||||
if (end_of_input_) {
|
||||
// No more iterator inputs; we're done!
|
||||
return;
|
||||
}
|
||||
std::vector<Tensor> args;
|
||||
// BlockAndUpdateOutputBuffer() sequences calls to
|
||||
// input_impl_->GetNext when the out_iterator doesn't cause
|
||||
// slopping.
|
||||
Status input_status =
|
||||
input_impl_->GetNext(ctx.get(), &args, &end_of_input_);
|
||||
if (end_of_input_) {
|
||||
// No more elements to produce, stop the worker thread.
|
||||
return;
|
||||
}
|
||||
if (input_status.ok()) {
|
||||
input_status = dataset::MakeIteratorFromInputElement(
|
||||
ctx.get(), args, thread_index,
|
||||
dataset()->captured_func_.get(), prefix(), &out_iterator);
|
||||
}
|
||||
if (input_status.ok()) {
|
||||
// Successfully have a new out_iterator; restart the outer
|
||||
// loop to produce an element.
|
||||
break;
|
||||
}
|
||||
|
||||
// We encountered an error; push the error to the output buffer.
|
||||
BlockAndUpdateOutputBuffer(&l, thread_index, input_status,
|
||||
/* end_of_sequence = */ false,
|
||||
&out_tensors);
|
||||
}
|
||||
while (!cancelled_ && !workers_[thread_index].is_producing) {
|
||||
workers_[thread_index].cond_var.wait(l);
|
||||
}
|
||||
if (cancelled_) return;
|
||||
input.swap(workers_[thread_index].input);
|
||||
}
|
||||
|
||||
// Check if we should exit.
|
||||
if (cancelled_) {
|
||||
return;
|
||||
// 2. Run the user defined function to produce a new iterator.
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
Status s = dataset::MakeIteratorFromInputElement(
|
||||
ctx.get(), input, thread_index, dataset()->captured_func_.get(),
|
||||
prefix(), &iterator);
|
||||
input.clear(); // Release memory as early as possible.
|
||||
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
workers_[thread_index].outputs.emplace_back(s);
|
||||
workers_[thread_index].is_producing = false;
|
||||
workers_[thread_index].cond_var.notify_one();
|
||||
} else {
|
||||
// 3. Produce elements
|
||||
bool end_of_sequence = false;
|
||||
while (!end_of_sequence) {
|
||||
// 3.a Produce an element!
|
||||
std::vector<Tensor> output_elem;
|
||||
s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence);
|
||||
|
||||
// 3.b Make it available to the client.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// Wait for space in the prefetch queue.
|
||||
while (!cancelled_ && workers_[thread_index].outputs.size() ==
|
||||
dataset()->buffer_output_elements_) {
|
||||
workers_[thread_index].cond_var.wait(l);
|
||||
}
|
||||
if (cancelled_) return;
|
||||
|
||||
// Output the element.
|
||||
workers_[thread_index].is_producing = !end_of_sequence;
|
||||
if (!end_of_sequence) {
|
||||
workers_[thread_index].outputs.emplace_back(s);
|
||||
workers_[thread_index].outputs.back().output.swap(
|
||||
output_elem);
|
||||
}
|
||||
if (dataset()->sloppy_) {
|
||||
sloppy_cond_var_.notify_one();
|
||||
} else {
|
||||
workers_[thread_index].cond_var.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -355,27 +449,34 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Mutex & condition variable to guard mutable iterator internals and
|
||||
// coordinate among worker threads and client thread[s].
|
||||
mutex mu_;
|
||||
condition_variable cond_var_;
|
||||
// The main thread waits on this condition variable if running in sloppy
|
||||
// mode and no values are available.
|
||||
condition_variable sloppy_cond_var_;
|
||||
|
||||
// The iterator producing elements which are converted to datasets by
|
||||
// the dataset()->captured_func_ then interleaved together.
|
||||
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
// Whether the input_impl_ can produce future elements.
|
||||
bool end_of_input_ GUARDED_BY(mu_) = false;
|
||||
// The buffer of elements to be produced. Each worker thread operates
|
||||
// on a single OutputBufferElement.
|
||||
std::vector<OutputBufferElement> output_elements_ GUARDED_BY(mu_);
|
||||
// input_impl_ is reset when we have exhausted its input.
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
|
||||
// The WorkerState structs the worker threads operate on.
|
||||
// workers_ elements are in at most one of interleave_ and staging_.
|
||||
std::vector<WorkerState> workers_ GUARDED_BY(mu_);
|
||||
|
||||
// The iterators to interleave
|
||||
std::vector<WorkerState*> interleave_ GUARDED_BY(mu_);
|
||||
// Prefetched iterators
|
||||
std::deque<WorkerState*> staging_ GUARDED_BY(mu_);
|
||||
|
||||
// The index into output_elements_ for next element to produce.
|
||||
size_t next_index_ GUARDED_BY(mu_) = 0;
|
||||
// The number of items produced so far within the block
|
||||
size_t block_count_ GUARDED_BY(mu_) = 0;
|
||||
// Number of active threads.
|
||||
size_t num_active_threads_ GUARDED_BY(mu_) = 0;
|
||||
// Flag to instruct the worker threads to exit.
|
||||
bool cancelled_ GUARDED_BY(mu_) = false;
|
||||
// Pointers to the worker threads. This must be last to ensure the
|
||||
// The worker threads. This must be last to ensure the
|
||||
// threads have exited before any other members are deallocated.
|
||||
// TODO(b/65178177): Avoid allocating additional threads.
|
||||
std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
|
||||
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
@ -383,6 +484,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
const bool sloppy_;
|
||||
const int64 buffer_output_elements_;
|
||||
const int64 prefetch_input_elements_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
@ -27387,6 +27387,14 @@ op {
|
||||
name: "sloppy"
|
||||
type: DT_BOOL
|
||||
}
|
||||
input_arg {
|
||||
name: "buffer_output_elements"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "prefetch_input_elements"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
|
@ -313,6 +313,8 @@ REGISTER_OP("ParallelInterleaveDataset")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("sloppy: bool")
|
||||
.Input("buffer_output_elements: int64")
|
||||
.Input("prefetch_input_elements: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
|
@ -34,11 +34,11 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/util:convert",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops.dataset_ops import Dataset
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -29,18 +29,6 @@ from tensorflow.python.ops import gen_dataset_ops
|
||||
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
|
||||
|
||||
|
||||
def _convert_optional_param_to_tensor(argument_name,
|
||||
argument_value,
|
||||
argument_default=0,
|
||||
argument_dtype=dtypes.int64):
|
||||
if argument_value is not None:
|
||||
return ops.convert_to_tensor(
|
||||
argument_value, dtype=argument_dtype, name=argument_name)
|
||||
else:
|
||||
return constant_op.constant(
|
||||
argument_default, dtype=argument_dtype, name=argument_name)
|
||||
|
||||
|
||||
class TextLineDataset(Dataset):
|
||||
"""A `Dataset` comprising lines from one or more text files."""
|
||||
|
||||
@ -58,12 +46,12 @@ class TextLineDataset(Dataset):
|
||||
super(TextLineDataset, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
self._compression_type = _convert_optional_param_to_tensor(
|
||||
self._compression_type = convert.optional_param_to_tensor(
|
||||
"compression_type",
|
||||
compression_type,
|
||||
argument_default="",
|
||||
argument_dtype=dtypes.string)
|
||||
self._buffer_size = _convert_optional_param_to_tensor(
|
||||
self._buffer_size = convert.optional_param_to_tensor(
|
||||
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
@ -100,12 +88,12 @@ class TFRecordDataset(Dataset):
|
||||
# Force the type to string even if filenames is an empty list.
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtypes.string, name="filenames")
|
||||
self._compression_type = _convert_optional_param_to_tensor(
|
||||
self._compression_type = convert.optional_param_to_tensor(
|
||||
"compression_type",
|
||||
compression_type,
|
||||
argument_default="",
|
||||
argument_dtype=dtypes.string)
|
||||
self._buffer_size = _convert_optional_param_to_tensor(
|
||||
self._buffer_size = convert.optional_param_to_tensor(
|
||||
"buffer_size",
|
||||
buffer_size,
|
||||
argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
|
||||
@ -155,11 +143,11 @@ class FixedLengthRecordDataset(Dataset):
|
||||
self._record_bytes = ops.convert_to_tensor(
|
||||
record_bytes, dtype=dtypes.int64, name="record_bytes")
|
||||
|
||||
self._header_bytes = _convert_optional_param_to_tensor(
|
||||
self._header_bytes = convert.optional_param_to_tensor(
|
||||
"header_bytes", header_bytes)
|
||||
self._footer_bytes = _convert_optional_param_to_tensor(
|
||||
self._footer_bytes = convert.optional_param_to_tensor(
|
||||
"footer_bytes", footer_bytes)
|
||||
self._buffer_size = _convert_optional_param_to_tensor(
|
||||
self._buffer_size = convert.optional_param_to_tensor(
|
||||
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
|
@ -62,6 +62,30 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "convert",
|
||||
srcs = ["convert.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "convert_test",
|
||||
size = "small",
|
||||
srcs = ["convert_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
34
tensorflow/python/data/util/convert.py
Normal file
34
tensorflow/python/data/util/convert.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helpers constructing Datasets."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def optional_param_to_tensor(argument_name,
|
||||
argument_value,
|
||||
argument_default=0,
|
||||
argument_dtype=dtypes.int64):
|
||||
if argument_value is not None:
|
||||
return ops.convert_to_tensor(
|
||||
argument_value, dtype=argument_dtype, name=argument_name)
|
||||
else:
|
||||
return constant_op.constant(
|
||||
argument_default, dtype=argument_dtype, name=argument_name)
|
53
tensorflow/python/data/util/convert_test.py
Normal file
53
tensorflow/python/data/util/convert_test.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for utilities working with user input."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class ConvertTest(test.TestCase):
|
||||
|
||||
def testInteger(self):
|
||||
resp = convert.optional_param_to_tensor("foo", 3)
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(3, sess.run(resp))
|
||||
|
||||
def testIntegerDefault(self):
|
||||
resp = convert.optional_param_to_tensor("foo", None)
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(0, sess.run(resp))
|
||||
|
||||
def testStringDefault(self):
|
||||
resp = convert.optional_param_to_tensor("bar", None, "default",
|
||||
dtypes.string)
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
|
||||
|
||||
def testString(self):
|
||||
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
||||
dtypes.string)
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user